From b8a4cdc6673787333cac282de744fd11604ca161 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 14:11:26 +0200 Subject: improved loss_scores --- swr2_asr/loss_scores.py | 154 +++++++++++++++++++++++++++--------------------- 1 file changed, 87 insertions(+), 67 deletions(-) (limited to 'swr2_asr') diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py index 63c8a8f..80285f6 100644 --- a/swr2_asr/loss_scores.py +++ b/swr2_asr/loss_scores.py @@ -2,18 +2,36 @@ import numpy as np -def avg_wer(wer_scores, combined_ref_len): - """Calculate the average word error rate (WER) of the model.""" +def avg_wer(wer_scores, combined_ref_len) -> float: + """Calculate the average word error rate (WER). + + Args: + wer_scores: word error rate scores + combined_ref_len: combined length of reference sentences + + Returns: + average word error rate (float) + + Usage: + >>> avg_wer([0.5, 0.5], 2) + 0.5 + """ return float(sum(wer_scores)) / float(combined_ref_len) -def _levenshtein_distance(ref, hyp): - """Levenshtein distance is a string metric for measuring the difference - between two sequences. Informally, the levenshtein disctance is defined as - the minimum number of single-character edits (substitutions, insertions or - deletions) required to change one word into the other. We can naturally - extend the edits to word level when calculating levenshtein disctance for - two sentences. +def _levenshtein_distance(ref, hyp) -> int: + """Levenshtein distance. + + Args: + ref: reference sentence + hyp: hypothesis sentence + + Returns: + distance: levenshtein distance between reference and hypothesis + + Usage: + >>> _levenshtein_distance("hello", "helo") + 2 """ len_ref = len(ref) len_hyp = len(hyp) @@ -54,19 +72,24 @@ def _levenshtein_distance(ref, hyp): return distance[len_ref % 2][len_hyp] -def word_errors(reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "): +def word_errors( + reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " +) -> tuple[float, int]: """Compute the levenshtein distance between reference sequence and hypothesis sequence in word-level. - :param reference: The reference sentence. - :type reference: basestring - :param hypothesis: The hypothesis sentence. - :type hypothesis: basestring - :param ignore_case: Whether case-sensitive or not. - :type ignore_case: bool - :param delimiter: Delimiter of input sentences. - :type delimiter: char - :return: Levenshtein distance and word number of reference sentence. - :rtype: list + + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + delimiter: Delimiter of input sentences. + + Returns: + Levenshtein distance and length of reference sentence. + + Usage: + >>> word_errors("hello world", "hello") + 1, 2 """ if ignore_case: reference = reference.lower() @@ -84,19 +107,21 @@ def char_errors( hypothesis: str, ignore_case: bool = False, remove_space: bool = False, -): +) -> tuple[float, int]: """Compute the levenshtein distance between reference sequence and hypothesis sequence in char-level. - :param reference: The reference sentence. - :type reference: basestring - :param hypothesis: The hypothesis sentence. - :type hypothesis: basestring - :param ignore_case: Whether case-sensitive or not. - :type ignore_case: bool - :param remove_space: Whether remove internal space characters - :type remove_space: bool - :return: Levenshtein distance and length of reference sentence. - :rtype: list + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + remove_space: Whether remove internal space characters + + Returns: + Levenshtein distance and length of reference sentence. + + Usage: + >>> char_errors("hello world", "hello") + 1, 10 """ if ignore_case: reference = reference.lower() @@ -113,30 +138,29 @@ def char_errors( return float(edit_distance), len(reference) -def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "): +def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" ") -> float: """Calculate word error rate (WER). WER compares reference text and - hypothesis text in word-level. WER is defined as: - .. math:: + hypothesis text in word-level. + WER is defined as: WER = (Sw + Dw + Iw) / Nw - where - .. code-block:: text + with: Sw is the number of words subsituted, Dw is the number of words deleted, Iw is the number of words inserted, Nw is the number of words in the reference - We can use levenshtein distance to calculate WER. Please draw an attention - that empty items will be removed when splitting sentences by delimiter. - :param reference: The reference sentence. - :type reference: basestring - :param hypothesis: The hypothesis sentence. - :type hypothesis: basestring - :param ignore_case: Whether case-sensitive or not. - :type ignore_case: bool - :param delimiter: Delimiter of input sentences. - :type delimiter: char - :return: Word error rate. - :rtype: float - :raises ValueError: If word number of reference is zero. + + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + delimiter: Delimiter of input sentences. + + Returns: + Word error rate (float) + + Usage: + >>> wer("hello world", "hello") + 0.5 """ edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter) @@ -150,29 +174,25 @@ def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "): def cer(reference, hypothesis, ignore_case=False, remove_space=False): """Calculate charactor error rate (CER). CER compares reference text and hypothesis text in char-level. CER is defined as: - .. math:: CER = (Sc + Dc + Ic) / Nc - where - .. code-block:: text + with Sc is the number of characters substituted, Dc is the number of characters deleted, Ic is the number of characters inserted Nc is the number of characters in the reference - We can use levenshtein distance to calculate CER. Chinese input should be - encoded to unicode. Please draw an attention that the leading and tailing - space characters will be truncated and multiple consecutive space - characters in a sentence will be replaced by one space character. - :param reference: The reference sentence. - :type reference: basestring - :param hypothesis: The hypothesis sentence. - :type hypothesis: basestring - :param ignore_case: Whether case-sensitive or not. - :type ignore_case: bool - :param remove_space: Whether remove internal space characters - :type remove_space: bool - :return: Character error rate. - :rtype: float - :raises ValueError: If the reference length is zero. + + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + remove_space: Whether remove internal space characters + + Returns: + Character error rate (float) + + Usage: + >>> cer("hello world", "hello") + 0.2727272727272727 """ edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space) -- cgit v1.2.3 From 9aba8d447ea727afb4208e41dcdcff9157446162 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 14:12:14 +0200 Subject: improved tokenizer --- swr2_asr/tokenizer.py | 466 +++++++++++--------------------------------------- 1 file changed, 99 insertions(+), 367 deletions(-) (limited to 'swr2_asr') diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 2e2fb57..69ced81 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -1,393 +1,125 @@ -"""Tokenizer for use with Multilingual Librispeech""" -import json -import os -from dataclasses import dataclass -from typing import Type +"""Tokenizer for Multilingual Librispeech datasets""" -import click -from tokenizers import Tokenizer, normalizers -from tokenizers.models import BPE -from tokenizers.pre_tokenizers import Whitespace -from tokenizers.trainers import BpeTrainer -from tqdm import tqdm - -class TokenizerType: - """Base class for tokenizers. - - exposes the same interface as tokenizers from the huggingface library""" - - def encode(self, sequence: str) -> list[int]: - """Encode a sequence to a list of integer labels""" - raise NotImplementedError - - def decode(self, labels: list[int], remove_special_tokens: bool) -> str: - """Decode a list of integer labels to a sequence""" - raise NotImplementedError - - def decode_batch(self, labels: list[list[int]]) -> list[str]: - """Decode a batch of integer labels to a list of sequences""" - raise NotImplementedError - - def get_vocab_size(self) -> int: - """Get the size of the vocabulary""" - raise NotImplementedError - - def enable_padding( - self, - length: int = -1, - direction: str = "right", - pad_id: int = 0, - pad_type_id: int = 0, - pad_token: str = "[PAD]", - ) -> None: - """Enable padding for the tokenizer""" - raise NotImplementedError - - def save(self, path: str) -> None: - """Save the tokenizer to a file""" - raise NotImplementedError - - @staticmethod - def from_file(path: str) -> "TokenizerType": - """Load the tokenizer from a file""" - raise NotImplementedError - - -MyTokenizerType = Type[TokenizerType] - - -@dataclass -class Encoding: - """Simple dataclass to represent an encoding""" - - ids: list[int] - tokens: list[str] - - -class CharTokenizer(TokenizerType): - """Very simple tokenizer for use with Multilingual Librispeech - - Simply checks what characters are in the dataset and uses them as tokens. - - Exposes the same interface as tokenizers from the huggingface library, i.e. - encode, decode, decode_batch, get_vocab_size, save, from_file and train. - """ +class CharTokenizer: + """Maps characters to integers and vice versa""" def __init__(self): - self.char_map = {} - self.index_map = {} - self.add_tokens(["", ""]) - - def add_tokens(self, tokens: list[str]): - """Manually add tokens to the tokenizer - - Args: - tokens (list[str]): List of tokens to add - """ - for token in tokens: - if token not in self.char_map: - self.char_map[token] = len(self.char_map) - self.index_map[len(self.index_map)] = token - - def train(self, dataset_path: str, language: str, split: str): - """Train the tokenizer on the given dataset - - Args: - dataset_path (str): Path to the MLS dataset - language (str): Language to use - split (str): Split to use + char_map_str = """ + _ + + + + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z + é + à + ä + ö + ß + ü + - + ' """ - if split not in ["train", "dev", "test", "all"]: - raise ValueError("Split must be one of train, dev, test, all") - if split == "all": - splits = ["train", "dev", "test"] - else: - splits = [split] - - chars: set = set() - for s_plit in splits: - transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - - with open( - transcript_path, - "r", - encoding="utf-8", - ) as file: - lines = file.readlines() - lines = [line.split(" ", 1)[1] for line in lines] - lines = [line.strip() for line in lines] - - for line in tqdm(lines, desc=f"Training tokenizer on {s_plit} split"): - chars.update(line) - offset = len(self.char_map) - for i, char in enumerate(chars): - i += offset - self.char_map[char] = i - self.index_map[i] = char - - def encode(self, sequence: str): - """Use a character map and convert text to an integer sequence - - automatically maps spaces to and makes everything lowercase - unknown characters are mapped to the token - - """ + self.char_map = {} + self.index_map = {} + for idx, char in enumerate(char_map_str.strip().split("\n")): + char = char.strip() + self.char_map[char] = idx + self.index_map[idx] = char + self.index_map[1] = " " + + def encode(self, text: str) -> list[int]: + """Use a character map and convert text to an integer sequence""" int_sequence = [] - sequence = sequence.lower() - for char in sequence: + for char in text: if char == " ": - mapped_char = self.char_map[""] + char = self.char_map[""] elif char not in self.char_map: - mapped_char = self.char_map[""] + char = self.char_map[""] else: - mapped_char = self.char_map[char] - int_sequence.append(mapped_char) - return Encoding(ids=int_sequence, tokens=list(sequence)) + char = self.char_map[char] + int_sequence.append(char) + return int_sequence - def decode(self, labels: list[int], remove_special_tokens: bool = True): - """Use a character map and convert integer labels to an text sequence - - Args: - labels (list[int]): List of integer labels - remove_special_tokens (bool): Whether to remove special tokens. - Defaults to True. - """ + def decode(self, labels: list[int]) -> str: + """Use a character map and convert integer labels to an text sequence""" string = [] for i in labels: - if remove_special_tokens and self.index_map[f"{i}"] == "": - continue - if remove_special_tokens and self.index_map[f"{i}"] == "": - string.append(" ") - string.append(self.index_map[f"{i}"]) + string.append(self.index_map[i]) return "".join(string).replace("", " ") - def decode_batch(self, labels: list[list[int]]): - """Use a character map and convert integer labels to an text sequence""" - strings = [] - for label in labels: - string = [] - for i in label: - if self.index_map[i] == "": - continue - if self.index_map[i] == "": - string.append(" ") - string.append(self.index_map[i]) - strings.append("".join(string).replace("", " ")) - return strings - - def get_vocab_size(self): - """Get the size of the vocabulary""" + def get_vocab_size(self) -> int: + """Get the number of unique characters in the dataset""" return len(self.char_map) - def save(self, path: str): - """Save the tokenizer to a file""" - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "w", encoding="utf-8") as file: - # save it in the following format: - # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} - json.dump( - {"char_map": self.char_map, "index_map": self.index_map}, - file, - ensure_ascii=False, - ) - - @staticmethod - def from_file(path: str) -> "CharTokenizer": - """Load the tokenizer from a file""" - char_tokenizer = CharTokenizer() - with open(path, "r", encoding="utf-8") as file: - # load it in the following format: - # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} - saved_file = json.load(file) - char_tokenizer.char_map = saved_file["char_map"] - char_tokenizer.index_map = saved_file["index_map"] - - return char_tokenizer - - -@click.command() -@click.option("--dataset_path", default="data", help="Path to the MLS dataset") -@click.option("--language", default="mls_german_opus", help="Language to use") -@click.option("--split", default="train", help="Split to use (including all)") -@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") -@click.option("--vocab_size", default=2000, help="Size of the vocabulary") -def train_bpe_tokenizer_cli( - dataset_path: str, - language: str, - split: str, - out_path: str, - vocab_size: int, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" - train_bpe_tokenizer( - dataset_path, - language, - split, - out_path, - vocab_size, - ) - - -def train_bpe_tokenizer( - dataset_path: str, - language: str, - split: str, - out_path: str, - vocab_size: int, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset - - Assumes that the MLS dataset is located in the dataset_path and there is a - transcripts.txt file in the split folder. - - Args: - dataset_path (str): Path to the MLS dataset - language (str): Language to use - split (str): Split to use - download (bool): Whether to download the dataset if it is not present - out_path (str): Path to save the tokenizer to - vocab_size (int): Size of the vocabulary - """ - if split not in ["train", "dev", "test", "all"]: - raise ValueError("Split must be one of train, dev, test, all") - - if split == "all": - splits = ["train", "dev", "test"] - else: - splits = [split] - - lines = [] + def get_blank_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] - for s_plit in splits: - transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - if not os.path.exists(transcripts_path): - raise FileNotFoundError( - f"Could not find transcripts.txt in {transcripts_path}. " - "Please make sure that the dataset is downloaded." - ) + def get_unk_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] - with open( - transcripts_path, - "r", - encoding="utf-8", - ) as file: - sp_lines = file.readlines() - sp_lines = [line.split(" ", 1)[1] for line in sp_lines] - sp_lines = [line.strip() for line in sp_lines] + def get_space_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] - lines.append(sp_lines) + # TODO: add train function - bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]")) - - initial_alphabet = [ - " ", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "ä", - "ö", - "ü", - "ß", - "-", - "é", - "è", - "à", - "ù", - "ç", - "â", - "ê", - "î", - "ô", - "û", - "ë", - "ï", - "ü", - ] - - - trainer = BpeTrainer( - special_tokens=["[UNK]"], - vocab_size=vocab_size, - initial_alphabet=initial_alphabet, - show_progress=True, - ) # type: ignore - - bpe_tokenizer.pre_tokenizer = Whitespace() # type: ignore - - bpe_tokenizer.normalizer = normalizers.Lowercase() # type: ignore - - bpe_tokenizer.train_from_iterator(lines, trainer=trainer) - - bpe_tokenizer.save(out_path) - - -@click.command() -@click.option("--dataset_path", default="data", help="Path to the MLS dataset") -@click.option("--language", default="mls_german_opus", help="Language to use") -@click.option("--split", default="train", help="Split to use") -@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") -def train_char_tokenizer_cli( - dataset_path: str, - language: str, - split: str, - out_path: str, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" - train_char_tokenizer(dataset_path, language, split, out_path) - - -def train_char_tokenizer( - dataset_path: str, - language: str, - split: str, - out_path: str, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset - - Assumes that the MLS dataset is located in the dataset_path and there is a - transcripts.txt file in the split folder. - - Args: - dataset_path (str): Path to the MLS dataset - language (str): Language to use - split (str): Split to use - download (bool): Whether to download the dataset if it is not present - out_path (str): Path to save the tokenizer to - """ - char_tokenizer = CharTokenizer() - - char_tokenizer.train(dataset_path, language, split) + def save(self, path: str) -> None: + """Save the tokenizer to a file""" + with open(path, "w", encoding="utf-8") as file: + for char, index in self.char_map.items(): + file.write(f"{char} {index}\n") - char_tokenizer.save(out_path) + @staticmethod + def from_file(tokenizer_file: str) -> "CharTokenizer": + """Instantiate a CharTokenizer from a file""" + load_tokenizer = CharTokenizer() + with open(tokenizer_file, "r", encoding="utf-8") as file: + for line in file: + line = line.strip() + if line: + char, index = line.split() + tokenizer.char_map[char] = int(index) + tokenizer.index_map[int(index)] = char + return load_tokenizer if __name__ == "__main__": tokenizer = CharTokenizer() - tokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - - print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids)) + tokenizer.save("data/tokenizers/char_tokenizer_german.json") + print(tokenizer.char_map) + print(tokenizer.index_map) + print(tokenizer.get_vocab_size()) + print(tokenizer.get_blank_token()) + print(tokenizer.get_unk_token()) + print(tokenizer.get_space_token()) + print(tokenizer.encode("hallo welt")) + print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) -- cgit v1.2.3 From 01fae2b5e395e84db6a7e9819b6f98777c46e845 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 14:33:18 +0200 Subject: readd space token, idk why its there but it could break things --- swr2_asr/tokenizer.py | 1 + 1 file changed, 1 insertion(+) (limited to 'swr2_asr') diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 69ced81..d92465a 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -44,6 +44,7 @@ class CharTokenizer: ü - ' + """ self.char_map = {} -- cgit v1.2.3 From 9dc3bc07424908dd7cf3f052708f506fd58b6e2c Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 14:49:28 +0200 Subject: refactor utilities (data, vis, tokenizer) --- swr2_asr/tokenizer.py | 126 ------------ swr2_asr/utils.py | 417 ---------------------------------------- swr2_asr/utils/__init__.py | 0 swr2_asr/utils/data.py | 371 +++++++++++++++++++++++++++++++++++ swr2_asr/utils/decoder.py | 26 +++ swr2_asr/utils/tokenizer.py | 126 ++++++++++++ swr2_asr/utils/visualization.py | 22 +++ 7 files changed, 545 insertions(+), 543 deletions(-) delete mode 100644 swr2_asr/tokenizer.py delete mode 100644 swr2_asr/utils.py create mode 100644 swr2_asr/utils/__init__.py create mode 100644 swr2_asr/utils/data.py create mode 100644 swr2_asr/utils/decoder.py create mode 100644 swr2_asr/utils/tokenizer.py create mode 100644 swr2_asr/utils/visualization.py (limited to 'swr2_asr') diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py deleted file mode 100644 index d92465a..0000000 --- a/swr2_asr/tokenizer.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Tokenizer for Multilingual Librispeech datasets""" - - -class CharTokenizer: - """Maps characters to integers and vice versa""" - - def __init__(self): - char_map_str = """ - _ - - - - a - b - c - d - e - f - g - h - i - j - k - l - m - n - o - p - q - r - s - t - u - v - w - x - y - z - é - à - ä - ö - ß - ü - - - ' - - """ - - self.char_map = {} - self.index_map = {} - for idx, char in enumerate(char_map_str.strip().split("\n")): - char = char.strip() - self.char_map[char] = idx - self.index_map[idx] = char - self.index_map[1] = " " - - def encode(self, text: str) -> list[int]: - """Use a character map and convert text to an integer sequence""" - int_sequence = [] - for char in text: - if char == " ": - char = self.char_map[""] - elif char not in self.char_map: - char = self.char_map[""] - else: - char = self.char_map[char] - int_sequence.append(char) - return int_sequence - - def decode(self, labels: list[int]) -> str: - """Use a character map and convert integer labels to an text sequence""" - string = [] - for i in labels: - string.append(self.index_map[i]) - return "".join(string).replace("", " ") - - def get_vocab_size(self) -> int: - """Get the number of unique characters in the dataset""" - return len(self.char_map) - - def get_blank_token(self) -> int: - """Get the integer representation of the character""" - return self.char_map[""] - - def get_unk_token(self) -> int: - """Get the integer representation of the character""" - return self.char_map[""] - - def get_space_token(self) -> int: - """Get the integer representation of the character""" - return self.char_map[""] - - # TODO: add train function - - def save(self, path: str) -> None: - """Save the tokenizer to a file""" - with open(path, "w", encoding="utf-8") as file: - for char, index in self.char_map.items(): - file.write(f"{char} {index}\n") - - @staticmethod - def from_file(tokenizer_file: str) -> "CharTokenizer": - """Instantiate a CharTokenizer from a file""" - load_tokenizer = CharTokenizer() - with open(tokenizer_file, "r", encoding="utf-8") as file: - for line in file: - line = line.strip() - if line: - char, index = line.split() - tokenizer.char_map[char] = int(index) - tokenizer.index_map[int(index)] = char - return load_tokenizer - - -if __name__ == "__main__": - tokenizer = CharTokenizer() - tokenizer.save("data/tokenizers/char_tokenizer_german.json") - print(tokenizer.char_map) - print(tokenizer.index_map) - print(tokenizer.get_vocab_size()) - print(tokenizer.get_blank_token()) - print(tokenizer.get_unk_token()) - print(tokenizer.get_space_token()) - print(tokenizer.encode("hallo welt")) - print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py deleted file mode 100644 index a362b9e..0000000 --- a/swr2_asr/utils.py +++ /dev/null @@ -1,417 +0,0 @@ -"""Class containing utils for the ASR system.""" -import os -from enum import Enum -from typing import TypedDict - -import matplotlib.pyplot as plt -import numpy as np -import torch -import torchaudio -from tokenizers import Tokenizer -from torch.utils.data import Dataset -from torchaudio.datasets.utils import _extract_tar as extract_archive - -from swr2_asr.tokenizer import TokenizerType - -train_audio_transforms = torch.nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) - - -# create enum specifiying dataset splits -class MLSSplit(str, Enum): - """Enum specifying dataset as they are defined in the - Multilingual LibriSpeech dataset""" - - TRAIN = "train" - TEST = "test" - DEV = "dev" - - -class Split(str, Enum): - """Extending the MLSSplit class to allow for a custom validatio split""" - - TRAIN = "train" - VALID = "valid" - TEST = "test" - DEV = "dev" - - -def split_to_mls_split(split_name: Split) -> MLSSplit: - """Converts the custom split to a MLSSplit""" - if split_name == Split.VALID: - return MLSSplit.TRAIN - else: - return split_name # type: ignore - - -class Sample(TypedDict): - """Type for a sample in the dataset""" - - waveform: torch.Tensor - spectrogram: torch.Tensor - input_length: int - utterance: torch.Tensor - utterance_length: int - sample_rate: int - speaker_id: str - book_id: str - chapter_id: str - - -class MLSDataset(Dataset): - """Custom Dataset for reading Multilingual LibriSpeech - - Attributes: - dataset_path (str): - path to the dataset - language (str): - language of the dataset - split (Split): - split of the dataset - mls_split (MLSSplit): - split of the dataset as defined in the Multilingual LibriSpeech dataset - dataset_lookup (list): - list of dicts containing the speakerid, bookid, chapterid and utterance - - directory structure: - - ├── - │ ├── train - │ │ ├── transcripts.txt - │ │ └── audio - │ │ └── - │ │ └── - │ │ └── __.opus / .flac - - each line in transcripts.txt has the following format: - __ - """ - - def __init__( - self, - dataset_path: str, - language: str, - split: Split, - limited: bool, - download: bool, - spectrogram_hparams: dict | None, - ): - """Initializes the dataset""" - self.dataset_path = dataset_path - self.language = language - self.file_ext = ".opus" if "opus" in language else ".flac" - self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk - self.split: Split = split # split used internally - - if spectrogram_hparams is None: - self.spectrogram_hparams = { - "sample_rate": 16000, - "n_fft": 400, - "win_length": 400, - "hop_length": 160, - "n_mels": 128, - "f_min": 0, - "f_max": 8000, - "power": 2.0, - } - else: - self.spectrogram_hparams = spectrogram_hparams - - self.dataset_lookup = [] - self.tokenizer: type[TokenizerType] - - self._handle_download_dataset(download) - self._validate_local_directory() - if limited and (split == Split.TRAIN or split == Split.VALID): - self.initialize_limited() - else: - self.initialize() - - def initialize_limited(self) -> None: - """Initializes the limited supervision dataset""" - # get file handles - # get file paths - # get transcripts - # create train or validation split - - handles = set() - - train_root_path = os.path.join(self.dataset_path, self.language, "train") - - # get file handles for 9h - with open( - os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"), - "r", - encoding="utf-8", - ) as file: - for line in file: - handles.add(line.strip()) - - # get file handles for 1h splits - for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")): - if handle_path not in range(0, 6): - continue - with open( - os.path.join( - train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt" - ), - "r", - encoding="utf-8", - ) as file: - for line in file: - handles.add(line.strip()) - - # get file paths for handles - file_paths = [] - for handle in handles: - file_paths.append( - os.path.join( - train_root_path, - "audio", - handle.split("_")[0], - handle.split("_")[1], - handle + self.file_ext, - ) - ) - - # get transcripts for handles - transcripts = [] - with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file: - for line in file: - if line.split("\t")[0] in handles: - transcripts.append(line.strip()) - - # create train or valid split randomly with seed 42 - if self.split == Split.TRAIN: - np.random.seed(42) - indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8)) - file_paths = [file_paths[i] for i in indices] - transcripts = [transcripts[i] for i in indices] - elif self.split == Split.VALID: - np.random.seed(42) - indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2)) - file_paths = [file_paths[i] for i in indices] - transcripts = [transcripts[i] for i in indices] - - # create dataset lookup - self.dataset_lookup = [ - { - "speakerid": path.split("/")[-3], - "bookid": path.split("/")[-2], - "chapterid": path.split("/")[-1].split("_")[2].split(".")[0], - "utterance": utterance.split("\t")[1], - } - for path, utterance in zip(file_paths, transcripts, strict=False) - ] - - def initialize(self) -> None: - """Initializes the entire dataset - - Reads the transcripts.txt file and creates a lookup table - """ - transcripts_path = os.path.join( - self.dataset_path, self.language, self.mls_split, "transcripts.txt" - ) - - with open(transcripts_path, "r", encoding="utf-8") as script_file: - # read all lines in transcripts.txt - transcripts = script_file.readlines() - # split each line into (__, ) - transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore - utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore - identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore - identifier = [path.split("_") for path in identifier] - - if self.split == Split.VALID: - np.random.seed(42) - indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) - utterances = [utterances[i] for i in indices] - identifier = [identifier[i] for i in indices] - elif self.split == Split.TRAIN: - np.random.seed(42) - indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) - utterances = [utterances[i] for i in indices] - identifier = [identifier[i] for i in indices] - - self.dataset_lookup = [ - { - "speakerid": path[0], - "bookid": path[1], - "chapterid": path[2], - "utterance": utterance, - } - for path, utterance in zip(identifier, utterances, strict=False) - ] - - def set_tokenizer(self, tokenizer: type[TokenizerType]): - """Sets the tokenizer""" - self.tokenizer = tokenizer - - def _handle_download_dataset(self, download: bool) -> None: - """Download the dataset""" - if not download: - print("Download flag not set, skipping download") - return - # zip exists: - if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: - print(f"Found dataset at {self.dataset_path}. Skipping download") - # zip does not exist: - else: - os.makedirs(self.dataset_path, exist_ok=True) - url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" - - torch.hub.download_url_to_file( - url, os.path.join(self.dataset_path, self.language) + ".tar.gz" - ) - - # unzip the dataset - if not os.path.isdir(os.path.join(self.dataset_path, self.language)): - print( - f"Unzipping the dataset at \ - {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" - ) - extract_archive( - os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True - ) - else: - print("Dataset is already unzipped, validating it now") - return - - def _validate_local_directory(self): - # check if dataset_path exists - if not os.path.exists(self.dataset_path): - raise ValueError("Dataset path does not exist") - if not os.path.exists(os.path.join(self.dataset_path, self.language)): - raise ValueError("Language not downloaded!") - if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): - raise ValueError("Split not found in dataset") - - def __len__(self): - """Returns the length of the dataset""" - return len(self.dataset_lookup) - - def __getitem__(self, idx: int) -> Sample: - """One sample""" - if self.tokenizer is None: - raise ValueError("No tokenizer set") - # get the utterance - utterance = self.dataset_lookup[idx]["utterance"] - - # get the audio file - audio_path = os.path.join( - self.dataset_path, - self.language, - self.mls_split, - "audio", - self.dataset_lookup[idx]["speakerid"], - self.dataset_lookup[idx]["bookid"], - "_".join( - [ - self.dataset_lookup[idx]["speakerid"], - self.dataset_lookup[idx]["bookid"], - self.dataset_lookup[idx]["chapterid"], - ] - ) - + self.file_ext, - ) - - waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member - - # resample if necessary - if sample_rate != self.spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample( - sample_rate, self.spectrogram_hparams["sample_rate"] - ) - waveform = resampler(waveform) - - spec = ( - torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform) - .squeeze(0) - .transpose(0, 1) - ) - - input_length = spec.shape[0] // 2 - - utterance_length = len(utterance) - - utterance = self.tokenizer.encode(utterance) - - utterance = torch.LongTensor(utterance.ids) # pylint: disable=no-member - - return Sample( - waveform=waveform, - spectrogram=spec, - input_length=input_length, - utterance=utterance, - utterance_length=utterance_length, - sample_rate=self.spectrogram_hparams["sample_rate"], - speaker_id=self.dataset_lookup[idx]["speakerid"], - book_id=self.dataset_lookup[idx]["bookid"], - chapter_id=self.dataset_lookup[idx]["chapterid"], - ) - - -def collate_fn(samples: list[Sample]) -> dict: - """Collate function for the dataloader - - pads all tensors within a batch to the same dimensions - """ - waveforms = [] - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - - for sample in samples: - waveforms.append(sample["waveform"].transpose(0, 1)) - spectrograms.append(sample["spectrogram"]) - labels.append(sample["utterance"]) - input_lengths.append(sample["spectrogram"].shape[0] // 2) - label_lengths.append(len(sample["utterance"])) - - waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) - spectrograms = ( - torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) - ) - labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) - - return { - "waveform": waveforms, - "spectrogram": spectrograms, - "input_length": input_lengths, - "utterance": labels, - "utterance_length": label_lengths, - } - - -if __name__ == "__main__": - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" - LANGUAGE = "mls_german_opus" - split = Split.TRAIN - DOWNLOAD = False - - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, False, DOWNLOAD, None) - - tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - dataset.set_tokenizer(tok) - - -def plot(epochs, path): - """Plots the losses over the epochs""" - losses = list() - test_losses = list() - cers = list() - wers = list() - for epoch in range(1, epochs + 1): - current_state = torch.load(path + str(epoch)) - losses.append(current_state["loss"]) - test_losses.append(current_state["test_loss"]) - cers.append(current_state["avg_cer"]) - wers.append(current_state["avg_wer"]) - - plt.plot(losses) - plt.plot(test_losses) - plt.savefig("losses.svg") diff --git a/swr2_asr/utils/__init__.py b/swr2_asr/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py new file mode 100644 index 0000000..93f4a9a --- /dev/null +++ b/swr2_asr/utils/data.py @@ -0,0 +1,371 @@ +"""Class containing utils for the ASR system.""" +import os +from enum import Enum +from typing import TypedDict + +import numpy as np +import torch +import torchaudio +from torch import Tensor, nn +from torch.utils.data import Dataset +from torchaudio.datasets.utils import _extract_tar + +from swr2_asr.utils.tokenizer import CharTokenizer + + +class DataProcessing: + """Data processing class for the dataloader""" + + def __init__(self, data_type: str, tokenizer: CharTokenizer): + self.data_type = data_type + self.tokenizer = tokenizer + + if data_type == "train": + self.audio_transform = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), + ) + elif data_type == "valid": + self.audio_transform = torchaudio.transforms.MelSpectrogram() + + def __call__(self, data) -> tuple[Tensor, Tensor, list, list]: + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + for waveform, _, utterance, _, _, _ in data: + spec = self.audio_transform(waveform).squeeze(0).transpose(0, 1) + spectrograms.append(spec) + label = torch.Tensor(self.tokenizer.encode(utterance.lower())) + labels.append(label) + input_lengths.append(spec.shape[0] // 2) + label_lengths.append(len(label)) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) + ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths + + +# create enum specifiying dataset splits +class MLSSplit(str, Enum): + """Enum specifying dataset as they are defined in the + Multilingual LibriSpeech dataset""" + + TRAIN = "train" + TEST = "test" + DEV = "dev" + + +class Split(str, Enum): + """Extending the MLSSplit class to allow for a custom validation split""" + + TRAIN = "train" + VALID = "valid" + TEST = "test" + DEV = "dev" + + +def split_to_mls_split(split_name: Split) -> MLSSplit: + """Converts the custom split to a MLSSplit""" + if split_name == Split.VALID: + return MLSSplit.TRAIN + return split_name # type: ignore + + +class Sample(TypedDict): + """Type for a sample in the dataset""" + + waveform: torch.Tensor + spectrogram: torch.Tensor + input_length: int + utterance: torch.Tensor + utterance_length: int + sample_rate: int + speaker_id: str + book_id: str + chapter_id: str + + +class MLSDataset(Dataset): + """Custom Dataset for reading Multilingual LibriSpeech + + Attributes: + dataset_path (str): + path to the dataset + language (str): + language of the dataset + split (Split): + split of the dataset + mls_split (MLSSplit): + split of the dataset as defined in the Multilingual LibriSpeech dataset + dataset_lookup (list): + list of dicts containing the speakerid, bookid, chapterid and utterance + + directory structure: + + ├── + │ ├── train + │ │ ├── transcripts.txt + │ │ └── audio + │ │ └── + │ │ └── + │ │ └── __.opus / .flac + + each line in transcripts.txt has the following format: + __ + """ + + def __init__( + self, + dataset_path: str, + language: str, + split: Split, + limited: bool, + download: bool, + size: float = 0.2, + ): + """Initializes the dataset""" + self.dataset_path = dataset_path + self.language = language + self.file_ext = ".opus" if "opus" in language else ".flac" + self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk + self.split: Split = split # split used internally + + self.dataset_lookup = [] + + self._handle_download_dataset(download) + self._validate_local_directory() + if limited and (split == Split.TRAIN or split == Split.VALID): + self.initialize_limited() + else: + self.initialize() + + self.dataset_lookup = self.dataset_lookup[: int(len(self.dataset_lookup) * size)] + + def initialize_limited(self) -> None: + """Initializes the limited supervision dataset""" + # get file handles + # get file paths + # get transcripts + # create train or validation split + + handles = set() + + train_root_path = os.path.join(self.dataset_path, self.language, "train") + + # get file handles for 9h + with open( + os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"), + "r", + encoding="utf-8", + ) as file: + for line in file: + handles.add(line.strip()) + + # get file handles for 1h splits + for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")): + if handle_path not in range(0, 6): + continue + with open( + os.path.join( + train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt" + ), + "r", + encoding="utf-8", + ) as file: + for line in file: + handles.add(line.strip()) + + # get file paths for handles + file_paths = [] + for handle in handles: + file_paths.append( + os.path.join( + train_root_path, + "audio", + handle.split("_")[0], + handle.split("_")[1], + handle + self.file_ext, + ) + ) + + # get transcripts for handles + transcripts = [] + with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file: + for line in file: + if line.split("\t")[0] in handles: + transcripts.append(line.strip()) + + # create train or valid split randomly with seed 42 + if self.split == Split.TRAIN: + np.random.seed(42) + indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8)) + file_paths = [file_paths[i] for i in indices] + transcripts = [transcripts[i] for i in indices] + elif self.split == Split.VALID: + np.random.seed(42) + indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2)) + file_paths = [file_paths[i] for i in indices] + transcripts = [transcripts[i] for i in indices] + + # create dataset lookup + self.dataset_lookup = [ + { + "speakerid": path.split("/")[-3], + "bookid": path.split("/")[-2], + "chapterid": path.split("/")[-1].split("_")[2].split(".")[0], + "utterance": utterance.split("\t")[1], + } + for path, utterance in zip(file_paths, transcripts, strict=False) + ] + + def initialize(self) -> None: + """Initializes the entire dataset + + Reads the transcripts.txt file and creates a lookup table + """ + transcripts_path = os.path.join( + self.dataset_path, self.language, self.mls_split, "transcripts.txt" + ) + + with open(transcripts_path, "r", encoding="utf-8") as script_file: + # read all lines in transcripts.txt + transcripts = script_file.readlines() + # split each line into (__, ) + transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore + utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore + identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore + identifier = [path.split("_") for path in identifier] + + if self.split == Split.VALID: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + elif self.split == Split.TRAIN: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + + self.dataset_lookup = [ + { + "speakerid": path[0], + "bookid": path[1], + "chapterid": path[2], + "utterance": utterance, + } + for path, utterance in zip(identifier, utterances, strict=False) + ] + + def _handle_download_dataset(self, download: bool) -> None: + """Download the dataset""" + if not download: + print("Download flag not set, skipping download") + return + # zip exists: + if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: + print(f"Found dataset at {self.dataset_path}. Skipping download") + # path exists: + elif os.path.isdir(os.path.join(self.dataset_path, self.language)) and download: + return + else: + os.makedirs(self.dataset_path, exist_ok=True) + url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" + + torch.hub.download_url_to_file( + url, os.path.join(self.dataset_path, self.language) + ".tar.gz" + ) + + # unzip the dataset + if not os.path.isdir(os.path.join(self.dataset_path, self.language)): + print( + f"Unzipping the dataset at \ + {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" + ) + _extract_tar(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True) + else: + print("Dataset is already unzipped, validating it now") + return + + def _validate_local_directory(self): + # check if dataset_path exists + if not os.path.exists(self.dataset_path): + raise ValueError("Dataset path does not exist") + if not os.path.exists(os.path.join(self.dataset_path, self.language)): + raise ValueError("Language not downloaded!") + if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): + raise ValueError("Split not found in dataset") + + def __len__(self): + """Returns the length of the dataset""" + return len(self.dataset_lookup) + + def __getitem__(self, idx: int) -> tuple[Tensor, int, str, int, int, int]: + """One sample + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + int: + Speaker ID + int: + Chapter ID + int: + Utterance ID + """ + # get the utterance + dataset_lookup_entry = self.dataset_lookup[idx] + + utterance = dataset_lookup_entry["utterance"] + + # get the audio file + audio_path = os.path.join( + self.dataset_path, + self.language, + self.mls_split, + "audio", + self.dataset_lookup[idx]["speakerid"], + self.dataset_lookup[idx]["bookid"], + "_".join( + [ + self.dataset_lookup[idx]["speakerid"], + self.dataset_lookup[idx]["bookid"], + self.dataset_lookup[idx]["chapterid"], + ] + ) + + self.file_ext, + ) + + waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member + + # resample if necessary + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) + waveform = resampler(waveform) + + return ( + waveform, + sample_rate, + utterance, + dataset_lookup_entry["speakerid"], + dataset_lookup_entry["chapterid"], + idx, + ) # type: ignore + + +if __name__ == "__main__": + DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" + LANGUAGE = "mls_german_opus" + split = Split.TRAIN + DOWNLOAD = False diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py new file mode 100644 index 0000000..fcddb79 --- /dev/null +++ b/swr2_asr/utils/decoder.py @@ -0,0 +1,26 @@ +"""Decoder for CTC-based ASR.""" "" +import torch + +from swr2_asr.utils.tokenizer import CharTokenizer + + +# TODO: refactor to use torch CTC decoder class +def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): + """Greedily decode a sequence.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(tokenizer.decode(decode)) + return decodes, targets + + +# TODO: add beam search decoder diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py new file mode 100644 index 0000000..d92465a --- /dev/null +++ b/swr2_asr/utils/tokenizer.py @@ -0,0 +1,126 @@ +"""Tokenizer for Multilingual Librispeech datasets""" + + +class CharTokenizer: + """Maps characters to integers and vice versa""" + + def __init__(self): + char_map_str = """ + _ + + + + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z + é + à + ä + ö + ß + ü + - + ' + + """ + + self.char_map = {} + self.index_map = {} + for idx, char in enumerate(char_map_str.strip().split("\n")): + char = char.strip() + self.char_map[char] = idx + self.index_map[idx] = char + self.index_map[1] = " " + + def encode(self, text: str) -> list[int]: + """Use a character map and convert text to an integer sequence""" + int_sequence = [] + for char in text: + if char == " ": + char = self.char_map[""] + elif char not in self.char_map: + char = self.char_map[""] + else: + char = self.char_map[char] + int_sequence.append(char) + return int_sequence + + def decode(self, labels: list[int]) -> str: + """Use a character map and convert integer labels to an text sequence""" + string = [] + for i in labels: + string.append(self.index_map[i]) + return "".join(string).replace("", " ") + + def get_vocab_size(self) -> int: + """Get the number of unique characters in the dataset""" + return len(self.char_map) + + def get_blank_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] + + def get_unk_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] + + def get_space_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] + + # TODO: add train function + + def save(self, path: str) -> None: + """Save the tokenizer to a file""" + with open(path, "w", encoding="utf-8") as file: + for char, index in self.char_map.items(): + file.write(f"{char} {index}\n") + + @staticmethod + def from_file(tokenizer_file: str) -> "CharTokenizer": + """Instantiate a CharTokenizer from a file""" + load_tokenizer = CharTokenizer() + with open(tokenizer_file, "r", encoding="utf-8") as file: + for line in file: + line = line.strip() + if line: + char, index = line.split() + tokenizer.char_map[char] = int(index) + tokenizer.index_map[int(index)] = char + return load_tokenizer + + +if __name__ == "__main__": + tokenizer = CharTokenizer() + tokenizer.save("data/tokenizers/char_tokenizer_german.json") + print(tokenizer.char_map) + print(tokenizer.index_map) + print(tokenizer.get_vocab_size()) + print(tokenizer.get_blank_token()) + print(tokenizer.get_unk_token()) + print(tokenizer.get_space_token()) + print(tokenizer.encode("hallo welt")) + print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py new file mode 100644 index 0000000..80f942a --- /dev/null +++ b/swr2_asr/utils/visualization.py @@ -0,0 +1,22 @@ +"""Utilities for visualizing the training process and results.""" + +import matplotlib.pyplot as plt +import torch + + +def plot(epochs, path): + """Plots the losses over the epochs""" + losses = list() + test_losses = list() + cers = list() + wers = list() + for epoch in range(1, epochs + 1): + current_state = torch.load(path + str(epoch)) + losses.append(current_state["loss"]) + test_losses.append(current_state["test_loss"]) + cers.append(current_state["avg_cer"]) + wers.append(current_state["avg_wer"]) + + plt.plot(losses) + plt.plot(test_losses) + plt.savefig("losses.svg") -- cgit v1.2.3 From effde1d9e71864a2c5bd8464db0958f5bf2d1733 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 15:07:08 +0200 Subject: added small stuff to data utilities --- swr2_asr/utils/data.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) (limited to 'swr2_asr') diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 93f4a9a..74d10c9 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -123,9 +123,9 @@ class MLSDataset(Dataset): self, dataset_path: str, language: str, - split: Split, - limited: bool, - download: bool, + split: Split, # pylint: disable=redefined-outer-name + limited: bool = False, + download: bool = True, size: float = 0.2, ): """Initializes the dataset""" @@ -365,7 +365,28 @@ class MLSDataset(Dataset): if __name__ == "__main__": + from torch.utils.data import DataLoader + DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" LANGUAGE = "mls_german_opus" - split = Split.TRAIN + split = Split.DEV DOWNLOAD = False + + dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD) + + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=True, + collate_fn=DataProcessing( + "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + ), + ) + + for batch in dataloader: + print(batch) + break + + print(len(dataset)) + + print(dataset[0]) -- cgit v1.2.3 From c078ce6789c134aa05607903d3bf9e4be64df45d Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 15:45:35 +0200 Subject: big change! --- .vscode/settings.json | 26 +-- pyproject.toml | 3 +- swr2_asr/loss_scores.py | 203 ----------------------- swr2_asr/model_deep_speech.py | 25 ++- swr2_asr/train.py | 372 ++++++++++++++++++++++-------------------- swr2_asr/utils/data.py | 14 -- swr2_asr/utils/loss_scores.py | 203 +++++++++++++++++++++++ 7 files changed, 435 insertions(+), 411 deletions(-) delete mode 100644 swr2_asr/loss_scores.py create mode 100644 swr2_asr/utils/loss_scores.py (limited to 'swr2_asr') diff --git a/.vscode/settings.json b/.vscode/settings.json index 0054bca..1adbc18 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,14 +1,14 @@ { - "[python]": { - "editor.formatOnType": true, - "editor.defaultFormatter": "ms-python.black-formatter", - "editor.formatOnSave": true, - "editor.rulers": [88, 120], - }, - "black-formatter.importStrategy": "fromEnvironment", - "python.analysis.typeCheckingMode": "basic", - "ruff.organizeImports": true, - "ruff.importStrategy": "fromEnvironment", - "ruff.fixAll": true, - "ruff.run": "onType" -} \ No newline at end of file + "[python]": { + "editor.formatOnType": true, + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true, + "editor.rulers": [88, 120] + }, + "black-formatter.importStrategy": "fromEnvironment", + "python.analysis.typeCheckingMode": "off", + "ruff.organizeImports": true, + "ruff.importStrategy": "fromEnvironment", + "ruff.fixAll": true, + "ruff.run": "onType" +} diff --git a/pyproject.toml b/pyproject.toml index 6f74b49..38cc51a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,12 +23,11 @@ black = "^23.7.0" mypy = "^1.5.1" pylint = "^2.17.5" ruff = "^0.0.285" -types-tqdm = "^4.66.0.1" [tool.ruff] select = ["E", "F", "B", "I"] fixable = ["ALL"] -line-length = 120 +line-length = 100 target-version = "py310" [tool.black] diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py deleted file mode 100644 index 80285f6..0000000 --- a/swr2_asr/loss_scores.py +++ /dev/null @@ -1,203 +0,0 @@ -"""Methods for determining the loss and scores of the model.""" -import numpy as np - - -def avg_wer(wer_scores, combined_ref_len) -> float: - """Calculate the average word error rate (WER). - - Args: - wer_scores: word error rate scores - combined_ref_len: combined length of reference sentences - - Returns: - average word error rate (float) - - Usage: - >>> avg_wer([0.5, 0.5], 2) - 0.5 - """ - return float(sum(wer_scores)) / float(combined_ref_len) - - -def _levenshtein_distance(ref, hyp) -> int: - """Levenshtein distance. - - Args: - ref: reference sentence - hyp: hypothesis sentence - - Returns: - distance: levenshtein distance between reference and hypothesis - - Usage: - >>> _levenshtein_distance("hello", "helo") - 2 - """ - len_ref = len(ref) - len_hyp = len(hyp) - - # special case - if ref == hyp: - return 0 - if len_ref == 0: - return len_hyp - if len_hyp == 0: - return len_ref - - if len_ref < len_hyp: - ref, hyp = hyp, ref - len_ref, len_hyp = len_hyp, len_ref - - # use O(min(m, n)) space - distance = np.zeros((2, len_hyp + 1), dtype=np.int32) - - # initialize distance matrix - for j in range(0, len_hyp + 1): - distance[0][j] = j - - # calculate levenshtein distance - for i in range(1, len_ref + 1): - prev_row_idx = (i - 1) % 2 - cur_row_idx = i % 2 - distance[cur_row_idx][0] = i - for j in range(1, len_hyp + 1): - if ref[i - 1] == hyp[j - 1]: - distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] - else: - s_num = distance[prev_row_idx][j - 1] + 1 - i_num = distance[cur_row_idx][j - 1] + 1 - d_num = distance[prev_row_idx][j] + 1 - distance[cur_row_idx][j] = min(s_num, i_num, d_num) - - return distance[len_ref % 2][len_hyp] - - -def word_errors( - reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " -) -> tuple[float, int]: - """Compute the levenshtein distance between reference sequence and - hypothesis sequence in word-level. - - Args: - reference: The reference sentence. - hypothesis: The hypothesis sentence. - ignore_case: Whether case-sensitive or not. - delimiter: Delimiter of input sentences. - - Returns: - Levenshtein distance and length of reference sentence. - - Usage: - >>> word_errors("hello world", "hello") - 1, 2 - """ - if ignore_case: - reference = reference.lower() - hypothesis = hypothesis.lower() - - ref_words = reference.split(delimiter) - hyp_words = hypothesis.split(delimiter) - - edit_distance = _levenshtein_distance(ref_words, hyp_words) - return float(edit_distance), len(ref_words) - - -def char_errors( - reference: str, - hypothesis: str, - ignore_case: bool = False, - remove_space: bool = False, -) -> tuple[float, int]: - """Compute the levenshtein distance between reference sequence and - hypothesis sequence in char-level. - Args: - reference: The reference sentence. - hypothesis: The hypothesis sentence. - ignore_case: Whether case-sensitive or not. - remove_space: Whether remove internal space characters - - Returns: - Levenshtein distance and length of reference sentence. - - Usage: - >>> char_errors("hello world", "hello") - 1, 10 - """ - if ignore_case: - reference = reference.lower() - hypothesis = hypothesis.lower() - - join_char = " " - if remove_space: - join_char = "" - - reference = join_char.join(filter(None, reference.split(" "))) - hypothesis = join_char.join(filter(None, hypothesis.split(" "))) - - edit_distance = _levenshtein_distance(reference, hypothesis) - return float(edit_distance), len(reference) - - -def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" ") -> float: - """Calculate word error rate (WER). WER compares reference text and - hypothesis text in word-level. - WER is defined as: - WER = (Sw + Dw + Iw) / Nw - with: - Sw is the number of words subsituted, - Dw is the number of words deleted, - Iw is the number of words inserted, - Nw is the number of words in the reference - - Args: - reference: The reference sentence. - hypothesis: The hypothesis sentence. - ignore_case: Whether case-sensitive or not. - delimiter: Delimiter of input sentences. - - Returns: - Word error rate (float) - - Usage: - >>> wer("hello world", "hello") - 0.5 - """ - edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter) - - if ref_len == 0: - raise ValueError("Reference's word number should be greater than 0.") - - word_error_rate = float(edit_distance) / ref_len - return word_error_rate - - -def cer(reference, hypothesis, ignore_case=False, remove_space=False): - """Calculate charactor error rate (CER). CER compares reference text and - hypothesis text in char-level. CER is defined as: - CER = (Sc + Dc + Ic) / Nc - with - Sc is the number of characters substituted, - Dc is the number of characters deleted, - Ic is the number of characters inserted - Nc is the number of characters in the reference - - Args: - reference: The reference sentence. - hypothesis: The hypothesis sentence. - ignore_case: Whether case-sensitive or not. - remove_space: Whether remove internal space characters - - Returns: - Character error rate (float) - - Usage: - >>> cer("hello world", "hello") - 0.2727272727272727 - """ - edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space) - - if ref_len == 0: - raise ValueError("Length of reference should be greater than 0.") - - char_error_rate = float(edit_distance) / ref_len - return char_error_rate diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index dd07ff9..8ddbd99 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,8 +1,29 @@ -"""Main definition of model""" +"""Main definition of the Deep speech 2 model by Baidu Research. + +Following definition by Assembly AI +(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) +""" +from typing import TypedDict + import torch.nn.functional as F from torch import nn +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" + + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int + + class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" @@ -60,7 +81,7 @@ class ResidualCNN(nn.Module): class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" + """Bidirectional GRU with Layer Normalization and Dropout""" def __init__( self, diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9f12bcb..ac7666b 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -7,50 +7,17 @@ import torch import torch.nn.functional as F from torch import nn, optim from torch.utils.data import DataLoader -from tqdm import tqdm - -from swr2_asr.model_deep_speech import SpeechRecognitionModel -from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer -from swr2_asr.utils import MLSDataset, Split, collate_fn,plot - -from .loss_scores import cer, wer - - -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): - """Greedily decode a sequence.""" - print("output shape", output.shape) - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes, targets - - -class IterMeter: +from tqdm.autonotebook import tqdm + +from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel +from swr2_asr.utils.data import DataProcessing, MLSDataset, Split +from swr2_asr.utils.decoder import greedy_decoder +from swr2_asr.utils.tokenizer import CharTokenizer + +from .utils.loss_scores import cer, wer + + +class IterMeter(object): """keeps track of total iterations""" def __init__(self): @@ -61,123 +28,195 @@ class IterMeter: self.val += 1 def get(self): - """get""" + """get steps""" return self.val -def train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, -): - """Train""" +class TrainArgs(TypedDict): + """Type for the arguments of the training function.""" + + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + train_loader: DataLoader + criterion: nn.CTCLoss + optimizer: optim.AdamW + scheduler: optim.lr_scheduler.OneCycleLR + epoch: int + iter_meter: IterMeter + + +def train(train_args) -> float: + """Train + Args: + model: model + device: device type + train_loader: train dataloader + criterion: loss function + optimizer: optimizer + scheduler: learning rate scheduler + epoch: epoch number + iter_meter: iteration meter + + Returns: + avg_train_loss: avg_train_loss for the epoch + + Information: + spectrograms: (batch, time, feature) + labels: (batch, label_length) + + model output: (batch,time, n_class) + + """ + # get values from train_args: + ( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) = train_args.values() + model.train() - print(f"Epoch: {epoch}") - losses = [] - for _data in tqdm(train_loader, desc="batches"): - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + print(f"training batch {epoch}") + train_losses = [] + for _data in tqdm(train_loader, desc="Training batches"): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) + train_losses.append(loss) loss.backward() optimizer.step() scheduler.step() iter_meter.step() + avg_train_loss = sum(train_losses) / len(train_losses) + print(f"Train set: Average loss: {avg_train_loss:.2f}") + return avg_train_loss + - losses.append(loss.item()) +class TestArgs(TypedDict): + """Type for the arguments of the test function.""" - print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") - return sum(losses) / len(losses) + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + test_loader: DataLoader + criterion: nn.CTCLoss + tokenizer: CharTokenizer + decoder: str -def test(model, device, test_loader, criterion, tokenizer): - """Test""" +def test(test_args: TestArgs) -> tuple[float, float, float]: print("\nevaluating...") + + # get values from test_args: + model, device, test_loader, criterion, tokenizer, decoder = test_args.values() + + if decoder == "greedy": + decoder = greedy_decoder + model.eval() test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for _data in test_loader: - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output=output.transpose(0, 1), - labels=labels, - label_lengths=_data["utterance_length"], - tokenizer=tokenizer, + output.transpose(0, 1), labels, label_lengths, tokenizer ) - for j, pred in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], pred)) - test_wer.append(wer(decoded_targets[j], pred)) + if i == 1: + print(f"decoding first sample: {decoded_preds}") + for j, _ in enumerate(decoded_preds): + test_cer.append(cer(decoded_targets[j], decoded_preds[j])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j])) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) print( - f"Test set: Average loss:\ - {test_loss}, Average CER: {None} Average WER: {None}\n" + f"Test set: \ + Average loss: {test_loss:.4f}, \ + Average CER: {avg_cer:4f} \ + Average WER: {avg_wer:.4f}\n" ) return test_loss, avg_cer, avg_wer -def run( +def main( learning_rate: float, batch_size: int, epochs: int, - load: bool, - path: str, dataset_path: str, language: str, -) -> None: - """Runs the training script.""" + limited_supervision: bool, + model_load_path: str, + model_save_path: str, + dataset_percentage: float, + eval_every: int, + num_workers: int, +): + """Main function for training the model. + + Args: + learning_rate: learning rate for the optimizer + batch_size: batch size + epochs: number of epochs to train + dataset_path: path for the dataset + language: language of the dataset + limited_supervision: whether to use only limited supervision + model_load_path: path to load a model from + model_save_path: path to save the model to + dataset_percentage: percentage of the dataset to use + eval_every: evaluate every n epochs + num_workers: number of workers for the dataloader + """ use_cuda = torch.cuda.is_available() - torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - # device = torch.device("mps") - # load dataset + torch.manual_seed(7) + + if not os.path.isdir(dataset_path): + os.makedirs(dataset_path) + train_dataset = MLSDataset( - dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True + dataset_path, + language, + Split.TEST, + download=True, + limited=limited_supervision, + size=dataset_percentage, ) valid_dataset = MLSDataset( - dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True + dataset_path, + language, + Split.TRAIN, + download=False, + limited=Falimited_supervisionlse, + size=dataset_percentage, ) - # load tokenizer (bpe by default): - if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): - print("There is no tokenizer available. Do you want to train it on the dataset?") - input("Press Enter to continue...") - train_char_tokenizer( - dataset_path=dataset_path, - language=language, - split="all", - out_path="data/tokenizers/char_tokenizer_german.json", - ) - - tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - - train_dataset.set_tokenizer(tokenizer) # type: ignore - valid_dataset.set_tokenizer(tokenizer) # type: ignore + # TODO: initialize and possibly train tokenizer if none found - print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") + kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} hparams = HParams( n_cnn_layers=3, @@ -192,24 +231,24 @@ def run( epochs=epochs, ) + train_data_processing = DataProcessing("train", tokenizer) + valid_data_processing = DataProcessing("valid", tokenizer) + train_loader = DataLoader( - train_dataset, + dataset=train_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: collate_fn(x), + collate_fn=train_data_processing, + **kwargs, ) - valid_loader = DataLoader( - valid_dataset, + dataset=valid_dataset, batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: collate_fn(x), + shuffle=False, + collate_fn=valid_data_processing, + **kwargs, ) - # enable flag to find the most compatible algorithms in advance - if use_cuda: - torch.backends.cudnn.benchmark = True # pylance: disable=no-member - model = SpeechRecognitionModel( hparams["n_cnn_layers"], hparams["n_rnn_layers"], @@ -219,16 +258,9 @@ def run( hparams["stride"], hparams["dropout"], ).to(device) - print(tokenizer.encode(" ")) - print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) + optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) - if load: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["model_state_dict"]) - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - epoch = checkpoint["epoch"] - loss = checkpoint["loss"] + criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hparams["learning_rate"], @@ -236,77 +268,63 @@ def run( epochs=hparams["epochs"], anneal_strategy="linear", ) + prev_epoch = 0 + + if model_load_path is not None: + checkpoint = torch.load(model_load_path) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - for epoch in range(1, epochs + 1): - loss = train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, + if not os.path.isdir(os.path.dirname(model_save_path)): + os.makedirs(os.path.dirname(model_save_path)) + for epoch in range(prev_epoch + 1, epochs + 1): + train_args: TrainArgs = dict( + model=model, + device=device, + train_loader=train_loader, + criterion=criterion, + optimizer=optimizer, + scheduler=scheduler, + epoch=epoch, + iter_meter=iter_meter, ) - test_loss, avg_cer, avg_wer = test( + train_loss = train(train_args) + + test_loss, test_cer, test_wer = 0, 0, 0 + + test_args: TestArgs = dict( model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer=tokenizer, + decoder="greedy", ) - print("saving epoch", str(epoch)) + + if epoch % eval_every == 0: + test_loss, test_cer, test_wer = test(test_args) + + if model_save_path is None: + continue + + if not os.path.isdir(os.path.dirname(model_save_path)): + os.makedirs(os.path.dirname(model_save_path)) torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), - "loss": loss, + "optimizer_state_dict": optimizer.state_dict(), + "train_loss": train_loss, "test_loss": test_loss, - "avg_cer": avg_cer, - "avg_wer": avg_wer, + "avg_cer": test_cer, + "avg_wer": test_wer, }, - path + str(epoch), - plot(epochs,path) + model_save_path + str(epoch), ) -@click.command() -@click.option("--learning_rate", default=1e-3, help="Learning rate") -@click.option("--batch_size", default=10, help="Batch size") -@click.option("--epochs", default=1, help="Number of epochs") -@click.option("--load", default=False, help="Do you want to load a model?") -@click.option( - "--path", - default="model", - help="Path where the model will be saved to/loaded from", -) -@click.option( - "--dataset_path", - default="data/", - help="Path for the dataset directory", -) -def run_cli( - learning_rate: float, - batch_size: int, - epochs: int, - load: bool, - path: str, - dataset_path: str, -) -> None: - """Runs the training script.""" - - run( - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - load=load, - path=path, - dataset_path=dataset_path, - language="mls_german_opus", - ) - - if __name__ == "__main__": - run_cli() # pylint: disable=no-value-for-parameter + main() # pylint: disable=no-value-for-parameter diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 74d10c9..e939e1d 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -76,20 +76,6 @@ def split_to_mls_split(split_name: Split) -> MLSSplit: return split_name # type: ignore -class Sample(TypedDict): - """Type for a sample in the dataset""" - - waveform: torch.Tensor - spectrogram: torch.Tensor - input_length: int - utterance: torch.Tensor - utterance_length: int - sample_rate: int - speaker_id: str - book_id: str - chapter_id: str - - class MLSDataset(Dataset): """Custom Dataset for reading Multilingual LibriSpeech diff --git a/swr2_asr/utils/loss_scores.py b/swr2_asr/utils/loss_scores.py new file mode 100644 index 0000000..80285f6 --- /dev/null +++ b/swr2_asr/utils/loss_scores.py @@ -0,0 +1,203 @@ +"""Methods for determining the loss and scores of the model.""" +import numpy as np + + +def avg_wer(wer_scores, combined_ref_len) -> float: + """Calculate the average word error rate (WER). + + Args: + wer_scores: word error rate scores + combined_ref_len: combined length of reference sentences + + Returns: + average word error rate (float) + + Usage: + >>> avg_wer([0.5, 0.5], 2) + 0.5 + """ + return float(sum(wer_scores)) / float(combined_ref_len) + + +def _levenshtein_distance(ref, hyp) -> int: + """Levenshtein distance. + + Args: + ref: reference sentence + hyp: hypothesis sentence + + Returns: + distance: levenshtein distance between reference and hypothesis + + Usage: + >>> _levenshtein_distance("hello", "helo") + 2 + """ + len_ref = len(ref) + len_hyp = len(hyp) + + # special case + if ref == hyp: + return 0 + if len_ref == 0: + return len_hyp + if len_hyp == 0: + return len_ref + + if len_ref < len_hyp: + ref, hyp = hyp, ref + len_ref, len_hyp = len_hyp, len_ref + + # use O(min(m, n)) space + distance = np.zeros((2, len_hyp + 1), dtype=np.int32) + + # initialize distance matrix + for j in range(0, len_hyp + 1): + distance[0][j] = j + + # calculate levenshtein distance + for i in range(1, len_ref + 1): + prev_row_idx = (i - 1) % 2 + cur_row_idx = i % 2 + distance[cur_row_idx][0] = i + for j in range(1, len_hyp + 1): + if ref[i - 1] == hyp[j - 1]: + distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] + else: + s_num = distance[prev_row_idx][j - 1] + 1 + i_num = distance[cur_row_idx][j - 1] + 1 + d_num = distance[prev_row_idx][j] + 1 + distance[cur_row_idx][j] = min(s_num, i_num, d_num) + + return distance[len_ref % 2][len_hyp] + + +def word_errors( + reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " +) -> tuple[float, int]: + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in word-level. + + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + delimiter: Delimiter of input sentences. + + Returns: + Levenshtein distance and length of reference sentence. + + Usage: + >>> word_errors("hello world", "hello") + 1, 2 + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + ref_words = reference.split(delimiter) + hyp_words = hypothesis.split(delimiter) + + edit_distance = _levenshtein_distance(ref_words, hyp_words) + return float(edit_distance), len(ref_words) + + +def char_errors( + reference: str, + hypothesis: str, + ignore_case: bool = False, + remove_space: bool = False, +) -> tuple[float, int]: + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in char-level. + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + remove_space: Whether remove internal space characters + + Returns: + Levenshtein distance and length of reference sentence. + + Usage: + >>> char_errors("hello world", "hello") + 1, 10 + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + join_char = " " + if remove_space: + join_char = "" + + reference = join_char.join(filter(None, reference.split(" "))) + hypothesis = join_char.join(filter(None, hypothesis.split(" "))) + + edit_distance = _levenshtein_distance(reference, hypothesis) + return float(edit_distance), len(reference) + + +def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" ") -> float: + """Calculate word error rate (WER). WER compares reference text and + hypothesis text in word-level. + WER is defined as: + WER = (Sw + Dw + Iw) / Nw + with: + Sw is the number of words subsituted, + Dw is the number of words deleted, + Iw is the number of words inserted, + Nw is the number of words in the reference + + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + delimiter: Delimiter of input sentences. + + Returns: + Word error rate (float) + + Usage: + >>> wer("hello world", "hello") + 0.5 + """ + edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter) + + if ref_len == 0: + raise ValueError("Reference's word number should be greater than 0.") + + word_error_rate = float(edit_distance) / ref_len + return word_error_rate + + +def cer(reference, hypothesis, ignore_case=False, remove_space=False): + """Calculate charactor error rate (CER). CER compares reference text and + hypothesis text in char-level. CER is defined as: + CER = (Sc + Dc + Ic) / Nc + with + Sc is the number of characters substituted, + Dc is the number of characters deleted, + Ic is the number of characters inserted + Nc is the number of characters in the reference + + Args: + reference: The reference sentence. + hypothesis: The hypothesis sentence. + ignore_case: Whether case-sensitive or not. + remove_space: Whether remove internal space characters + + Returns: + Character error rate (float) + + Usage: + >>> cer("hello world", "hello") + 0.2727272727272727 + """ + edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space) + + if ref_len == 0: + raise ValueError("Length of reference should be greater than 0.") + + char_error_rate = float(edit_distance) / ref_len + return char_error_rate -- cgit v1.2.3 From 8be140b38183b7465b5888a15b536a5f7fa66db6 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 20:45:32 +0200 Subject: added tokenizer to git and tokenizer training routing --- .gitignore | 11 +-- data/tokenizers/char_tokenizer_german.json | 38 ++++++++++ swr2_asr/utils/tokenizer.py | 110 ++++++++++++++++------------- 3 files changed, 101 insertions(+), 58 deletions(-) create mode 100644 data/tokenizers/char_tokenizer_german.json (limited to 'swr2_asr') diff --git a/.gitignore b/.gitignore index 8e64e4b..d21ddb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # Training files -data/ +data/* +!data/tokenizers + # Mac **/.DS_Store @@ -163,10 +165,3 @@ dmypy.json # Cython debug symbols cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ diff --git a/data/tokenizers/char_tokenizer_german.json b/data/tokenizers/char_tokenizer_german.json new file mode 100644 index 0000000..20db079 --- /dev/null +++ b/data/tokenizers/char_tokenizer_german.json @@ -0,0 +1,38 @@ +_ 0 + 1 + 2 + 3 +a 4 +b 5 +c 6 +d 7 +e 8 +f 9 +g 10 +h 11 +i 12 +j 13 +k 14 +l 15 +m 16 +n 17 +o 18 +p 19 +q 20 +r 21 +s 22 +t 23 +u 24 +v 25 +w 26 +x 27 +y 28 +z 29 +é 30 +à 31 +ä 32 +ö 33 +ß 34 +ü 35 +- 36 +' 37 diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index d92465a..5482bbe 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -1,59 +1,18 @@ """Tokenizer for Multilingual Librispeech datasets""" +from datetime import datetime +import os + +from tqdm.autonotebook import tqdm + + class CharTokenizer: """Maps characters to integers and vice versa""" def __init__(self): - char_map_str = """ - _ - - - - a - b - c - d - e - f - g - h - i - j - k - l - m - n - o - p - q - r - s - t - u - v - w - x - y - z - é - à - ä - ö - ß - ü - - - ' - - """ - self.char_map = {} self.index_map = {} - for idx, char in enumerate(char_map_str.strip().split("\n")): - char = char.strip() - self.char_map[char] = idx - self.index_map[idx] = char - self.index_map[1] = " " def encode(self, text: str) -> list[int]: """Use a character map and convert text to an integer sequence""" @@ -91,7 +50,59 @@ class CharTokenizer: """Get the integer representation of the character""" return self.char_map[""] - # TODO: add train function + @staticmethod + def train(dataset_path: str, language: str) -> "CharTokenizer": + """Train the tokenizer on a dataset""" + chars = set() + root_path = os.path.join(dataset_path, language) + for split in os.listdir(root_path): + split_dir = os.path.join(root_path, split) + if os.path.isdir(split_dir): + transcript_path = os.path.join(split_dir, "transcripts.txt") + + with open(transcript_path, "r", encoding="utf-8") as transcrips: + lines = transcrips.readlines() + lines = [line.split(" ", 1)[1] for line in lines] + lines = [line.strip() for line in lines] + lines = [line.lower() for line in lines] + + for line in tqdm(lines, desc=f"Training tokenizer on {split_dir} split"): + chars.update(line) + + # sort chars + chars.remove(" ") + chars = sorted(chars) + + train_tokenizer = CharTokenizer() + + train_tokenizer.char_map["_"] = 0 + train_tokenizer.char_map[""] = 1 + train_tokenizer.char_map[""] = 2 + train_tokenizer.char_map[""] = 3 + + train_tokenizer.index_map[0] = "_" + train_tokenizer.index_map[1] = "" + train_tokenizer.index_map[2] = "" + train_tokenizer.index_map[3] = "" + + offset = 4 + + for idx, char in enumerate(chars): + idx += offset + train_tokenizer.char_map[char] = idx + train_tokenizer.index_map[idx] = char + + train_tokenizer_dir = os.path.join("data/tokenizers") + train_tokenizer_path = os.path.join( + train_tokenizer_dir, + f"char_tokenizer_{language}_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.json", + ) + + if not os.path.exists(os.path.dirname(train_tokenizer_dir)): + os.makedirs(train_tokenizer_dir) + train_tokenizer.save(train_tokenizer_path) + + return train_tokenizer def save(self, path: str) -> None: """Save the tokenizer to a file""" @@ -114,8 +125,7 @@ class CharTokenizer: if __name__ == "__main__": - tokenizer = CharTokenizer() - tokenizer.save("data/tokenizers/char_tokenizer_german.json") + tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus") print(tokenizer.char_map) print(tokenizer.index_map) print(tokenizer.get_vocab_size()) -- cgit v1.2.3 From 58b30927bd870604a4077a8af9ec3cad7b0be21c Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 21:52:42 +0200 Subject: changed config to yaml! --- config.philipp.yaml | 29 ++++++ config.train.yaml | 28 ++++++ poetry.lock | 51 ++++++++++- pyproject.toml | 1 + requirements.txt | 1 + swr2_asr/__main__.py | 12 --- swr2_asr/inference.py | 16 ++-- swr2_asr/model_deep_speech.py | 17 ---- swr2_asr/train.py | 192 +++++++++++++++++++--------------------- swr2_asr/utils/data.py | 7 +- swr2_asr/utils/tokenizer.py | 8 +- swr2_asr/utils/visualization.py | 8 +- 12 files changed, 218 insertions(+), 152 deletions(-) create mode 100644 config.philipp.yaml create mode 100644 config.train.yaml delete mode 100644 swr2_asr/__main__.py (limited to 'swr2_asr') diff --git a/config.philipp.yaml b/config.philipp.yaml new file mode 100644 index 0000000..638b5ef --- /dev/null +++ b/config.philipp.yaml @@ -0,0 +1,29 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 0.0005 + batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 1 # evaluate every n epochs + num_workers: 4 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: True # set to True if you want to use limited supervision + dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) + shuffle: True + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.json" + +checkpoints: + model_load_path: ~ # path to load model from + model_save_path: ~ # path to save model to \ No newline at end of file diff --git a/config.train.yaml b/config.train.yaml new file mode 100644 index 0000000..c82439d --- /dev/null +++ b/config.train.yaml @@ -0,0 +1,28 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 5e-4 + batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 3 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" + +checkpoints: + model_load_path: "YOUR/PATH" # path to load model from + model_save_path: "YOUR/PATH" # path to save model to \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 3901b8c..a1f916b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1083,6 +1083,55 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "ruff" version = "0.0.285" @@ -1482,4 +1531,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "b9efbbcd85e7d70578496491d81aa6ef8a610a77ffe134c08446300d5de42ed5" +content-hash = "e45a9c1ba8b67cbe83c4b010c3f4718eee990b064b90a3ccd64380387e734faf" diff --git a/pyproject.toml b/pyproject.toml index 38cc51a..f6d19dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ mido = "^1.3.0" tokenizers = "^0.13.3" click = "^8.1.7" matplotlib = "^3.7.2" +pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] black = "^23.7.0" diff --git a/requirements.txt b/requirements.txt index 3b39b56..040fed0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,7 @@ platformdirs==3.10.0 pylint==2.17.5 pyparsing==3.0.9 python-dateutil==2.8.2 +PyYAML==6.0.1 ruff==0.0.285 six==1.16.0 sympy==1.12 diff --git a/swr2_asr/__main__.py b/swr2_asr/__main__.py deleted file mode 100644 index be294fb..0000000 --- a/swr2_asr/__main__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Main entrypoint for swr2-asr.""" -import torch -import torchaudio - -if __name__ == "__main__": - # test if GPU is available - print("GPU available: ", torch.cuda.is_available()) - - # test if torchaudio is installed correctly - print("torchaudio version: ", torchaudio.__version__) - print("torchaudio backend: ", torchaudio.get_audio_backend()) - print("torchaudio info: ", torchaudio.get_audio_backend()) diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index c3eec42..f8342f7 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,11 +1,12 @@ """Training script for the ASR model.""" +from typing import TypedDict + import torch -import torchaudio import torch.nn.functional as F -from typing import TypedDict +import torchaudio -from swr2_asr.tokenizer import CharTokenizer from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.tokenizer import CharTokenizer class HParams(TypedDict): @@ -28,8 +29,7 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member blank_label = tokenizer.encode(" ").ids[0] decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): + for _i, args in enumerate(arg_maxes): decode = [] for j, index in enumerate(args): if index != blank_label: @@ -44,7 +44,7 @@ def main() -> None: """inference function.""" device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) + device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") @@ -90,7 +90,7 @@ def main() -> None: model.load_state_dict(state_dict) # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") + waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member if sample_rate != spectrogram_hparams["sample_rate"]: resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) waveform = resampler(waveform) @@ -103,7 +103,7 @@ def main() -> None: specs = [spec] specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) - output = model(specs) + output = model(specs) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) decodes = greedy_decoder(output, tokenizer) diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index 8ddbd99..77f4c8a 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -3,27 +3,10 @@ Following definition by Assembly AI (https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) """ -from typing import TypedDict - import torch.nn.functional as F from torch import nn -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ac7666b..eb79ee2 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,11 +5,12 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +import yaml from torch import nn, optim from torch.utils.data import DataLoader from tqdm.autonotebook import tqdm -from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel +from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split from swr2_asr.utils.decoder import greedy_decoder from swr2_asr.utils.tokenizer import CharTokenizer @@ -17,7 +18,7 @@ from swr2_asr.utils.tokenizer import CharTokenizer from .utils.loss_scores import cer, wer -class IterMeter(object): +class IterMeter: """keeps track of total iterations""" def __init__(self): @@ -116,6 +117,7 @@ class TestArgs(TypedDict): def test(test_args: TestArgs) -> tuple[float, float, float]: + """Test""" print("\nevaluating...") # get values from test_args: @@ -128,7 +130,7 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")): + for _data in tqdm(test_loader, desc="Validation Batches"): spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) @@ -142,8 +144,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: decoded_preds, decoded_targets = greedy_decoder( output.transpose(0, 1), labels, label_lengths, tokenizer ) - if i == 1: - print(f"decoding first sample: {decoded_preds}") for j, _ in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], decoded_preds[j])) test_wer.append(wer(decoded_targets[j], decoded_preds[j])) @@ -161,157 +161,149 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: return test_loss, avg_cer, avg_wer -def main( - learning_rate: float, - batch_size: int, - epochs: int, - dataset_path: str, - language: str, - limited_supervision: bool, - model_load_path: str, - model_save_path: str, - dataset_percentage: float, - eval_every: int, - num_workers: int, -): +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +def main(config_path: str): """Main function for training the model. - Args: - learning_rate: learning rate for the optimizer - batch_size: batch size - epochs: number of epochs to train - dataset_path: path for the dataset - language: language of the dataset - limited_supervision: whether to use only limited supervision - model_load_path: path to load a model from - model_save_path: path to save the model to - dataset_percentage: percentage of the dataset to use - eval_every: evaluate every n epochs - num_workers: number of workers for the dataloader + Gets all configuration arguments from yaml config file. """ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member torch.manual_seed(7) - if not os.path.isdir(dataset_path): - os.makedirs(dataset_path) + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + training_config = config_dict.get("training", {}) + dataset_config = config_dict.get("dataset", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + checkpoints_config = config_dict.get("checkpoints", {}) + + print(training_config["learning_rate"]) + + if not os.path.isdir(dataset_config["dataset_root_path"]): + os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TEST, - download=True, - limited=limited_supervision, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) valid_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TRAIN, - download=False, - limited=Falimited_supervisionlse, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) - # TODO: initialize and possibly train tokenizer if none found - - kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - ) + kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {} + + if tokenizer_config["tokenizer_path"] is None: + print("Tokenizer not found!") + if click.confirm("Do you want to train a new tokenizer?", default=True): + pass + else: + return + tokenizer = CharTokenizer.train( + dataset_config["dataset_root_path"], dataset_config["language_name"] + ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) train_data_processing = DataProcessing("train", tokenizer) valid_data_processing = DataProcessing("valid", tokenizer) train_loader = DataLoader( dataset=train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=train_data_processing, **kwargs, ) valid_loader = DataLoader( dataset=valid_dataset, - batch_size=hparams["batch_size"], - shuffle=False, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=valid_data_processing, **kwargs, ) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"]) criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=hparams["learning_rate"], + max_lr=training_config["learning_rate"], steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], + epochs=training_config["epochs"], anneal_strategy="linear", ) prev_epoch = 0 - if model_load_path is not None: - checkpoint = torch.load(model_load_path) + if checkpoints_config["model_load_path"] is not None: + checkpoint = torch.load(checkpoints_config["model_load_path"]) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) - for epoch in range(prev_epoch + 1, epochs + 1): - train_args: TrainArgs = dict( - model=model, - device=device, - train_loader=train_loader, - criterion=criterion, - optimizer=optimizer, - scheduler=scheduler, - epoch=epoch, - iter_meter=iter_meter, - ) + + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): + train_args: TrainArgs = { + "model": model, + "device": device, + "train_loader": train_loader, + "criterion": criterion, + "optimizer": optimizer, + "scheduler": scheduler, + "epoch": epoch, + "iter_meter": iter_meter, + } train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = dict( - model=model, - device=device, - test_loader=valid_loader, - criterion=criterion, - tokenizer=tokenizer, - decoder="greedy", - ) + test_args: TestArgs = { + "model": model, + "device": device, + "test_loader": valid_loader, + "criterion": criterion, + "tokenizer": tokenizer, + "decoder": "greedy", + } - if epoch % eval_every == 0: + if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: test_loss, test_cer, test_wer = test(test_args) - if model_save_path is None: + if checkpoints_config["model_save_path"] is None: continue - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) + if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])): + os.makedirs(os.path.dirname(checkpoints_config["model_save_path"])) + torch.save( { "epoch": epoch, @@ -322,7 +314,7 @@ def main( "avg_cer": test_cer, "avg_wer": test_wer, }, - model_save_path + str(epoch), + checkpoints_config["model_save_path"] + str(epoch), ) diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index e939e1d..0e06eec 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -1,13 +1,12 @@ """Class containing utils for the ASR system.""" import os from enum import Enum -from typing import TypedDict import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -125,7 +124,7 @@ class MLSDataset(Dataset): self._handle_download_dataset(download) self._validate_local_directory() - if limited and (split == Split.TRAIN or split == Split.VALID): + if limited and split in (Split.TRAIN, Split.VALID): self.initialize_limited() else: self.initialize() @@ -351,8 +350,6 @@ class MLSDataset(Dataset): if __name__ == "__main__": - from torch.utils.data import DataLoader - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" LANGUAGE = "mls_german_opus" split = Split.DEV diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 5482bbe..22569eb 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -1,8 +1,6 @@ """Tokenizer for Multilingual Librispeech datasets""" - - -from datetime import datetime import os +from datetime import datetime from tqdm.autonotebook import tqdm @@ -119,8 +117,8 @@ class CharTokenizer: line = line.strip() if line: char, index = line.split() - tokenizer.char_map[char] = int(index) - tokenizer.index_map[int(index)] = char + load_tokenizer.char_map[char] = int(index) + load_tokenizer.index_map[int(index)] = char return load_tokenizer diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index 80f942a..a55d0d5 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -6,10 +6,10 @@ import torch def plot(epochs, path): """Plots the losses over the epochs""" - losses = list() - test_losses = list() - cers = list() - wers = list() + losses = [] + test_losses = [] + cers = [] + wers = [] for epoch in range(1, epochs + 1): current_state = torch.load(path + str(epoch)) losses.append(current_state["loss"]) -- cgit v1.2.3 From 64dbb9d32a51b1bce6c9de67069dc8f5943a5399 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 22:16:26 +0200 Subject: added n_feats from config --- config.philipp.yaml | 2 +- swr2_asr/train.py | 4 ++-- swr2_asr/utils/data.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) (limited to 'swr2_asr') diff --git a/config.philipp.yaml b/config.philipp.yaml index 638b5ef..6b905cd 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -4,7 +4,7 @@ model: rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets training: learning_rate: 0.0005 diff --git a/swr2_asr/train.py b/swr2_asr/train.py index eb79ee2..ca70d21 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -223,8 +223,8 @@ def main(config_path: str): ) tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) - train_data_processing = DataProcessing("train", tokenizer) - valid_data_processing = DataProcessing("valid", tokenizer) + train_data_processing = DataProcessing("train", tokenizer, {"n_feats": model_config["n_feats"]}) + valid_data_processing = DataProcessing("valid", tokenizer, {"n_feats": model_config["n_feats"]}) train_loader = DataLoader( dataset=train_dataset, diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 0e06eec..10f0ea8 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -15,18 +15,19 @@ from swr2_asr.utils.tokenizer import CharTokenizer class DataProcessing: """Data processing class for the dataloader""" - def __init__(self, data_type: str, tokenizer: CharTokenizer): + def __init__(self, data_type: str, tokenizer: CharTokenizer, hparams: dict): self.data_type = data_type self.tokenizer = tokenizer + n_features = hparams["n_feats"] if data_type == "train": self.audio_transform = torch.nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=n_features), torchaudio.transforms.FrequencyMasking(freq_mask_param=30), torchaudio.transforms.TimeMasking(time_mask_param=100), ) elif data_type == "valid": - self.audio_transform = torchaudio.transforms.MelSpectrogram() + self.audio_transform = torchaudio.transforms.MelSpectrogram(n_mels=n_features) def __call__(self, data) -> tuple[Tensor, Tensor, list, list]: spectrograms = [] -- cgit v1.2.3 From 6f5513140f153206cfa91df3077e67ce58043d35 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 22:58:19 +0200 Subject: model loading is broken :( --- config.philipp.yaml | 9 +++- config.train.yaml | 28 ---------- config.yaml | 34 ++++++++++++ swr2_asr/inference.py | 140 ++++++++++++++++++++++---------------------------- swr2_asr/train.py | 2 +- 5 files changed, 103 insertions(+), 110 deletions(-) delete mode 100644 config.train.yaml create mode 100644 config.yaml (limited to 'swr2_asr') diff --git a/config.philipp.yaml b/config.philipp.yaml index 6b905cd..4a723c6 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -12,6 +12,7 @@ training: epochs: 3 eval_every_n: 1 # evaluate every n epochs num_workers: 4 # number of workers for dataloader + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically dataset: download: True @@ -25,5 +26,9 @@ tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" checkpoints: - model_load_path: ~ # path to load model from - model_save_path: ~ # path to save model to \ No newline at end of file + model_load_path: "data/runs/epoch30" # path to load model from + model_save_path: ~ # path to save model to + +inference: + model_load_path: "data/runs/epoch30" # path to load model from + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file diff --git a/config.train.yaml b/config.train.yaml deleted file mode 100644 index c82439d..0000000 --- a/config.train.yaml +++ /dev/null @@ -1,28 +0,0 @@ -model: - n_cnn_layers: 3 - n_rnn_layers: 5 - rnn_dim: 512 - n_feats: 128 # number of mel features - stride: 2 - dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets - -training: - learning_rate: 5e-4 - batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 3 - eval_every_n: 3 # evaluate every n epochs - num_workers: 8 # number of workers for dataloader - -dataset: - download: True - dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir - language_name: "mls_german_opus" - limited_supervision: False # set to True if you want to use limited supervision - dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) - -tokenizer: - tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" - -checkpoints: - model_load_path: "YOUR/PATH" # path to load model from - model_save_path: "YOUR/PATH" # path to save model to \ No newline at end of file diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..e5ff43a --- /dev/null +++ b/config.yaml @@ -0,0 +1,34 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 5e-4 + batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 3 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + shuffle: True + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" + +checkpoints: + model_load_path: "YOUR/PATH" # path to load model from + model_save_path: "YOUR/PATH" # path to save model to + +inference: + model_load_path: "YOUR/PATH" # path to load model from + beam_width: 10 # beam width for beam search + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index f8342f7..6495a9a 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,35 +1,20 @@ """Training script for the ASR model.""" -from typing import TypedDict - +import click import torch import torch.nn.functional as F import torchaudio +import yaml from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.tokenizer import CharTokenizer -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, collapse_repeated=True): +def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] + blank_label = tokenizer.get_blank_token() decodes = [] - for _i, args in enumerate(arg_maxes): + for args in arg_maxes: decode = [] for j, index in enumerate(args): if index != blank_label: @@ -40,75 +25,72 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): return decodes -def main() -> None: +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +@click.option( + "--file_path", + help="Path to audio file", + type=click.Path(exists=True), +) +def main(config_path: str, file_path: str) -> None: """inference function.""" - - device = "cuda" if torch.cuda.is_available() else "cpu" + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + inference_config = config_dict.get("inference", {}) + + if inference_config["device"] == "cpu": + device = "cpu" + elif inference_config["device"] == "cuda": + device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # pylint: disable=no-member - tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") - - spectrogram_hparams = { - "sample_rate": 16000, - "n_fft": 400, - "win_length": 400, - "hop_length": 160, - "n_mels": 128, - "f_min": 0, - "f_max": 8000, - "power": 2.0, - } - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=0.1, - batch_size=30, - epochs=100, - ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - checkpoint = torch.load("model8", map_location=device) - state_dict = { - k[len("module.") :] if k.startswith("module.") else k: v - for k, v in checkpoint["model_state_dict"].items() - } - model.load_state_dict(state_dict) - - # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member - if sample_rate != spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) + checkpoint = torch.load(inference_config["model_load_path"], map_location=device) + print(checkpoint["model_state_dict"].keys()) + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + model.eval() + waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member + if waveform.shape[0] != 1: + waveform = waveform[1] + waveform = waveform.unsqueeze(0) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) + sample_rate = 16000 + + data_processing = torchaudio.transforms.MelSpectrogram(n_mels=model_config["n_feats"]) + + spec = data_processing(waveform).squeeze(0).transpose(0, 1) - spec = ( - torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform) - .squeeze(0) - .transpose(0, 1) - ) - specs = [spec] - specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) + spec = spec.unsqueeze(0) + spec = spec.transpose(1, 2) + spec = spec.unsqueeze(0) + output = model(spec) # pylint: disable=not-callable + output = F.log_softmax(output, dim=2) # (batch, time, n_class) + decoded_preds = greedy_decoder(output, tokenizer) - output = model(specs) # pylint: disable=not-callable - output = F.log_softmax(output, dim=2) - output = output.transpose(0, 1) # (time, batch, n_class) - decodes = greedy_decoder(output, tokenizer) - print(decodes) + print(decoded_preds) if __name__ == "__main__": - main() + main() # pylint: disable=no-value-for-parameter diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ca70d21..ec25918 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -263,7 +263,7 @@ def main(config_path: str): prev_epoch = 0 if checkpoints_config["model_load_path"] is not None: - checkpoint = torch.load(checkpoints_config["model_load_path"]) + checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] -- cgit v1.2.3 From 96fee5f59f67187292ddf37db4660c5085fb66b5 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 23:08:45 +0200 Subject: changed name to match pre-trained weights --- swr2_asr/inference.py | 4 +-- swr2_asr/model_deep_speech.py | 68 +++++++++++++++---------------------------- 2 files changed, 25 insertions(+), 47 deletions(-) (limited to 'swr2_asr') diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index 6495a9a..3f6a44e 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -66,9 +66,9 @@ def main(config_path: str, file_path: str) -> None: ).to(device) checkpoint = torch.load(inference_config["model_load_path"], map_location=device) - print(checkpoint["model_state_dict"].keys()) - model.load_state_dict(checkpoint["model_state_dict"], strict=False) + model.load_state_dict(checkpoint["model_state_dict"], strict=True) model.eval() + waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member if waveform.shape[0] != 1: waveform = waveform[1] diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index 77f4c8a..73f5a81 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -10,8 +10,8 @@ from torch import nn class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" - def __init__(self, n_feats: int): - super().__init__() + def __init__(self, n_feats): + super(CNNLayerNorm, self).__init__() self.layer_norm = nn.LayerNorm(n_feats) def forward(self, data): @@ -22,34 +22,22 @@ class CNNLayerNorm(nn.Module): class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf + except with layer norm instead of batch norm + """ - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super().__init__() + def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats): + super(ResidualCNN, self).__init__() self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) + self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.layer_norm1 = CNNLayerNorm(n_feats) self.layer_norm2 = CNNLayerNorm(n_feats) def forward(self, data): - """x (batch, channel, feature, time)""" + """data (batch, channel, feature, time)""" residual = data # (batch, channel, feature, time) data = self.layer_norm1(data) data = F.gelu(data) @@ -64,18 +52,12 @@ class ResidualCNN(nn.Module): class BidirectionalGRU(nn.Module): - """Bidirectional GRU with Layer Normalization and Dropout""" + """Bidirectional GRU layer""" - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super().__init__() + def __init__(self, rnn_dim, hidden_size, dropout, batch_first): + super(BidirectionalGRU, self).__init__() - self.bi_gru = nn.GRU( + self.BiGRU = nn.GRU( # pylint: disable=invalid-name input_size=rnn_dim, hidden_size=hidden_size, num_layers=1, @@ -86,11 +68,11 @@ class BidirectionalGRU(nn.Module): self.dropout = nn.Dropout(dropout) def forward(self, data): - """data (batch, time, feature)""" + """x (batch, time, feature)""" data = self.layer_norm(data) data = F.gelu(data) + data, _ = self.BiGRU(data) data = self.dropout(data) - data, _ = self.bi_gru(data) return data @@ -98,18 +80,14 @@ class SpeechRecognitionModel(nn.Module): """Speech Recognition Model Inspired by DeepSpeech 2""" def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, + self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1 ): - super().__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + super(SpeechRecognitionModel, self).__init__() + n_feats = n_feats // 2 + self.cnn = nn.Conv2d( + 1, 32, 3, stride=stride, padding=3 // 2 + ) # cnn for extracting heirachal features + # n residual cnn layers with filter size of 32 self.rescnn_layers = nn.Sequential( *[ @@ -137,7 +115,7 @@ class SpeechRecognitionModel(nn.Module): ) def forward(self, data): - """data (batch, channel, feature, time)""" + """x (batch, channel, feature, time)""" data = self.cnn(data) data = self.rescnn_layers(data) sizes = data.size() -- cgit v1.2.3 From 4aff1fcd70cd8601541a1dd5bd820b0263ed1362 Mon Sep 17 00:00:00 2001 From: Philipp Merkel Date: Mon, 11 Sep 2023 22:36:28 +0000 Subject: fix: switched up training and test splits in train.py --- config.philipp.yaml | 22 +++++++++++----------- swr2_asr/train.py | 8 +++----- swr2_asr/utils/data.py | 31 ------------------------------- swr2_asr/utils/tokenizer.py | 12 ------------ 4 files changed, 14 insertions(+), 59 deletions(-) (limited to 'swr2_asr') diff --git a/config.philipp.yaml b/config.philipp.yaml index 4a723c6..f72ce2e 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -4,30 +4,30 @@ model: rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets + dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets training: learning_rate: 0.0005 - batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 3 - eval_every_n: 1 # evaluate every n epochs + batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 150 + eval_every_n: 5 # evaluate every n epochs num_workers: 4 # number of workers for dataloader device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically dataset: - download: True - dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir + download: true + dataset_root_path: "data" # files will be downloaded into this dir language_name: "mls_german_opus" - limited_supervision: True # set to True if you want to use limited supervision - dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) - shuffle: True + limited_supervision: false # set to True if you want to use limited supervision + dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%) + shuffle: true tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" checkpoints: - model_load_path: "data/runs/epoch30" # path to load model from - model_save_path: ~ # path to save model to + model_load_path: "data/runs/epoch31" # path to load model from + model_save_path: "data/runs/epoch" # path to save model to inference: model_load_path: "data/runs/epoch30" # path to load model from diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ec25918..3ed3ac8 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -187,16 +187,14 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) - - print(training_config["learning_rate"]) - + if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TEST, + Split.TRAIN, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], @@ -204,7 +202,7 @@ def main(config_path: str): valid_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TRAIN, + Split.TEST, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 10f0ea8..d551c98 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -134,11 +134,6 @@ class MLSDataset(Dataset): def initialize_limited(self) -> None: """Initializes the limited supervision dataset""" - # get file handles - # get file paths - # get transcripts - # create train or validation split - handles = set() train_root_path = os.path.join(self.dataset_path, self.language, "train") @@ -348,29 +343,3 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore - - -if __name__ == "__main__": - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" - LANGUAGE = "mls_german_opus" - split = Split.DEV - DOWNLOAD = False - - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD) - - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=True, - collate_fn=DataProcessing( - "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - ), - ) - - for batch in dataloader: - print(batch) - break - - print(len(dataset)) - - print(dataset[0]) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 22569eb..1cc7b84 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,15 +120,3 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer - - -if __name__ == "__main__": - tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus") - print(tokenizer.char_map) - print(tokenizer.index_map) - print(tokenizer.get_vocab_size()) - print(tokenizer.get_blank_token()) - print(tokenizer.get_unk_token()) - print(tokenizer.get_space_token()) - print(tokenizer.encode("hallo welt")) - print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) -- cgit v1.2.3 From 7b71dab87591e04d874cd636614450b0e65e3f2b Mon Sep 17 00:00:00 2001 From: Pherkel Date: Tue, 12 Sep 2023 14:14:19 +0200 Subject: fixed black formatting issue --- swr2_asr/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'swr2_asr') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 3ed3ac8..ffdae73 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -187,7 +187,7 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) - + if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) -- cgit v1.2.3