From 9dc3bc07424908dd7cf3f052708f506fd58b6e2c Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 14:49:28 +0200 Subject: refactor utilities (data, vis, tokenizer) --- swr2_asr/utils/tokenizer.py | 126 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 swr2_asr/utils/tokenizer.py (limited to 'swr2_asr/utils/tokenizer.py') diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py new file mode 100644 index 0000000..d92465a --- /dev/null +++ b/swr2_asr/utils/tokenizer.py @@ -0,0 +1,126 @@ +"""Tokenizer for Multilingual Librispeech datasets""" + + +class CharTokenizer: + """Maps characters to integers and vice versa""" + + def __init__(self): + char_map_str = """ + _ + + + + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z + é + à + ä + ö + ß + ü + - + ' + + """ + + self.char_map = {} + self.index_map = {} + for idx, char in enumerate(char_map_str.strip().split("\n")): + char = char.strip() + self.char_map[char] = idx + self.index_map[idx] = char + self.index_map[1] = " " + + def encode(self, text: str) -> list[int]: + """Use a character map and convert text to an integer sequence""" + int_sequence = [] + for char in text: + if char == " ": + char = self.char_map[""] + elif char not in self.char_map: + char = self.char_map[""] + else: + char = self.char_map[char] + int_sequence.append(char) + return int_sequence + + def decode(self, labels: list[int]) -> str: + """Use a character map and convert integer labels to an text sequence""" + string = [] + for i in labels: + string.append(self.index_map[i]) + return "".join(string).replace("", " ") + + def get_vocab_size(self) -> int: + """Get the number of unique characters in the dataset""" + return len(self.char_map) + + def get_blank_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] + + def get_unk_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] + + def get_space_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] + + # TODO: add train function + + def save(self, path: str) -> None: + """Save the tokenizer to a file""" + with open(path, "w", encoding="utf-8") as file: + for char, index in self.char_map.items(): + file.write(f"{char} {index}\n") + + @staticmethod + def from_file(tokenizer_file: str) -> "CharTokenizer": + """Instantiate a CharTokenizer from a file""" + load_tokenizer = CharTokenizer() + with open(tokenizer_file, "r", encoding="utf-8") as file: + for line in file: + line = line.strip() + if line: + char, index = line.split() + tokenizer.char_map[char] = int(index) + tokenizer.index_map[int(index)] = char + return load_tokenizer + + +if __name__ == "__main__": + tokenizer = CharTokenizer() + tokenizer.save("data/tokenizers/char_tokenizer_german.json") + print(tokenizer.char_map) + print(tokenizer.index_map) + print(tokenizer.get_vocab_size()) + print(tokenizer.get_blank_token()) + print(tokenizer.get_unk_token()) + print(tokenizer.get_space_token()) + print(tokenizer.encode("hallo welt")) + print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) -- cgit v1.2.3 From 8be140b38183b7465b5888a15b536a5f7fa66db6 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 20:45:32 +0200 Subject: added tokenizer to git and tokenizer training routing --- .gitignore | 11 +-- data/tokenizers/char_tokenizer_german.json | 38 ++++++++++ swr2_asr/utils/tokenizer.py | 110 ++++++++++++++++------------- 3 files changed, 101 insertions(+), 58 deletions(-) create mode 100644 data/tokenizers/char_tokenizer_german.json (limited to 'swr2_asr/utils/tokenizer.py') diff --git a/.gitignore b/.gitignore index 8e64e4b..d21ddb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # Training files -data/ +data/* +!data/tokenizers + # Mac **/.DS_Store @@ -163,10 +165,3 @@ dmypy.json # Cython debug symbols cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ diff --git a/data/tokenizers/char_tokenizer_german.json b/data/tokenizers/char_tokenizer_german.json new file mode 100644 index 0000000..20db079 --- /dev/null +++ b/data/tokenizers/char_tokenizer_german.json @@ -0,0 +1,38 @@ +_ 0 + 1 + 2 + 3 +a 4 +b 5 +c 6 +d 7 +e 8 +f 9 +g 10 +h 11 +i 12 +j 13 +k 14 +l 15 +m 16 +n 17 +o 18 +p 19 +q 20 +r 21 +s 22 +t 23 +u 24 +v 25 +w 26 +x 27 +y 28 +z 29 +é 30 +à 31 +ä 32 +ö 33 +ß 34 +ü 35 +- 36 +' 37 diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index d92465a..5482bbe 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -1,59 +1,18 @@ """Tokenizer for Multilingual Librispeech datasets""" +from datetime import datetime +import os + +from tqdm.autonotebook import tqdm + + class CharTokenizer: """Maps characters to integers and vice versa""" def __init__(self): - char_map_str = """ - _ - - - - a - b - c - d - e - f - g - h - i - j - k - l - m - n - o - p - q - r - s - t - u - v - w - x - y - z - é - à - ä - ö - ß - ü - - - ' - - """ - self.char_map = {} self.index_map = {} - for idx, char in enumerate(char_map_str.strip().split("\n")): - char = char.strip() - self.char_map[char] = idx - self.index_map[idx] = char - self.index_map[1] = " " def encode(self, text: str) -> list[int]: """Use a character map and convert text to an integer sequence""" @@ -91,7 +50,59 @@ class CharTokenizer: """Get the integer representation of the character""" return self.char_map[""] - # TODO: add train function + @staticmethod + def train(dataset_path: str, language: str) -> "CharTokenizer": + """Train the tokenizer on a dataset""" + chars = set() + root_path = os.path.join(dataset_path, language) + for split in os.listdir(root_path): + split_dir = os.path.join(root_path, split) + if os.path.isdir(split_dir): + transcript_path = os.path.join(split_dir, "transcripts.txt") + + with open(transcript_path, "r", encoding="utf-8") as transcrips: + lines = transcrips.readlines() + lines = [line.split(" ", 1)[1] for line in lines] + lines = [line.strip() for line in lines] + lines = [line.lower() for line in lines] + + for line in tqdm(lines, desc=f"Training tokenizer on {split_dir} split"): + chars.update(line) + + # sort chars + chars.remove(" ") + chars = sorted(chars) + + train_tokenizer = CharTokenizer() + + train_tokenizer.char_map["_"] = 0 + train_tokenizer.char_map[""] = 1 + train_tokenizer.char_map[""] = 2 + train_tokenizer.char_map[""] = 3 + + train_tokenizer.index_map[0] = "_" + train_tokenizer.index_map[1] = "" + train_tokenizer.index_map[2] = "" + train_tokenizer.index_map[3] = "" + + offset = 4 + + for idx, char in enumerate(chars): + idx += offset + train_tokenizer.char_map[char] = idx + train_tokenizer.index_map[idx] = char + + train_tokenizer_dir = os.path.join("data/tokenizers") + train_tokenizer_path = os.path.join( + train_tokenizer_dir, + f"char_tokenizer_{language}_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.json", + ) + + if not os.path.exists(os.path.dirname(train_tokenizer_dir)): + os.makedirs(train_tokenizer_dir) + train_tokenizer.save(train_tokenizer_path) + + return train_tokenizer def save(self, path: str) -> None: """Save the tokenizer to a file""" @@ -114,8 +125,7 @@ class CharTokenizer: if __name__ == "__main__": - tokenizer = CharTokenizer() - tokenizer.save("data/tokenizers/char_tokenizer_german.json") + tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus") print(tokenizer.char_map) print(tokenizer.index_map) print(tokenizer.get_vocab_size()) -- cgit v1.2.3 From 58b30927bd870604a4077a8af9ec3cad7b0be21c Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 21:52:42 +0200 Subject: changed config to yaml! --- config.philipp.yaml | 29 ++++++ config.train.yaml | 28 ++++++ poetry.lock | 51 ++++++++++- pyproject.toml | 1 + requirements.txt | 1 + swr2_asr/__main__.py | 12 --- swr2_asr/inference.py | 16 ++-- swr2_asr/model_deep_speech.py | 17 ---- swr2_asr/train.py | 192 +++++++++++++++++++--------------------- swr2_asr/utils/data.py | 7 +- swr2_asr/utils/tokenizer.py | 8 +- swr2_asr/utils/visualization.py | 8 +- 12 files changed, 218 insertions(+), 152 deletions(-) create mode 100644 config.philipp.yaml create mode 100644 config.train.yaml delete mode 100644 swr2_asr/__main__.py (limited to 'swr2_asr/utils/tokenizer.py') diff --git a/config.philipp.yaml b/config.philipp.yaml new file mode 100644 index 0000000..638b5ef --- /dev/null +++ b/config.philipp.yaml @@ -0,0 +1,29 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 0.0005 + batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 1 # evaluate every n epochs + num_workers: 4 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: True # set to True if you want to use limited supervision + dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) + shuffle: True + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.json" + +checkpoints: + model_load_path: ~ # path to load model from + model_save_path: ~ # path to save model to \ No newline at end of file diff --git a/config.train.yaml b/config.train.yaml new file mode 100644 index 0000000..c82439d --- /dev/null +++ b/config.train.yaml @@ -0,0 +1,28 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 5e-4 + batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 3 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" + +checkpoints: + model_load_path: "YOUR/PATH" # path to load model from + model_save_path: "YOUR/PATH" # path to save model to \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 3901b8c..a1f916b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1083,6 +1083,55 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "ruff" version = "0.0.285" @@ -1482,4 +1531,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "b9efbbcd85e7d70578496491d81aa6ef8a610a77ffe134c08446300d5de42ed5" +content-hash = "e45a9c1ba8b67cbe83c4b010c3f4718eee990b064b90a3ccd64380387e734faf" diff --git a/pyproject.toml b/pyproject.toml index 38cc51a..f6d19dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ mido = "^1.3.0" tokenizers = "^0.13.3" click = "^8.1.7" matplotlib = "^3.7.2" +pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] black = "^23.7.0" diff --git a/requirements.txt b/requirements.txt index 3b39b56..040fed0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,7 @@ platformdirs==3.10.0 pylint==2.17.5 pyparsing==3.0.9 python-dateutil==2.8.2 +PyYAML==6.0.1 ruff==0.0.285 six==1.16.0 sympy==1.12 diff --git a/swr2_asr/__main__.py b/swr2_asr/__main__.py deleted file mode 100644 index be294fb..0000000 --- a/swr2_asr/__main__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Main entrypoint for swr2-asr.""" -import torch -import torchaudio - -if __name__ == "__main__": - # test if GPU is available - print("GPU available: ", torch.cuda.is_available()) - - # test if torchaudio is installed correctly - print("torchaudio version: ", torchaudio.__version__) - print("torchaudio backend: ", torchaudio.get_audio_backend()) - print("torchaudio info: ", torchaudio.get_audio_backend()) diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index c3eec42..f8342f7 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,11 +1,12 @@ """Training script for the ASR model.""" +from typing import TypedDict + import torch -import torchaudio import torch.nn.functional as F -from typing import TypedDict +import torchaudio -from swr2_asr.tokenizer import CharTokenizer from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.tokenizer import CharTokenizer class HParams(TypedDict): @@ -28,8 +29,7 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member blank_label = tokenizer.encode(" ").ids[0] decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): + for _i, args in enumerate(arg_maxes): decode = [] for j, index in enumerate(args): if index != blank_label: @@ -44,7 +44,7 @@ def main() -> None: """inference function.""" device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) + device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") @@ -90,7 +90,7 @@ def main() -> None: model.load_state_dict(state_dict) # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") + waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member if sample_rate != spectrogram_hparams["sample_rate"]: resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) waveform = resampler(waveform) @@ -103,7 +103,7 @@ def main() -> None: specs = [spec] specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) - output = model(specs) + output = model(specs) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) decodes = greedy_decoder(output, tokenizer) diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index 8ddbd99..77f4c8a 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -3,27 +3,10 @@ Following definition by Assembly AI (https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) """ -from typing import TypedDict - import torch.nn.functional as F from torch import nn -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ac7666b..eb79ee2 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,11 +5,12 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +import yaml from torch import nn, optim from torch.utils.data import DataLoader from tqdm.autonotebook import tqdm -from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel +from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split from swr2_asr.utils.decoder import greedy_decoder from swr2_asr.utils.tokenizer import CharTokenizer @@ -17,7 +18,7 @@ from swr2_asr.utils.tokenizer import CharTokenizer from .utils.loss_scores import cer, wer -class IterMeter(object): +class IterMeter: """keeps track of total iterations""" def __init__(self): @@ -116,6 +117,7 @@ class TestArgs(TypedDict): def test(test_args: TestArgs) -> tuple[float, float, float]: + """Test""" print("\nevaluating...") # get values from test_args: @@ -128,7 +130,7 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")): + for _data in tqdm(test_loader, desc="Validation Batches"): spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) @@ -142,8 +144,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: decoded_preds, decoded_targets = greedy_decoder( output.transpose(0, 1), labels, label_lengths, tokenizer ) - if i == 1: - print(f"decoding first sample: {decoded_preds}") for j, _ in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], decoded_preds[j])) test_wer.append(wer(decoded_targets[j], decoded_preds[j])) @@ -161,157 +161,149 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: return test_loss, avg_cer, avg_wer -def main( - learning_rate: float, - batch_size: int, - epochs: int, - dataset_path: str, - language: str, - limited_supervision: bool, - model_load_path: str, - model_save_path: str, - dataset_percentage: float, - eval_every: int, - num_workers: int, -): +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +def main(config_path: str): """Main function for training the model. - Args: - learning_rate: learning rate for the optimizer - batch_size: batch size - epochs: number of epochs to train - dataset_path: path for the dataset - language: language of the dataset - limited_supervision: whether to use only limited supervision - model_load_path: path to load a model from - model_save_path: path to save the model to - dataset_percentage: percentage of the dataset to use - eval_every: evaluate every n epochs - num_workers: number of workers for the dataloader + Gets all configuration arguments from yaml config file. """ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member torch.manual_seed(7) - if not os.path.isdir(dataset_path): - os.makedirs(dataset_path) + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + training_config = config_dict.get("training", {}) + dataset_config = config_dict.get("dataset", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + checkpoints_config = config_dict.get("checkpoints", {}) + + print(training_config["learning_rate"]) + + if not os.path.isdir(dataset_config["dataset_root_path"]): + os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TEST, - download=True, - limited=limited_supervision, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) valid_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TRAIN, - download=False, - limited=Falimited_supervisionlse, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) - # TODO: initialize and possibly train tokenizer if none found - - kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - ) + kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {} + + if tokenizer_config["tokenizer_path"] is None: + print("Tokenizer not found!") + if click.confirm("Do you want to train a new tokenizer?", default=True): + pass + else: + return + tokenizer = CharTokenizer.train( + dataset_config["dataset_root_path"], dataset_config["language_name"] + ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) train_data_processing = DataProcessing("train", tokenizer) valid_data_processing = DataProcessing("valid", tokenizer) train_loader = DataLoader( dataset=train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=train_data_processing, **kwargs, ) valid_loader = DataLoader( dataset=valid_dataset, - batch_size=hparams["batch_size"], - shuffle=False, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=valid_data_processing, **kwargs, ) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"]) criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=hparams["learning_rate"], + max_lr=training_config["learning_rate"], steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], + epochs=training_config["epochs"], anneal_strategy="linear", ) prev_epoch = 0 - if model_load_path is not None: - checkpoint = torch.load(model_load_path) + if checkpoints_config["model_load_path"] is not None: + checkpoint = torch.load(checkpoints_config["model_load_path"]) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) - for epoch in range(prev_epoch + 1, epochs + 1): - train_args: TrainArgs = dict( - model=model, - device=device, - train_loader=train_loader, - criterion=criterion, - optimizer=optimizer, - scheduler=scheduler, - epoch=epoch, - iter_meter=iter_meter, - ) + + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): + train_args: TrainArgs = { + "model": model, + "device": device, + "train_loader": train_loader, + "criterion": criterion, + "optimizer": optimizer, + "scheduler": scheduler, + "epoch": epoch, + "iter_meter": iter_meter, + } train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = dict( - model=model, - device=device, - test_loader=valid_loader, - criterion=criterion, - tokenizer=tokenizer, - decoder="greedy", - ) + test_args: TestArgs = { + "model": model, + "device": device, + "test_loader": valid_loader, + "criterion": criterion, + "tokenizer": tokenizer, + "decoder": "greedy", + } - if epoch % eval_every == 0: + if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: test_loss, test_cer, test_wer = test(test_args) - if model_save_path is None: + if checkpoints_config["model_save_path"] is None: continue - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) + if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])): + os.makedirs(os.path.dirname(checkpoints_config["model_save_path"])) + torch.save( { "epoch": epoch, @@ -322,7 +314,7 @@ def main( "avg_cer": test_cer, "avg_wer": test_wer, }, - model_save_path + str(epoch), + checkpoints_config["model_save_path"] + str(epoch), ) diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index e939e1d..0e06eec 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -1,13 +1,12 @@ """Class containing utils for the ASR system.""" import os from enum import Enum -from typing import TypedDict import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -125,7 +124,7 @@ class MLSDataset(Dataset): self._handle_download_dataset(download) self._validate_local_directory() - if limited and (split == Split.TRAIN or split == Split.VALID): + if limited and split in (Split.TRAIN, Split.VALID): self.initialize_limited() else: self.initialize() @@ -351,8 +350,6 @@ class MLSDataset(Dataset): if __name__ == "__main__": - from torch.utils.data import DataLoader - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" LANGUAGE = "mls_german_opus" split = Split.DEV diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 5482bbe..22569eb 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -1,8 +1,6 @@ """Tokenizer for Multilingual Librispeech datasets""" - - -from datetime import datetime import os +from datetime import datetime from tqdm.autonotebook import tqdm @@ -119,8 +117,8 @@ class CharTokenizer: line = line.strip() if line: char, index = line.split() - tokenizer.char_map[char] = int(index) - tokenizer.index_map[int(index)] = char + load_tokenizer.char_map[char] = int(index) + load_tokenizer.index_map[int(index)] = char return load_tokenizer diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index 80f942a..a55d0d5 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -6,10 +6,10 @@ import torch def plot(epochs, path): """Plots the losses over the epochs""" - losses = list() - test_losses = list() - cers = list() - wers = list() + losses = [] + test_losses = [] + cers = [] + wers = [] for epoch in range(1, epochs + 1): current_state = torch.load(path + str(epoch)) losses.append(current_state["loss"]) -- cgit v1.2.3 From 4aff1fcd70cd8601541a1dd5bd820b0263ed1362 Mon Sep 17 00:00:00 2001 From: Philipp Merkel Date: Mon, 11 Sep 2023 22:36:28 +0000 Subject: fix: switched up training and test splits in train.py --- config.philipp.yaml | 22 +++++++++++----------- swr2_asr/train.py | 8 +++----- swr2_asr/utils/data.py | 31 ------------------------------- swr2_asr/utils/tokenizer.py | 12 ------------ 4 files changed, 14 insertions(+), 59 deletions(-) (limited to 'swr2_asr/utils/tokenizer.py') diff --git a/config.philipp.yaml b/config.philipp.yaml index 4a723c6..f72ce2e 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -4,30 +4,30 @@ model: rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets + dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets training: learning_rate: 0.0005 - batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 3 - eval_every_n: 1 # evaluate every n epochs + batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 150 + eval_every_n: 5 # evaluate every n epochs num_workers: 4 # number of workers for dataloader device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically dataset: - download: True - dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir + download: true + dataset_root_path: "data" # files will be downloaded into this dir language_name: "mls_german_opus" - limited_supervision: True # set to True if you want to use limited supervision - dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) - shuffle: True + limited_supervision: false # set to True if you want to use limited supervision + dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%) + shuffle: true tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" checkpoints: - model_load_path: "data/runs/epoch30" # path to load model from - model_save_path: ~ # path to save model to + model_load_path: "data/runs/epoch31" # path to load model from + model_save_path: "data/runs/epoch" # path to save model to inference: model_load_path: "data/runs/epoch30" # path to load model from diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ec25918..3ed3ac8 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -187,16 +187,14 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) - - print(training_config["learning_rate"]) - + if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TEST, + Split.TRAIN, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], @@ -204,7 +202,7 @@ def main(config_path: str): valid_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TRAIN, + Split.TEST, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 10f0ea8..d551c98 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -134,11 +134,6 @@ class MLSDataset(Dataset): def initialize_limited(self) -> None: """Initializes the limited supervision dataset""" - # get file handles - # get file paths - # get transcripts - # create train or validation split - handles = set() train_root_path = os.path.join(self.dataset_path, self.language, "train") @@ -348,29 +343,3 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore - - -if __name__ == "__main__": - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" - LANGUAGE = "mls_german_opus" - split = Split.DEV - DOWNLOAD = False - - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD) - - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=True, - collate_fn=DataProcessing( - "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - ), - ) - - for batch in dataloader: - print(batch) - break - - print(len(dataset)) - - print(dataset[0]) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 22569eb..1cc7b84 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,15 +120,3 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer - - -if __name__ == "__main__": - tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus") - print(tokenizer.char_map) - print(tokenizer.index_map) - print(tokenizer.get_vocab_size()) - print(tokenizer.get_blank_token()) - print(tokenizer.get_unk_token()) - print(tokenizer.get_space_token()) - print(tokenizer.encode("hallo welt")) - print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) -- cgit v1.2.3