diff options
-rw-r--r-- | .github/workflows/format.yml | 13 | ||||
-rw-r--r-- | .gitignore | 11 | ||||
-rw-r--r-- | .pre-commit-config.yaml | 14 | ||||
-rw-r--r-- | .vscode/settings.json | 26 | ||||
-rw-r--r-- | config.cluster.yaml | 34 | ||||
-rw-r--r-- | config.philipp.yaml | 34 | ||||
-rw-r--r-- | config.yaml | 34 | ||||
-rw-r--r-- | data/own/Silja_erklärt_die_welt.flac (renamed from Silja_erklärt_die_welt.flac) | bin | 8211720 -> 8211720 bytes | |||
-rw-r--r-- | data/own/marvin_rede.flac (renamed from marvin_rede.flac) | bin | 3976026 -> 3976026 bytes | |||
-rw-r--r-- | data/own/valentins_diss_gegen_marvin.flac (renamed from valentins_diss_gegen_marvin.flac) | bin | 706852 -> 706852 bytes | |||
-rw-r--r-- | data/tokenizers/char_tokenizer_german.json | 38 | ||||
-rwxr-xr-x | hpc_train.sh | 2 | ||||
-rw-r--r-- | mypy.ini | 17 | ||||
-rw-r--r-- | poetry.lock | 157 | ||||
-rw-r--r-- | pyproject.toml | 4 | ||||
-rw-r--r-- | requirements.txt | 41 | ||||
-rw-r--r-- | swr2_asr/__main__.py | 12 | ||||
-rw-r--r-- | swr2_asr/inference.py | 146 | ||||
-rw-r--r-- | swr2_asr/model_deep_speech.py | 74 | ||||
-rw-r--r-- | swr2_asr/tokenizer.py | 393 | ||||
-rw-r--r-- | swr2_asr/train.py | 424 | ||||
-rw-r--r-- | swr2_asr/utils/__init__.py (renamed from tests/__init__.py) | 0 | ||||
-rw-r--r-- | swr2_asr/utils/data.py (renamed from swr2_asr/utils.py) | 244 | ||||
-rw-r--r-- | swr2_asr/utils/decoder.py | 26 | ||||
-rw-r--r-- | swr2_asr/utils/loss_scores.py (renamed from swr2_asr/loss_scores.py) | 154 | ||||
-rw-r--r-- | swr2_asr/utils/tokenizer.py | 122 | ||||
-rw-r--r-- | swr2_asr/utils/visualization.py | 22 |
27 files changed, 943 insertions, 1099 deletions
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 7138f69..a939663 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -13,17 +13,10 @@ jobs: python-version: "3.10" - name: Install dependencies run: | - python -m pip install -U pip poetry - poetry --version - poetry check --no-interaction - poetry config virtualenvs.in-project true - poetry install --no-interaction + python -m pip install -r requirements.txt - name: Check for format issues run: | - make format-check + black --check swr2_asr - name: Run pylint run: | - poetry run pylint --fail-under=9 swr2_asr - # run: | - # poetry run mypy --strict swr2_asr - #- name: Run mypy + pylint --fail-under=9 swr2_asr @@ -1,5 +1,7 @@ # Training files -data/ +data/* +!data/tokenizers +!data/own # Mac **/.DS_Store @@ -163,10 +165,3 @@ dmypy.json # Cython debug symbols cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index af91a5e..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,14 +0,0 @@ -repos: - - repo: https://github.com/psf/black.git - rev: 23.7.0 - hooks: - - id: black - - repo: https://github.com/pre-commit/mirrors-mypy - rev: '' - hooks: - - id: mypy - - repo: https://github.com/pylint-dev/pylint.git - rev: v2.17.5 - hooks: - - id: pylint -
\ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 0054bca..1adbc18 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,14 +1,14 @@ { - "[python]": { - "editor.formatOnType": true, - "editor.defaultFormatter": "ms-python.black-formatter", - "editor.formatOnSave": true, - "editor.rulers": [88, 120], - }, - "black-formatter.importStrategy": "fromEnvironment", - "python.analysis.typeCheckingMode": "basic", - "ruff.organizeImports": true, - "ruff.importStrategy": "fromEnvironment", - "ruff.fixAll": true, - "ruff.run": "onType" -}
\ No newline at end of file + "[python]": { + "editor.formatOnType": true, + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true, + "editor.rulers": [88, 120] + }, + "black-formatter.importStrategy": "fromEnvironment", + "python.analysis.typeCheckingMode": "off", + "ruff.organizeImports": true, + "ruff.importStrategy": "fromEnvironment", + "ruff.fixAll": true, + "ruff.run": "onType" +} diff --git a/config.cluster.yaml b/config.cluster.yaml new file mode 100644 index 0000000..a3def0e --- /dev/null +++ b/config.cluster.yaml @@ -0,0 +1,34 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 0.0005 + batch_size: 64 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 150 + eval_every_n: 5 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically + +dataset: + download: True + dataset_root_path: "/mnt/lustre/mladm/mfa252/data" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + shuffle: True + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.json" + +checkpoints: + model_load_path: "data/runs/epoch31" # path to load model from + model_save_path: "data/runs/epoch" # path to save model to + +inference: + model_load_path: ~ # path to load model from + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
\ No newline at end of file diff --git a/config.philipp.yaml b/config.philipp.yaml new file mode 100644 index 0000000..f72ce2e --- /dev/null +++ b/config.philipp.yaml @@ -0,0 +1,34 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 0.0005 + batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 150 + eval_every_n: 5 # evaluate every n epochs + num_workers: 4 # number of workers for dataloader + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically + +dataset: + download: true + dataset_root_path: "data" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: false # set to True if you want to use limited supervision + dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%) + shuffle: true + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.json" + +checkpoints: + model_load_path: "data/runs/epoch31" # path to load model from + model_save_path: "data/runs/epoch" # path to save model to + +inference: + model_load_path: "data/runs/epoch30" # path to load model from + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
\ No newline at end of file diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..e5ff43a --- /dev/null +++ b/config.yaml @@ -0,0 +1,34 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 5e-4 + batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 3 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + shuffle: True + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" + +checkpoints: + model_load_path: "YOUR/PATH" # path to load model from + model_save_path: "YOUR/PATH" # path to save model to + +inference: + model_load_path: "YOUR/PATH" # path to load model from + beam_width: 10 # beam width for beam search + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
\ No newline at end of file diff --git a/Silja_erklärt_die_welt.flac b/data/own/Silja_erklärt_die_welt.flac Binary files differindex da51c0c..da51c0c 100644 --- a/Silja_erklärt_die_welt.flac +++ b/data/own/Silja_erklärt_die_welt.flac diff --git a/marvin_rede.flac b/data/own/marvin_rede.flac Binary files differindex 43d171d..43d171d 100644 --- a/marvin_rede.flac +++ b/data/own/marvin_rede.flac diff --git a/valentins_diss_gegen_marvin.flac b/data/own/valentins_diss_gegen_marvin.flac Binary files differindex 9c15644..9c15644 100644 --- a/valentins_diss_gegen_marvin.flac +++ b/data/own/valentins_diss_gegen_marvin.flac diff --git a/data/tokenizers/char_tokenizer_german.json b/data/tokenizers/char_tokenizer_german.json new file mode 100644 index 0000000..20db079 --- /dev/null +++ b/data/tokenizers/char_tokenizer_german.json @@ -0,0 +1,38 @@ +_ 0 +<BLANK> 1 +<UNK> 2 +<SPACE> 3 +a 4 +b 5 +c 6 +d 7 +e 8 +f 9 +g 10 +h 11 +i 12 +j 13 +k 14 +l 15 +m 16 +n 17 +o 18 +p 19 +q 20 +r 21 +s 22 +t 23 +u 24 +v 25 +w 26 +x 27 +y 28 +z 29 +é 30 +à 31 +ä 32 +ö 33 +ß 34 +ü 35 +- 36 +' 37 diff --git a/hpc_train.sh b/hpc_train.sh index c7d1636..2280087 100755 --- a/hpc_train.sh +++ b/hpc_train.sh @@ -1,3 +1,3 @@ #!/bin/sh -yes no | python -m swr2_asr.train --epochs=100 --batch_size=30 --dataset_path=/mnt/lustre/mladm/mfa252/data +yes no | python -m swr2_asr.train --config_path config.cluster.yaml diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index c13aa05..0000000 --- a/mypy.ini +++ /dev/null @@ -1,17 +0,0 @@ -[mypy-AudioLoader.*] -ignore_missing_imports = True - -[mypy-torchaudio.*] -ignore_missing_imports = true - -[mypy-torch.*] -ignore_missing_imports = true - -[mypy-click.*] -ignore_missing_imports = true - -[mypy-tokenizers.*] -ignore_missing_imports = true - -[mypy-tqmd.*] -ignore_missing_imports = true
\ No newline at end of file diff --git a/poetry.lock b/poetry.lock index fcac817..a1f916b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "astroid" @@ -21,33 +21,33 @@ wrapt = [ [[package]] name = "black" -version = "23.7.0" +version = "23.9.1" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, - {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, - {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, - {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, - {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, - {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, - {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, - {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, - {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, - {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, - {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, - {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, - {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:d6bc09188020c9ac2555a498949401ab35bb6bf76d4e0f8ee251694664df6301"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:13ef033794029b85dfea8032c9d3b92b42b526f1ff4bf13b2182ce4e917f5100"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:75a2dc41b183d4872d3a500d2b9c9016e67ed95738a3624f4751a0cb4818fe71"}, + {file = "black-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13a2e4a93bb8ca74a749b6974925c27219bb3df4d42fc45e948a5d9feb5122b7"}, + {file = "black-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:adc3e4442eef57f99b5590b245a328aad19c99552e0bdc7f0b04db6656debd80"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:8431445bf62d2a914b541da7ab3e2b4f3bc052d2ccbf157ebad18ea126efb91f"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:8fc1ddcf83f996247505db6b715294eba56ea9372e107fd54963c7553f2b6dfe"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:7d30ec46de88091e4316b17ae58bbbfc12b2de05e069030f6b747dfc649ad186"}, + {file = "black-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:031e8c69f3d3b09e1aa471a926a1eeb0b9071f80b17689a655f7885ac9325a6f"}, + {file = "black-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:538efb451cd50f43aba394e9ec7ad55a37598faae3348d723b59ea8e91616300"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:638619a559280de0c2aa4d76f504891c9860bb8fa214267358f0a20f27c12948"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:a732b82747235e0542c03bf352c126052c0fbc458d8a239a94701175b17d4855"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:cf3a4d00e4cdb6734b64bf23cd4341421e8953615cba6b3670453737a72ec204"}, + {file = "black-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf99f3de8b3273a8317681d8194ea222f10e0133a24a7548c73ce44ea1679377"}, + {file = "black-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:14f04c990259576acd093871e7e9b14918eb28f1866f91968ff5524293f9c573"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:c619f063c2d68f19b2d7270f4cf3192cb81c9ec5bc5ba02df91471d0b88c4c5c"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:6a3b50e4b93f43b34a9d3ef00d9b6728b4a722c997c99ab09102fd5efdb88325"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c46767e8df1b7beefb0899c4a95fb43058fa8500b6db144f4ff3ca38eb2f6393"}, + {file = "black-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50254ebfa56aa46a9fdd5d651f9637485068a1adf42270148cd101cdf56e0ad9"}, + {file = "black-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:403397c033adbc45c2bd41747da1f7fc7eaa44efbee256b53842470d4ac5a70f"}, + {file = "black-23.9.1-py3-none-any.whl", hash = "sha256:6ccd59584cc834b6d127628713e4b6b968e5f79572da66284532525a042549f9"}, + {file = "black-23.9.1.tar.gz", hash = "sha256:24b6b3ff5c6d9ea08a8888f6977eae858e1f340d7260cf56d70a49823236b62d"}, ] [package.dependencies] @@ -57,6 +57,7 @@ packaging = ">=22.0" pathspec = ">=0.9.0" platformdirs = ">=2" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} [package.extras] colorama = ["colorama (>=0.4.3)"] @@ -80,28 +81,28 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "cmake" -version = "3.27.2" +version = "3.27.4.1" description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software" optional = false python-versions = "*" files = [ - {file = "cmake-3.27.2-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:96ac856c4d6b2104408848f0005a8ab2229d4135b171ea9a03e8c33039ede420"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:11fe6129d07982721c5965fd804a4056b8c6e9c4f482ac9e0fe41bb3abc1ab5f"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:f0c64e89e2ea59592980c4fe3821d712fee0e74cf87c2aaec5b3ab9aa809a57c"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ca7650477dff2a1138776b28b79c0e99127be733d3978922e8f87b56a433eed6"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:ab2e40fe09e76a7ef67da2bbbf7a4cd1f52db4f1c7b6ccdda2539f918830343a"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:980ee19f12c808cb8ddb56fdcee832501a9f9631799d8b4fc625c0a0b5fb4c55"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:115d30ca0760e3861d9ad6b3288cd11ee72a785b81227da0c1765d3b84e2c009"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:efc338c939d6d435890a52458a260bf0942bd8392b648d7532a72c1ec0764e18"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:7f7438c60ccc01765b67abfb1797787c3b9459d500a804ed70a4cc181bc02204"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:294f008734267e0eee1574ad1b911bed137bc907ab19d60a618dab4615aa1fca"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:197a34dc62ee149ced343545fac67e5a30b93fda65250b065726f86ce92bdada"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:afb46ad883b174fb64347802ba5878423551dbd5847bb64669c39a5957c06eb7"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:83611ffd155e270a6b13bbf0cfd4e8688ebda634f448aa2e3734006c745bf33f"}, - {file = "cmake-3.27.2-py2.py3-none-win32.whl", hash = "sha256:53e12deb893da935e236f93accd47dbe2806620cd7654986234dc4487cc49652"}, - {file = "cmake-3.27.2-py2.py3-none-win_amd64.whl", hash = "sha256:611f9722c68c40352d38a6c01960ab038c3d0419e7aee3bf18f95b23031e0dfe"}, - {file = "cmake-3.27.2-py2.py3-none-win_arm64.whl", hash = "sha256:30620326b51ac2ce0d8f476747af6367a7ea21075c4d065fad9443904b07476a"}, - {file = "cmake-3.27.2.tar.gz", hash = "sha256:7cd6e2d7d5a1125f8c26c4f65214f8c942e3f276f98c16cb62ae382c35609f25"}, + {file = "cmake-3.27.4.1-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:65b79f1e8b6fa254697ee0b411aa4dff0d2309c1405af3448adf06cbd7ef0ac5"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:4a1d22ee72dcdc32d0f8bbf5691d2e9367585db8bfeafe7cffa2c4274127a801"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:37e9cad75184fbefe837311528d026901278b606707e9d14b58e3767d49d0aa6"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2f217fb281b068696fdcc4b198de62e9ded8bc0a93877684afc59db3507ccb44"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:a3bd8b9d0e294bd2b3ce27a850c9d924aee7e4f4c0bb56d66641cc1544314f58"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:871c8b5eaac959f079c2389c7a7f198fa5f86a029e8726fcb1f3e13d030c33e9"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:ec4a5bc2376dfc57065bfde6806183b331165f33457b7cc0fc0511260dde7c72"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c36eb106dec60198264b25d4bd23cd9ea30b0af9200a143ec1db887c095306f7"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:2b0b53ec2e45cfe9982d0adf833b3519efc328c1e3cffae4d237841a1ed6edf4"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:a504815bcba0ece9aafb48a6b7770d6479756fda92f8b62f9ab7ff8a403a12d5"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:1aca07fccfa042a0379bb027e30f090a8239b18fd3f959391c5d77c22dd0a809"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:315eb37233e2d0b8fa01580e33439eaeaef65f1e41ad9ca269cbe68cc0a039a4"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:94fec5a8bae1f3b62d8a8653ebcb7fa4007e2d0e713f94e4b2089f708c13548f"}, + {file = "cmake-3.27.4.1-py2.py3-none-win32.whl", hash = "sha256:8249a1ba7901b53661b44e59bdf6fd1e977e10843795788efe25d3374de6ed95"}, + {file = "cmake-3.27.4.1-py2.py3-none-win_amd64.whl", hash = "sha256:b72db11e13eafb46b9c53797d141e89886293db768feabef4b841accf666de54"}, + {file = "cmake-3.27.4.1-py2.py3-none-win_arm64.whl", hash = "sha256:0fb68660ce3954de99d1f41bedcf87063325c4cc891003f12de36472fa1efa28"}, + {file = "cmake-3.27.4.1.tar.gz", hash = "sha256:70526bbff5eeb7d4d6b921af1b80d2d29828302882f94a2cba93ad7d469b90f6"}, ] [package.extras] @@ -1083,6 +1084,55 @@ files = [ six = ">=1.5" [[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + +[[package]] name = "ruff" version = "0.0.285" description = "An extremely fast Python linter, written in Rust." @@ -1110,19 +1160,19 @@ files = [ [[package]] name = "setuptools" -version = "68.1.2" +version = "68.2.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-68.1.2-py3-none-any.whl", hash = "sha256:3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b"}, - {file = "setuptools-68.1.2.tar.gz", hash = "sha256:3d4dfa6d95f1b101d695a6160a7626e15583af71a5f52176efa5d39a054d475d"}, + {file = "setuptools-68.2.1-py3-none-any.whl", hash = "sha256:eff96148eb336377ab11beee0c73ed84f1709a40c0b870298b0d058828761bae"}, + {file = "setuptools-68.2.1.tar.gz", hash = "sha256:56ee14884fd8d0cd015411f4a13f40b4356775a0aefd9ebc1d3bfb9a1acb32f1"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5,<=7.1.2)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" @@ -1370,17 +1420,6 @@ tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] -name = "types-tqdm" -version = "4.66.0.2" -description = "Typing stubs for tqdm" -optional = false -python-versions = "*" -files = [ - {file = "types-tqdm-4.66.0.2.tar.gz", hash = "sha256:9553a5e44c1d485fce19f505b8bd65c0c3e87e870678d1f2ed764ae59a55d45f"}, - {file = "types_tqdm-4.66.0.2-py3-none-any.whl", hash = "sha256:13dddd38908834abdf0acdc2b70cab7ac4bcc5ad7356ced450471662e58a0ffc"}, -] - -[[package]] name = "typing-extensions" version = "4.7.1" description = "Backported and Experimental Type Hints for Python 3.7+" @@ -1492,4 +1531,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "097dc1f31ef2717f69dd11bdf98584ff46ef5f563e97c76ffe169591105d4287" +content-hash = "e45a9c1ba8b67cbe83c4b010c3f4718eee990b064b90a3ccd64380387e734faf" diff --git a/pyproject.toml b/pyproject.toml index 6f74b49..f6d19dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,18 +17,18 @@ mido = "^1.3.0" tokenizers = "^0.13.3" click = "^8.1.7" matplotlib = "^3.7.2" +pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] black = "^23.7.0" mypy = "^1.5.1" pylint = "^2.17.5" ruff = "^0.0.285" -types-tqdm = "^4.66.0.1" [tool.ruff] select = ["E", "F", "B", "I"] fixable = ["ALL"] -line-length = 120 +line-length = 100 target-version = "py310" [tool.black] diff --git a/requirements.txt b/requirements.txt index 409d7b5..040fed0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,40 @@ -AudioLoader @ git+https://github.com/marvinborner/AudioLoader.git -mido==1.3.0 -filelock==3.12.2 +astroid==2.15.6 +black==23.9.1 +click==8.1.7 +contourpy==1.1.0 +cycler==0.11.0 +dill==0.3.7 +filelock==3.12.3 +fonttools==4.42.1 +isort==5.12.0 Jinja2==3.1.2 +kiwisolver==1.4.5 +lazy-object-proxy==1.9.0 MarkupSafe==2.1.3 +matplotlib==3.7.2 +mccabe==0.7.0 +mido==1.3.0 mpmath==1.3.0 +mypy==1.5.1 +mypy-extensions==1.0.0 networkx==3.1 numpy==1.25.2 -sympy==1.12 -torch==2.0.1 -torchaudio==2.0.2 -black==23.7.0 -click==8.1.7 -mypy==1.5.1 +packaging==23.1 +pathspec==0.11.2 +Pillow==10.0.0 +platformdirs==3.10.0 pylint==2.17.5 +pyparsing==3.0.9 +python-dateutil==2.8.2 +PyYAML==6.0.1 +ruff==0.0.285 +six==1.16.0 +sympy==1.12 +tokenizers==0.13.3 +tomli==2.0.1 +tomlkit==0.12.1 +torch==2.0.0 +torchaudio==2.0.1 tqdm==4.66.1 typing_extensions==4.7.1 +wrapt==1.15.0 diff --git a/swr2_asr/__main__.py b/swr2_asr/__main__.py deleted file mode 100644 index be294fb..0000000 --- a/swr2_asr/__main__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Main entrypoint for swr2-asr.""" -import torch -import torchaudio - -if __name__ == "__main__": - # test if GPU is available - print("GPU available: ", torch.cuda.is_available()) - - # test if torchaudio is installed correctly - print("torchaudio version: ", torchaudio.__version__) - print("torchaudio backend: ", torchaudio.get_audio_backend()) - print("torchaudio info: ", torchaudio.get_audio_backend()) diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index c3eec42..3f6a44e 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,35 +1,20 @@ """Training script for the ASR model.""" +import click import torch -import torchaudio import torch.nn.functional as F -from typing import TypedDict +import torchaudio +import yaml -from swr2_asr.tokenizer import CharTokenizer from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.tokenizer import CharTokenizer -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, collapse_repeated=True): +def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] + blank_label = tokenizer.get_blank_token() decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): + for args in arg_maxes: decode = [] for j, index in enumerate(args): if index != blank_label: @@ -40,75 +25,72 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): return decodes -def main() -> None: +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +@click.option( + "--file_path", + help="Path to audio file", + type=click.Path(exists=True), +) +def main(config_path: str, file_path: str) -> None: """inference function.""" + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) - device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) - - tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") - - spectrogram_hparams = { - "sample_rate": 16000, - "n_fft": 400, - "win_length": 400, - "hop_length": 160, - "n_mels": 128, - "f_min": 0, - "f_max": 8000, - "power": 2.0, - } - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=0.1, - batch_size=30, - epochs=100, - ) + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + inference_config = config_dict.get("inference", {}) + + if inference_config["device"] == "cpu": + device = "cpu" + elif inference_config["device"] == "cuda": + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) # pylint: disable=no-member + + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - checkpoint = torch.load("model8", map_location=device) - state_dict = { - k[len("module.") :] if k.startswith("module.") else k: v - for k, v in checkpoint["model_state_dict"].items() - } - model.load_state_dict(state_dict) - - # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") - if sample_rate != spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) + checkpoint = torch.load(inference_config["model_load_path"], map_location=device) + model.load_state_dict(checkpoint["model_state_dict"], strict=True) + model.eval() + + waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member + if waveform.shape[0] != 1: + waveform = waveform[1] + waveform = waveform.unsqueeze(0) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) + sample_rate = 16000 + + data_processing = torchaudio.transforms.MelSpectrogram(n_mels=model_config["n_feats"]) + + spec = data_processing(waveform).squeeze(0).transpose(0, 1) - spec = ( - torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform) - .squeeze(0) - .transpose(0, 1) - ) - specs = [spec] - specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) + spec = spec.unsqueeze(0) + spec = spec.transpose(1, 2) + spec = spec.unsqueeze(0) + output = model(spec) # pylint: disable=not-callable + output = F.log_softmax(output, dim=2) # (batch, time, n_class) + decoded_preds = greedy_decoder(output, tokenizer) - output = model(specs) - output = F.log_softmax(output, dim=2) - output = output.transpose(0, 1) # (time, batch, n_class) - decodes = greedy_decoder(output, tokenizer) - print(decodes) + print(decoded_preds) if __name__ == "__main__": - main() + main() # pylint: disable=no-value-for-parameter diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index dd07ff9..73f5a81 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,4 +1,8 @@ -"""Main definition of model""" +"""Main definition of the Deep speech 2 model by Baidu Research. + +Following definition by Assembly AI +(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) +""" import torch.nn.functional as F from torch import nn @@ -6,8 +10,8 @@ from torch import nn class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" - def __init__(self, n_feats: int): - super().__init__() + def __init__(self, n_feats): + super(CNNLayerNorm, self).__init__() self.layer_norm = nn.LayerNorm(n_feats) def forward(self, data): @@ -18,34 +22,22 @@ class CNNLayerNorm(nn.Module): class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf + except with layer norm instead of batch norm + """ - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super().__init__() + def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats): + super(ResidualCNN, self).__init__() self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) + self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.layer_norm1 = CNNLayerNorm(n_feats) self.layer_norm2 = CNNLayerNorm(n_feats) def forward(self, data): - """x (batch, channel, feature, time)""" + """data (batch, channel, feature, time)""" residual = data # (batch, channel, feature, time) data = self.layer_norm1(data) data = F.gelu(data) @@ -60,18 +52,12 @@ class ResidualCNN(nn.Module): class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" + """Bidirectional GRU layer""" - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super().__init__() + def __init__(self, rnn_dim, hidden_size, dropout, batch_first): + super(BidirectionalGRU, self).__init__() - self.bi_gru = nn.GRU( + self.BiGRU = nn.GRU( # pylint: disable=invalid-name input_size=rnn_dim, hidden_size=hidden_size, num_layers=1, @@ -82,11 +68,11 @@ class BidirectionalGRU(nn.Module): self.dropout = nn.Dropout(dropout) def forward(self, data): - """data (batch, time, feature)""" + """x (batch, time, feature)""" data = self.layer_norm(data) data = F.gelu(data) + data, _ = self.BiGRU(data) data = self.dropout(data) - data, _ = self.bi_gru(data) return data @@ -94,18 +80,14 @@ class SpeechRecognitionModel(nn.Module): """Speech Recognition Model Inspired by DeepSpeech 2""" def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, + self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1 ): - super().__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + super(SpeechRecognitionModel, self).__init__() + n_feats = n_feats // 2 + self.cnn = nn.Conv2d( + 1, 32, 3, stride=stride, padding=3 // 2 + ) # cnn for extracting heirachal features + # n residual cnn layers with filter size of 32 self.rescnn_layers = nn.Sequential( *[ @@ -133,7 +115,7 @@ class SpeechRecognitionModel(nn.Module): ) def forward(self, data): - """data (batch, channel, feature, time)""" + """x (batch, channel, feature, time)""" data = self.cnn(data) data = self.rescnn_layers(data) sizes = data.size() diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py deleted file mode 100644 index 2e2fb57..0000000 --- a/swr2_asr/tokenizer.py +++ /dev/null @@ -1,393 +0,0 @@ -"""Tokenizer for use with Multilingual Librispeech""" -import json -import os -from dataclasses import dataclass -from typing import Type - -import click -from tokenizers import Tokenizer, normalizers -from tokenizers.models import BPE -from tokenizers.pre_tokenizers import Whitespace -from tokenizers.trainers import BpeTrainer -from tqdm import tqdm - - -class TokenizerType: - """Base class for tokenizers. - - exposes the same interface as tokenizers from the huggingface library""" - - def encode(self, sequence: str) -> list[int]: - """Encode a sequence to a list of integer labels""" - raise NotImplementedError - - def decode(self, labels: list[int], remove_special_tokens: bool) -> str: - """Decode a list of integer labels to a sequence""" - raise NotImplementedError - - def decode_batch(self, labels: list[list[int]]) -> list[str]: - """Decode a batch of integer labels to a list of sequences""" - raise NotImplementedError - - def get_vocab_size(self) -> int: - """Get the size of the vocabulary""" - raise NotImplementedError - - def enable_padding( - self, - length: int = -1, - direction: str = "right", - pad_id: int = 0, - pad_type_id: int = 0, - pad_token: str = "[PAD]", - ) -> None: - """Enable padding for the tokenizer""" - raise NotImplementedError - - def save(self, path: str) -> None: - """Save the tokenizer to a file""" - raise NotImplementedError - - @staticmethod - def from_file(path: str) -> "TokenizerType": - """Load the tokenizer from a file""" - raise NotImplementedError - - -MyTokenizerType = Type[TokenizerType] - - -@dataclass -class Encoding: - """Simple dataclass to represent an encoding""" - - ids: list[int] - tokens: list[str] - - -class CharTokenizer(TokenizerType): - """Very simple tokenizer for use with Multilingual Librispeech - - Simply checks what characters are in the dataset and uses them as tokens. - - Exposes the same interface as tokenizers from the huggingface library, i.e. - encode, decode, decode_batch, get_vocab_size, save, from_file and train. - """ - - def __init__(self): - self.char_map = {} - self.index_map = {} - self.add_tokens(["<UNK>", "<SPACE>"]) - - def add_tokens(self, tokens: list[str]): - """Manually add tokens to the tokenizer - - Args: - tokens (list[str]): List of tokens to add - """ - for token in tokens: - if token not in self.char_map: - self.char_map[token] = len(self.char_map) - self.index_map[len(self.index_map)] = token - - def train(self, dataset_path: str, language: str, split: str): - """Train the tokenizer on the given dataset - - Args: - dataset_path (str): Path to the MLS dataset - language (str): Language to use - split (str): Split to use - """ - if split not in ["train", "dev", "test", "all"]: - raise ValueError("Split must be one of train, dev, test, all") - - if split == "all": - splits = ["train", "dev", "test"] - else: - splits = [split] - - chars: set = set() - for s_plit in splits: - transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - - with open( - transcript_path, - "r", - encoding="utf-8", - ) as file: - lines = file.readlines() - lines = [line.split(" ", 1)[1] for line in lines] - lines = [line.strip() for line in lines] - - for line in tqdm(lines, desc=f"Training tokenizer on {s_plit} split"): - chars.update(line) - offset = len(self.char_map) - for i, char in enumerate(chars): - i += offset - self.char_map[char] = i - self.index_map[i] = char - - def encode(self, sequence: str): - """Use a character map and convert text to an integer sequence - - automatically maps spaces to <SPACE> and makes everything lowercase - unknown characters are mapped to the <UNK> token - - """ - int_sequence = [] - sequence = sequence.lower() - for char in sequence: - if char == " ": - mapped_char = self.char_map["<SPACE>"] - elif char not in self.char_map: - mapped_char = self.char_map["<UNK>"] - else: - mapped_char = self.char_map[char] - int_sequence.append(mapped_char) - return Encoding(ids=int_sequence, tokens=list(sequence)) - - def decode(self, labels: list[int], remove_special_tokens: bool = True): - """Use a character map and convert integer labels to an text sequence - - Args: - labels (list[int]): List of integer labels - remove_special_tokens (bool): Whether to remove special tokens. - Defaults to True. - """ - string = [] - for i in labels: - if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>": - continue - if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>": - string.append(" ") - string.append(self.index_map[f"{i}"]) - return "".join(string).replace("<SPACE>", " ") - - def decode_batch(self, labels: list[list[int]]): - """Use a character map and convert integer labels to an text sequence""" - strings = [] - for label in labels: - string = [] - for i in label: - if self.index_map[i] == "<UNK>": - continue - if self.index_map[i] == "<SPACE>": - string.append(" ") - string.append(self.index_map[i]) - strings.append("".join(string).replace("<SPACE>", " ")) - return strings - - def get_vocab_size(self): - """Get the size of the vocabulary""" - return len(self.char_map) - - def save(self, path: str): - """Save the tokenizer to a file""" - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "w", encoding="utf-8") as file: - # save it in the following format: - # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} - json.dump( - {"char_map": self.char_map, "index_map": self.index_map}, - file, - ensure_ascii=False, - ) - - @staticmethod - def from_file(path: str) -> "CharTokenizer": - """Load the tokenizer from a file""" - char_tokenizer = CharTokenizer() - with open(path, "r", encoding="utf-8") as file: - # load it in the following format: - # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} - saved_file = json.load(file) - char_tokenizer.char_map = saved_file["char_map"] - char_tokenizer.index_map = saved_file["index_map"] - - return char_tokenizer - - -@click.command() -@click.option("--dataset_path", default="data", help="Path to the MLS dataset") -@click.option("--language", default="mls_german_opus", help="Language to use") -@click.option("--split", default="train", help="Split to use (including all)") -@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") -@click.option("--vocab_size", default=2000, help="Size of the vocabulary") -def train_bpe_tokenizer_cli( - dataset_path: str, - language: str, - split: str, - out_path: str, - vocab_size: int, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" - train_bpe_tokenizer( - dataset_path, - language, - split, - out_path, - vocab_size, - ) - - -def train_bpe_tokenizer( - dataset_path: str, - language: str, - split: str, - out_path: str, - vocab_size: int, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset - - Assumes that the MLS dataset is located in the dataset_path and there is a - transcripts.txt file in the split folder. - - Args: - dataset_path (str): Path to the MLS dataset - language (str): Language to use - split (str): Split to use - download (bool): Whether to download the dataset if it is not present - out_path (str): Path to save the tokenizer to - vocab_size (int): Size of the vocabulary - """ - if split not in ["train", "dev", "test", "all"]: - raise ValueError("Split must be one of train, dev, test, all") - - if split == "all": - splits = ["train", "dev", "test"] - else: - splits = [split] - - lines = [] - - for s_plit in splits: - transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - if not os.path.exists(transcripts_path): - raise FileNotFoundError( - f"Could not find transcripts.txt in {transcripts_path}. " - "Please make sure that the dataset is downloaded." - ) - - with open( - transcripts_path, - "r", - encoding="utf-8", - ) as file: - sp_lines = file.readlines() - sp_lines = [line.split(" ", 1)[1] for line in sp_lines] - sp_lines = [line.strip() for line in sp_lines] - - lines.append(sp_lines) - - bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]")) - - initial_alphabet = [ - " ", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "ä", - "ö", - "ü", - "ß", - "-", - "é", - "è", - "à", - "ù", - "ç", - "â", - "ê", - "î", - "ô", - "û", - "ë", - "ï", - "ü", - ] - - - trainer = BpeTrainer( - special_tokens=["[UNK]"], - vocab_size=vocab_size, - initial_alphabet=initial_alphabet, - show_progress=True, - ) # type: ignore - - bpe_tokenizer.pre_tokenizer = Whitespace() # type: ignore - - bpe_tokenizer.normalizer = normalizers.Lowercase() # type: ignore - - bpe_tokenizer.train_from_iterator(lines, trainer=trainer) - - bpe_tokenizer.save(out_path) - - -@click.command() -@click.option("--dataset_path", default="data", help="Path to the MLS dataset") -@click.option("--language", default="mls_german_opus", help="Language to use") -@click.option("--split", default="train", help="Split to use") -@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") -def train_char_tokenizer_cli( - dataset_path: str, - language: str, - split: str, - out_path: str, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" - train_char_tokenizer(dataset_path, language, split, out_path) - - -def train_char_tokenizer( - dataset_path: str, - language: str, - split: str, - out_path: str, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset - - Assumes that the MLS dataset is located in the dataset_path and there is a - transcripts.txt file in the split folder. - - Args: - dataset_path (str): Path to the MLS dataset - language (str): Language to use - split (str): Split to use - download (bool): Whether to download the dataset if it is not present - out_path (str): Path to save the tokenizer to - """ - char_tokenizer = CharTokenizer() - - char_tokenizer.train(dataset_path, language, split) - - char_tokenizer.save(out_path) - - -if __name__ == "__main__": - tokenizer = CharTokenizer() - tokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - - print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids)) diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9f12bcb..ffdae73 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,49 +5,17 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +import yaml from torch import nn, optim from torch.utils.data import DataLoader -from tqdm import tqdm +from tqdm.autonotebook import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel -from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer -from swr2_asr.utils import MLSDataset, Split, collate_fn,plot - -from .loss_scores import cer, wer - - -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): - """Greedily decode a sequence.""" - print("output shape", output.shape) - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes, targets +from swr2_asr.utils.data import DataProcessing, MLSDataset, Split +from swr2_asr.utils.decoder import greedy_decoder +from swr2_asr.utils.tokenizer import CharTokenizer + +from .utils.loss_scores import cer, wer class IterMeter: @@ -61,252 +29,292 @@ class IterMeter: self.val += 1 def get(self): - """get""" + """get steps""" return self.val -def train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, -): - """Train""" +class TrainArgs(TypedDict): + """Type for the arguments of the training function.""" + + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + train_loader: DataLoader + criterion: nn.CTCLoss + optimizer: optim.AdamW + scheduler: optim.lr_scheduler.OneCycleLR + epoch: int + iter_meter: IterMeter + + +def train(train_args) -> float: + """Train + Args: + model: model + device: device type + train_loader: train dataloader + criterion: loss function + optimizer: optimizer + scheduler: learning rate scheduler + epoch: epoch number + iter_meter: iteration meter + + Returns: + avg_train_loss: avg_train_loss for the epoch + + Information: + spectrograms: (batch, time, feature) + labels: (batch, label_length) + + model output: (batch,time, n_class) + + """ + # get values from train_args: + ( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) = train_args.values() + model.train() - print(f"Epoch: {epoch}") - losses = [] - for _data in tqdm(train_loader, desc="batches"): - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + print(f"training batch {epoch}") + train_losses = [] + for _data in tqdm(train_loader, desc="Training batches"): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) + train_losses.append(loss) loss.backward() optimizer.step() scheduler.step() iter_meter.step() + avg_train_loss = sum(train_losses) / len(train_losses) + print(f"Train set: Average loss: {avg_train_loss:.2f}") + return avg_train_loss + - losses.append(loss.item()) +class TestArgs(TypedDict): + """Type for the arguments of the test function.""" - print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") - return sum(losses) / len(losses) + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + test_loader: DataLoader + criterion: nn.CTCLoss + tokenizer: CharTokenizer + decoder: str -def test(model, device, test_loader, criterion, tokenizer): +def test(test_args: TestArgs) -> tuple[float, float, float]: """Test""" print("\nevaluating...") + + # get values from test_args: + model, device, test_loader, criterion, tokenizer, decoder = test_args.values() + + if decoder == "greedy": + decoder = greedy_decoder + model.eval() test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for _data in test_loader: - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + for _data in tqdm(test_loader, desc="Validation Batches"): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output=output.transpose(0, 1), - labels=labels, - label_lengths=_data["utterance_length"], - tokenizer=tokenizer, + output.transpose(0, 1), labels, label_lengths, tokenizer ) - for j, pred in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], pred)) - test_wer.append(wer(decoded_targets[j], pred)) + for j, _ in enumerate(decoded_preds): + test_cer.append(cer(decoded_targets[j], decoded_preds[j])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j])) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) print( - f"Test set: Average loss:\ - {test_loss}, Average CER: {None} Average WER: {None}\n" + f"Test set: \ + Average loss: {test_loss:.4f}, \ + Average CER: {avg_cer:4f} \ + Average WER: {avg_wer:.4f}\n" ) return test_loss, avg_cer, avg_wer -def run( - learning_rate: float, - batch_size: int, - epochs: int, - load: bool, - path: str, - dataset_path: str, - language: str, -) -> None: - """Runs the training script.""" +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +def main(config_path: str): + """Main function for training the model. + + Gets all configuration arguments from yaml config file. + """ use_cuda = torch.cuda.is_available() - torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - # device = torch.device("mps") - # load dataset + torch.manual_seed(7) + + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + training_config = config_dict.get("training", {}) + dataset_config = config_dict.get("dataset", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + checkpoints_config = config_dict.get("checkpoints", {}) + + if not os.path.isdir(dataset_config["dataset_root_path"]): + os.makedirs(dataset_config["dataset_root_path"]) + train_dataset = MLSDataset( - dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True + dataset_config["dataset_root_path"], + dataset_config["language_name"], + Split.TRAIN, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) valid_dataset = MLSDataset( - dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True + dataset_config["dataset_root_path"], + dataset_config["language_name"], + Split.TEST, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) - # load tokenizer (bpe by default): - if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): - print("There is no tokenizer available. Do you want to train it on the dataset?") - input("Press Enter to continue...") - train_char_tokenizer( - dataset_path=dataset_path, - language=language, - split="all", - out_path="data/tokenizers/char_tokenizer_german.json", - ) - - tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {} - train_dataset.set_tokenizer(tokenizer) # type: ignore - valid_dataset.set_tokenizer(tokenizer) # type: ignore - - print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") + if tokenizer_config["tokenizer_path"] is None: + print("Tokenizer not found!") + if click.confirm("Do you want to train a new tokenizer?", default=True): + pass + else: + return + tokenizer = CharTokenizer.train( + dataset_config["dataset_root_path"], dataset_config["language_name"] + ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - ) + train_data_processing = DataProcessing("train", tokenizer, {"n_feats": model_config["n_feats"]}) + valid_data_processing = DataProcessing("valid", tokenizer, {"n_feats": model_config["n_feats"]}) train_loader = DataLoader( - train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: collate_fn(x), + dataset=train_dataset, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], + collate_fn=train_data_processing, + **kwargs, ) - valid_loader = DataLoader( - valid_dataset, - batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: collate_fn(x), + dataset=valid_dataset, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], + collate_fn=valid_data_processing, + **kwargs, ) - # enable flag to find the most compatible algorithms in advance - if use_cuda: - torch.backends.cudnn.benchmark = True # pylance: disable=no-member - model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - print(tokenizer.encode(" ")) - print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) - if load: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["model_state_dict"]) - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - epoch = checkpoint["epoch"] - loss = checkpoint["loss"] + + optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"]) + criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=hparams["learning_rate"], + max_lr=training_config["learning_rate"], steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], + epochs=training_config["epochs"], anneal_strategy="linear", ) + prev_epoch = 0 + + if checkpoints_config["model_load_path"] is not None: + checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - for epoch in range(1, epochs + 1): - loss = train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, - ) - test_loss, avg_cer, avg_wer = test( - model=model, - device=device, - test_loader=valid_loader, - criterion=criterion, - tokenizer=tokenizer, - ) - print("saving epoch", str(epoch)) + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): + train_args: TrainArgs = { + "model": model, + "device": device, + "train_loader": train_loader, + "criterion": criterion, + "optimizer": optimizer, + "scheduler": scheduler, + "epoch": epoch, + "iter_meter": iter_meter, + } + + train_loss = train(train_args) + + test_loss, test_cer, test_wer = 0, 0, 0 + + test_args: TestArgs = { + "model": model, + "device": device, + "test_loader": valid_loader, + "criterion": criterion, + "tokenizer": tokenizer, + "decoder": "greedy", + } + + if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: + test_loss, test_cer, test_wer = test(test_args) + + if checkpoints_config["model_save_path"] is None: + continue + + if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])): + os.makedirs(os.path.dirname(checkpoints_config["model_save_path"])) + torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), - "loss": loss, + "optimizer_state_dict": optimizer.state_dict(), + "train_loss": train_loss, "test_loss": test_loss, - "avg_cer": avg_cer, - "avg_wer": avg_wer, + "avg_cer": test_cer, + "avg_wer": test_wer, }, - path + str(epoch), - plot(epochs,path) + checkpoints_config["model_save_path"] + str(epoch), ) -@click.command() -@click.option("--learning_rate", default=1e-3, help="Learning rate") -@click.option("--batch_size", default=10, help="Batch size") -@click.option("--epochs", default=1, help="Number of epochs") -@click.option("--load", default=False, help="Do you want to load a model?") -@click.option( - "--path", - default="model", - help="Path where the model will be saved to/loaded from", -) -@click.option( - "--dataset_path", - default="data/", - help="Path for the dataset directory", -) -def run_cli( - learning_rate: float, - batch_size: int, - epochs: int, - load: bool, - path: str, - dataset_path: str, -) -> None: - """Runs the training script.""" - - run( - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - load=load, - path=path, - dataset_path=dataset_path, - language="mls_german_opus", - ) - - if __name__ == "__main__": - run_cli() # pylint: disable=no-value-for-parameter + main() # pylint: disable=no-value-for-parameter diff --git a/tests/__init__.py b/swr2_asr/utils/__init__.py index e69de29..e69de29 100644 --- a/tests/__init__.py +++ b/swr2_asr/utils/__init__.py diff --git a/swr2_asr/utils.py b/swr2_asr/utils/data.py index a362b9e..d551c98 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils/data.py @@ -1,23 +1,53 @@ """Class containing utils for the ASR system.""" import os from enum import Enum -from typing import TypedDict -import matplotlib.pyplot as plt import numpy as np import torch import torchaudio -from tokenizers import Tokenizer -from torch.utils.data import Dataset -from torchaudio.datasets.utils import _extract_tar as extract_archive +from torch import Tensor, nn +from torch.utils.data import DataLoader, Dataset +from torchaudio.datasets.utils import _extract_tar -from swr2_asr.tokenizer import TokenizerType +from swr2_asr.utils.tokenizer import CharTokenizer -train_audio_transforms = torch.nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) + +class DataProcessing: + """Data processing class for the dataloader""" + + def __init__(self, data_type: str, tokenizer: CharTokenizer, hparams: dict): + self.data_type = data_type + self.tokenizer = tokenizer + n_features = hparams["n_feats"] + + if data_type == "train": + self.audio_transform = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=n_features), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), + ) + elif data_type == "valid": + self.audio_transform = torchaudio.transforms.MelSpectrogram(n_mels=n_features) + + def __call__(self, data) -> tuple[Tensor, Tensor, list, list]: + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + for waveform, _, utterance, _, _, _ in data: + spec = self.audio_transform(waveform).squeeze(0).transpose(0, 1) + spectrograms.append(spec) + label = torch.Tensor(self.tokenizer.encode(utterance.lower())) + labels.append(label) + input_lengths.append(spec.shape[0] // 2) + label_lengths.append(len(label)) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) + ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths # create enum specifiying dataset splits @@ -31,7 +61,7 @@ class MLSSplit(str, Enum): class Split(str, Enum): - """Extending the MLSSplit class to allow for a custom validatio split""" + """Extending the MLSSplit class to allow for a custom validation split""" TRAIN = "train" VALID = "valid" @@ -43,22 +73,7 @@ def split_to_mls_split(split_name: Split) -> MLSSplit: """Converts the custom split to a MLSSplit""" if split_name == Split.VALID: return MLSSplit.TRAIN - else: - return split_name # type: ignore - - -class Sample(TypedDict): - """Type for a sample in the dataset""" - - waveform: torch.Tensor - spectrogram: torch.Tensor - input_length: int - utterance: torch.Tensor - utterance_length: int - sample_rate: int - speaker_id: str - book_id: str - chapter_id: str + return split_name # type: ignore class MLSDataset(Dataset): @@ -94,10 +109,10 @@ class MLSDataset(Dataset): self, dataset_path: str, language: str, - split: Split, - limited: bool, - download: bool, - spectrogram_hparams: dict | None, + split: Split, # pylint: disable=redefined-outer-name + limited: bool = False, + download: bool = True, + size: float = 0.2, ): """Initializes the dataset""" self.dataset_path = dataset_path @@ -106,37 +121,19 @@ class MLSDataset(Dataset): self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally - if spectrogram_hparams is None: - self.spectrogram_hparams = { - "sample_rate": 16000, - "n_fft": 400, - "win_length": 400, - "hop_length": 160, - "n_mels": 128, - "f_min": 0, - "f_max": 8000, - "power": 2.0, - } - else: - self.spectrogram_hparams = spectrogram_hparams - self.dataset_lookup = [] - self.tokenizer: type[TokenizerType] self._handle_download_dataset(download) self._validate_local_directory() - if limited and (split == Split.TRAIN or split == Split.VALID): + if limited and split in (Split.TRAIN, Split.VALID): self.initialize_limited() else: self.initialize() + self.dataset_lookup = self.dataset_lookup[: int(len(self.dataset_lookup) * size)] + def initialize_limited(self) -> None: """Initializes the limited supervision dataset""" - # get file handles - # get file paths - # get transcripts - # create train or validation split - handles = set() train_root_path = os.path.join(self.dataset_path, self.language, "train") @@ -246,10 +243,6 @@ class MLSDataset(Dataset): for path, utterance in zip(identifier, utterances, strict=False) ] - def set_tokenizer(self, tokenizer: type[TokenizerType]): - """Sets the tokenizer""" - self.tokenizer = tokenizer - def _handle_download_dataset(self, download: bool) -> None: """Download the dataset""" if not download: @@ -258,7 +251,9 @@ class MLSDataset(Dataset): # zip exists: if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: print(f"Found dataset at {self.dataset_path}. Skipping download") - # zip does not exist: + # path exists: + elif os.path.isdir(os.path.join(self.dataset_path, self.language)) and download: + return else: os.makedirs(self.dataset_path, exist_ok=True) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" @@ -273,9 +268,7 @@ class MLSDataset(Dataset): f"Unzipping the dataset at \ {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" ) - extract_archive( - os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True - ) + _extract_tar(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True) else: print("Dataset is already unzipped, validating it now") return @@ -293,12 +286,29 @@ class MLSDataset(Dataset): """Returns the length of the dataset""" return len(self.dataset_lookup) - def __getitem__(self, idx: int) -> Sample: - """One sample""" - if self.tokenizer is None: - raise ValueError("No tokenizer set") + def __getitem__(self, idx: int) -> tuple[Tensor, int, str, int, int, int]: + """One sample + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + int: + Speaker ID + int: + Chapter ID + int: + Utterance ID + """ # get the utterance - utterance = self.dataset_lookup[idx]["utterance"] + dataset_lookup_entry = self.dataset_lookup[idx] + + utterance = dataset_lookup_entry["utterance"] # get the audio file audio_path = os.path.join( @@ -321,97 +331,15 @@ class MLSDataset(Dataset): waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member # resample if necessary - if sample_rate != self.spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample( - sample_rate, self.spectrogram_hparams["sample_rate"] - ) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) - spec = ( - torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform) - .squeeze(0) - .transpose(0, 1) - ) - - input_length = spec.shape[0] // 2 - - utterance_length = len(utterance) - - utterance = self.tokenizer.encode(utterance) - - utterance = torch.LongTensor(utterance.ids) # pylint: disable=no-member - - return Sample( - waveform=waveform, - spectrogram=spec, - input_length=input_length, - utterance=utterance, - utterance_length=utterance_length, - sample_rate=self.spectrogram_hparams["sample_rate"], - speaker_id=self.dataset_lookup[idx]["speakerid"], - book_id=self.dataset_lookup[idx]["bookid"], - chapter_id=self.dataset_lookup[idx]["chapterid"], - ) - - -def collate_fn(samples: list[Sample]) -> dict: - """Collate function for the dataloader - - pads all tensors within a batch to the same dimensions - """ - waveforms = [] - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - - for sample in samples: - waveforms.append(sample["waveform"].transpose(0, 1)) - spectrograms.append(sample["spectrogram"]) - labels.append(sample["utterance"]) - input_lengths.append(sample["spectrogram"].shape[0] // 2) - label_lengths.append(len(sample["utterance"])) - - waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) - spectrograms = ( - torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) - ) - labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) - - return { - "waveform": waveforms, - "spectrogram": spectrograms, - "input_length": input_lengths, - "utterance": labels, - "utterance_length": label_lengths, - } - - -if __name__ == "__main__": - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" - LANGUAGE = "mls_german_opus" - split = Split.TRAIN - DOWNLOAD = False - - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, False, DOWNLOAD, None) - - tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - dataset.set_tokenizer(tok) - - -def plot(epochs, path): - """Plots the losses over the epochs""" - losses = list() - test_losses = list() - cers = list() - wers = list() - for epoch in range(1, epochs + 1): - current_state = torch.load(path + str(epoch)) - losses.append(current_state["loss"]) - test_losses.append(current_state["test_loss"]) - cers.append(current_state["avg_cer"]) - wers.append(current_state["avg_wer"]) - - plt.plot(losses) - plt.plot(test_losses) - plt.savefig("losses.svg") + return ( + waveform, + sample_rate, + utterance, + dataset_lookup_entry["speakerid"], + dataset_lookup_entry["chapterid"], + idx, + ) # type: ignore diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py new file mode 100644 index 0000000..fcddb79 --- /dev/null +++ b/swr2_asr/utils/decoder.py @@ -0,0 +1,26 @@ +"""Decoder for CTC-based ASR.""" "" +import torch + +from swr2_asr.utils.tokenizer import CharTokenizer + + +# TODO: refactor to use torch CTC decoder class +def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): + """Greedily decode a sequence.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(tokenizer.decode(decode)) + return decodes, targets + + +# TODO: add beam search decoder diff --git a/swr2_asr/loss_scores.py b/swr2_asr/utils/loss_scores.py index 63c8a8f..80285f6 100644 --- a/swr2_asr/loss_scores.py +++ b/swr2_asr/utils/loss_scores.py @@ -2,18 +2,36 @@ import numpy as np -def avg_wer(wer_scores, combined_ref_len): - """Calculate the average word error rate (WER) of the model.""" +def avg_wer(wer_scores, combined_ref_len) -> float: + """Calculate the average word error rate (WER). + + Args: + wer_scores: word error rate scores + combined_ref_len: combined length of reference sentences + + Returns: + average word error rate (float) + + Usage: + >>> avg_wer([0.5, 0.5], 2) + 0.5 + """ return float(sum(wer_scores)) / float(combined_ref_len) -def _levenshtein_distance(ref, hyp): - """Levenshtein distance is a string metric for measuring the difference - between two sequences. Informally, the levenshtein disctance is defined as - the minimum number of single-character edits (substitutions, insertions or - deletions) required to change one word into the other. We can naturally - extend the edits to word level when calculating levenshtein disctance for - two sentences. +def _levenshtein_distance(ref, hyp) -> int: + """Levenshtein distance. + + Args: + ref: reference sentence + hyp: hypothesis sentence + + Returns: + distance: levenshtein distance between reference and hypothesis + + Usage: + >>> _levenshtein_distance("hello", "helo") + 2 """ len_ref = len(ref) len_hyp = len(hyp) @@ -54,19 +72,24 @@ def _levenshtein_distance(ref, hyp): return distance[len_ref % 2][len_hyp] -def word_errors(reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "): +def word_errors( + reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " +) -> tuple[float, int]: """Compute the levenshtein distance between reference sequence and hypothesis sequence in word-level. - :param reference: The reference sentence. - :type reference: basestring - :param hypothesis: The hypothesis sentence. - :type hypothesis: basestring - :param ignore_case: Whether case-sensitive or not. - :type ignore_case: bool - :param delimiter: Delimiter of input sentences. - :type delimiter: char - :return: Levenshtein distance and word number of reference sentence. - :rtype: list + + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + delimiter: Delimiter of input sentences. + + Returns: + Levenshtein distance and length of reference sentence. + + Usage: + >>> word_errors("hello world", "hello") + 1, 2 """ if ignore_case: reference = reference.lower() @@ -84,19 +107,21 @@ def char_errors( hypothesis: str, ignore_case: bool = False, remove_space: bool = False, -): +) -> tuple[float, int]: """Compute the levenshtein distance between reference sequence and hypothesis sequence in char-level. - :param reference: The reference sentence. - :type reference: basestring - :param hypothesis: The hypothesis sentence. - :type hypothesis: basestring - :param ignore_case: Whether case-sensitive or not. - :type ignore_case: bool - :param remove_space: Whether remove internal space characters - :type remove_space: bool - :return: Levenshtein distance and length of reference sentence. - :rtype: list + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + remove_space: Whether remove internal space characters + + Returns: + Levenshtein distance and length of reference sentence. + + Usage: + >>> char_errors("hello world", "hello") + 1, 10 """ if ignore_case: reference = reference.lower() @@ -113,30 +138,29 @@ def char_errors( return float(edit_distance), len(reference) -def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "): +def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" ") -> float: """Calculate word error rate (WER). WER compares reference text and - hypothesis text in word-level. WER is defined as: - .. math:: + hypothesis text in word-level. + WER is defined as: WER = (Sw + Dw + Iw) / Nw - where - .. code-block:: text + with: Sw is the number of words subsituted, Dw is the number of words deleted, Iw is the number of words inserted, Nw is the number of words in the reference - We can use levenshtein distance to calculate WER. Please draw an attention - that empty items will be removed when splitting sentences by delimiter. - :param reference: The reference sentence. - :type reference: basestring - :param hypothesis: The hypothesis sentence. - :type hypothesis: basestring - :param ignore_case: Whether case-sensitive or not. - :type ignore_case: bool - :param delimiter: Delimiter of input sentences. - :type delimiter: char - :return: Word error rate. - :rtype: float - :raises ValueError: If word number of reference is zero. + + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + delimiter: Delimiter of input sentences. + + Returns: + Word error rate (float) + + Usage: + >>> wer("hello world", "hello") + 0.5 """ edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter) @@ -150,29 +174,25 @@ def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "): def cer(reference, hypothesis, ignore_case=False, remove_space=False): """Calculate charactor error rate (CER). CER compares reference text and hypothesis text in char-level. CER is defined as: - .. math:: CER = (Sc + Dc + Ic) / Nc - where - .. code-block:: text + with Sc is the number of characters substituted, Dc is the number of characters deleted, Ic is the number of characters inserted Nc is the number of characters in the reference - We can use levenshtein distance to calculate CER. Chinese input should be - encoded to unicode. Please draw an attention that the leading and tailing - space characters will be truncated and multiple consecutive space - characters in a sentence will be replaced by one space character. - :param reference: The reference sentence. - :type reference: basestring - :param hypothesis: The hypothesis sentence. - :type hypothesis: basestring - :param ignore_case: Whether case-sensitive or not. - :type ignore_case: bool - :param remove_space: Whether remove internal space characters - :type remove_space: bool - :return: Character error rate. - :rtype: float - :raises ValueError: If the reference length is zero. + + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + remove_space: Whether remove internal space characters + + Returns: + Character error rate (float) + + Usage: + >>> cer("hello world", "hello") + 0.2727272727272727 """ edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py new file mode 100644 index 0000000..1cc7b84 --- /dev/null +++ b/swr2_asr/utils/tokenizer.py @@ -0,0 +1,122 @@ +"""Tokenizer for Multilingual Librispeech datasets""" +import os +from datetime import datetime + +from tqdm.autonotebook import tqdm + + +class CharTokenizer: + """Maps characters to integers and vice versa""" + + def __init__(self): + self.char_map = {} + self.index_map = {} + + def encode(self, text: str) -> list[int]: + """Use a character map and convert text to an integer sequence""" + int_sequence = [] + for char in text: + if char == " ": + char = self.char_map["<SPACE>"] + elif char not in self.char_map: + char = self.char_map["<UNK>"] + else: + char = self.char_map[char] + int_sequence.append(char) + return int_sequence + + def decode(self, labels: list[int]) -> str: + """Use a character map and convert integer labels to an text sequence""" + string = [] + for i in labels: + string.append(self.index_map[i]) + return "".join(string).replace("<SPACE>", " ") + + def get_vocab_size(self) -> int: + """Get the number of unique characters in the dataset""" + return len(self.char_map) + + def get_blank_token(self) -> int: + """Get the integer representation of the <BLANK> character""" + return self.char_map["<BLANK>"] + + def get_unk_token(self) -> int: + """Get the integer representation of the <UNK> character""" + return self.char_map["<UNK>"] + + def get_space_token(self) -> int: + """Get the integer representation of the <SPACE> character""" + return self.char_map["<SPACE>"] + + @staticmethod + def train(dataset_path: str, language: str) -> "CharTokenizer": + """Train the tokenizer on a dataset""" + chars = set() + root_path = os.path.join(dataset_path, language) + for split in os.listdir(root_path): + split_dir = os.path.join(root_path, split) + if os.path.isdir(split_dir): + transcript_path = os.path.join(split_dir, "transcripts.txt") + + with open(transcript_path, "r", encoding="utf-8") as transcrips: + lines = transcrips.readlines() + lines = [line.split(" ", 1)[1] for line in lines] + lines = [line.strip() for line in lines] + lines = [line.lower() for line in lines] + + for line in tqdm(lines, desc=f"Training tokenizer on {split_dir} split"): + chars.update(line) + + # sort chars + chars.remove(" ") + chars = sorted(chars) + + train_tokenizer = CharTokenizer() + + train_tokenizer.char_map["_"] = 0 + train_tokenizer.char_map["<BLANK>"] = 1 + train_tokenizer.char_map["<UNK>"] = 2 + train_tokenizer.char_map["<SPACE>"] = 3 + + train_tokenizer.index_map[0] = "_" + train_tokenizer.index_map[1] = "<BLANK>" + train_tokenizer.index_map[2] = "<UNK>" + train_tokenizer.index_map[3] = "<SPACE>" + + offset = 4 + + for idx, char in enumerate(chars): + idx += offset + train_tokenizer.char_map[char] = idx + train_tokenizer.index_map[idx] = char + + train_tokenizer_dir = os.path.join("data/tokenizers") + train_tokenizer_path = os.path.join( + train_tokenizer_dir, + f"char_tokenizer_{language}_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.json", + ) + + if not os.path.exists(os.path.dirname(train_tokenizer_dir)): + os.makedirs(train_tokenizer_dir) + train_tokenizer.save(train_tokenizer_path) + + return train_tokenizer + + def save(self, path: str) -> None: + """Save the tokenizer to a file""" + with open(path, "w", encoding="utf-8") as file: + for char, index in self.char_map.items(): + file.write(f"{char} {index}\n") + + @staticmethod + def from_file(tokenizer_file: str) -> "CharTokenizer": + """Instantiate a CharTokenizer from a file""" + load_tokenizer = CharTokenizer() + with open(tokenizer_file, "r", encoding="utf-8") as file: + for line in file: + line = line.strip() + if line: + char, index = line.split() + load_tokenizer.char_map[char] = int(index) + load_tokenizer.index_map[int(index)] = char + return load_tokenizer diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py new file mode 100644 index 0000000..a55d0d5 --- /dev/null +++ b/swr2_asr/utils/visualization.py @@ -0,0 +1,22 @@ +"""Utilities for visualizing the training process and results.""" + +import matplotlib.pyplot as plt +import torch + + +def plot(epochs, path): + """Plots the losses over the epochs""" + losses = [] + test_losses = [] + cers = [] + wers = [] + for epoch in range(1, epochs + 1): + current_state = torch.load(path + str(epoch)) + losses.append(current_state["loss"]) + test_losses.append(current_state["test_loss"]) + cers.append(current_state["avg_cer"]) + wers.append(current_state["avg_wer"]) + + plt.plot(losses) + plt.plot(test_losses) + plt.savefig("losses.svg") |