From 2fd748fd6339573f1ce5b961a563197a3eacf284 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Sun, 20 Aug 2023 13:15:11 +0200 Subject: adjusted in train.py --- swr2_asr/train.py | 103 ++++++------------------------------------------------ 1 file changed, 11 insertions(+), 92 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 29f9372..2628028 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,93 +1,14 @@ """Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech import click import torch -import torch.nn as nn -import torch.optim as optim import torch.nn.functional as F -from torch.utils.data import DataLoader import torchaudio -from .loss_scores import cer, wer - - -class TextTransform: - """Maps characters to integers and vice versa""" - - def __init__(self): - char_map_str = """ - ' 0 - 1 - a 2 - b 3 - c 4 - d 5 - e 6 - f 7 - g 8 - h 9 - i 10 - j 11 - k 12 - l 13 - m 14 - n 15 - o 16 - p 17 - q 18 - r 19 - s 20 - t 21 - u 22 - v 23 - w 24 - x 25 - y 26 - z 27 - ä 28 - ö 29 - ü 30 - ß 31 - - 32 - é 33 - è 34 - à 35 - ù 36 - ç 37 - â 38 - ê 39 - î 40 - ô 41 - û 42 - ë 43 - ï 44 - ü 45 - """ - self.char_map = {} - self.index_map = {} - for line in char_map_str.strip().split("\n"): - char, index = line.split() - self.char_map[char] = int(index) - self.index_map[int(index)] = char - self.index_map[1] = " " - - def text_to_int(self, text): - """Use a character map and convert text to an integer sequence""" - int_sequence = [] - for char in text: - if char == " ": - mapped_char = self.char_map[""] - else: - mapped_char = self.char_map[char] - int_sequence.append(mapped_char) - return int_sequence - - def int_to_text(self, labels): - """Use a character map and convert integer labels to an text sequence""" - string = [] - for i in labels: - string.append(self.index_map[i]) - return "".join(string).replace("", " ") +from AudioLoader.speech import MultilingualLibriSpeech +from torch import nn, optim +from torch.utils.data import DataLoader +from tokenizers import Tokenizer +from .loss_scores import cer, wer train_audio_transforms = nn.Sequential( torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), @@ -97,7 +18,7 @@ train_audio_transforms = nn.Sequential( valid_audio_transforms = torchaudio.transforms.MelSpectrogram() -text_transform = TextTransform() +text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") def data_processing(data, data_type="train"): @@ -114,7 +35,7 @@ def data_processing(data, data_type="train"): else: raise ValueError("data_type should be train or valid") spectrograms.append(spec) - label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower())) + label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) labels.append(label) input_lengths.append(spec.shape[0] // 2) label_lengths.append(len(label)) @@ -138,15 +59,13 @@ def greedy_decoder( targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append( - text_transform.int_to_text(labels[i][: label_lengths[i]].tolist()) - ) + targets.append(text_transform.decode(labels[i][: label_lengths[i]].tolist())) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: continue decode.append(index.item()) - decodes.append(text_transform.int_to_text(decode)) + decodes.append(text_transform.decode(decode)) return decodes, targets @@ -407,10 +326,10 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No # device = torch.device("mps") train_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False + "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="dev", download=False ) test_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False + "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="test", download=False ) kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} -- cgit v1.2.3 From 3ae21cbc432113531aa15e0cebd8a34c3767ba35 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Sun, 20 Aug 2023 14:52:15 +0200 Subject: added todos --- swr2_asr/tokenizer.py | 6 +++++- swr2_asr/train.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index d9cd622..79d6727 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -26,7 +26,7 @@ class CharTokenizer: Simply checks what characters are in the dataset and uses them as tokens. Exposes the same interface as tokenizers from the huggingface library, i.e. - encode, decode, decode_batch, save, from_file and train. + encode, decode, decode_batch, get_vocab_size, save, from_file and train. """ def __init__(self): @@ -140,6 +140,10 @@ class CharTokenizer: strings.append("".join(string).replace("", " ")) return strings + def get_vocab_size(self): + """Get the size of the vocabulary""" + return len(self.char_map) + def save(self, path: str): """Save the tokenizer to a file""" with open(path, "w", encoding="utf-8") as file: diff --git a/swr2_asr/train.py b/swr2_asr/train.py index d13683f..8943f71 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -57,6 +57,7 @@ def data_processing(data, data_type="train"): def greedy_decoder( output, labels, label_lengths, blank_label=28, collapse_repeated=True ): + # TODO: adopt to support both tokenizers """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] @@ -77,6 +78,7 @@ def greedy_decoder( return decodes, targets +# TODO: restructure into own file / class class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" @@ -280,6 +282,7 @@ def train( return loss.item() +# TODO: check how dataloader can be made more efficient def test(model, device, test_loader, criterion): """Test""" print("\nevaluating...") @@ -327,7 +330,7 @@ def run( "n_cnn_layers": 3, "n_rnn_layers": 5, "rnn_dim": 512, - "n_class": 36, + "n_class": 36, # TODO: dynamically determine this from vocab size "n_feats": 128, "stride": 2, "dropout": 0.1, -- cgit v1.2.3 From 33744072d8c0906950cdc9cd00fc1f345a51d9d4 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Sun, 20 Aug 2023 15:23:02 +0200 Subject: please the linters --- .github/workflows/format.yml | 38 +++++++++++++++++++++----------------- Makefile | 3 +++ mypy.ini | 6 ++++++ poetry.lock | 25 +++++++++++++++++++++---- pyproject.toml | 1 + swr2_asr/loss_scores.py | 36 +++++++++++++++++++----------------- swr2_asr/tokenizer.py | 18 ++++++++++-------- swr2_asr/train.py | 12 ++++++------ 8 files changed, 87 insertions(+), 52 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 4a5a509..8411d60 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -6,20 +6,24 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@master - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: "3.10" - - name: Install dependencies - run: | - python -m pip install -U pip poetry - poetry --version - poetry check --no-interaction - poetry config virtualenvs.in-project true - poetry install --no-interaction - - name: Run CI - run: | - make lint - - + - uses: actions/checkout@master + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install -U pip poetry + poetry --version + poetry check --no-interaction + poetry config virtualenvs.in-project true + poetry install --no-interaction + - name: Check for format issues + run: | + make format-check + - name: Run pylint + run: | + poetry run pylint swr2_asr + - name: Run mypy + run: | + poetry run mypy --strict swr2_asr diff --git a/Makefile b/Makefile index 4f0ea9c..a37644c 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ format: @poetry run black . +format-check: + @poetry run black --check . + lint: @poetry run mypy --strict swr2_asr @poetry run pylint swr2_asr \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index f7cfc59..c13aa05 100644 --- a/mypy.ini +++ b/mypy.ini @@ -9,3 +9,9 @@ ignore_missing_imports = true [mypy-click.*] ignore_missing_imports = true + +[mypy-tokenizers.*] +ignore_missing_imports = true + +[mypy-tqmd.*] +ignore_missing_imports = true \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 49d37d1..1f3609a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -14,7 +14,10 @@ files = [ [package.dependencies] lazy-object-proxy = ">=1.4.0" typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} -wrapt = {version = ">=1.11,<2", markers = "python_version < \"3.11\""} +wrapt = [ + {version = ">=1.11,<2", markers = "python_version < \"3.11\""}, + {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, +] [[package]] name = "AudioLoader" @@ -680,7 +683,10 @@ files = [ [package.dependencies] astroid = ">=2.15.6,<=2.17.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} -dill = {version = ">=0.2", markers = "python_version < \"3.11\""} +dill = [ + {version = ">=0.2", markers = "python_version < \"3.11\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, +] isort = ">=4.2.5,<6" mccabe = ">=0.6,<0.8" platformdirs = ">=2.2.0" @@ -967,6 +973,17 @@ torch = "*" tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] +[[package]] +name = "types-tqdm" +version = "4.66.0.1" +description = "Typing stubs for tqdm" +optional = false +python-versions = "*" +files = [ + {file = "types-tqdm-4.66.0.1.tar.gz", hash = "sha256:6457c90f03cc5a0fe8dd11839c8cbf5572bf542b438b1af74233801728b5dfbc"}, + {file = "types_tqdm-4.66.0.1-py3-none-any.whl", hash = "sha256:6a1516788cbb33d725803439b79c25bfed7e8176b8d782020b5c24aedac1649b"}, +] + [[package]] name = "typing-extensions" version = "4.7.1" @@ -1078,5 +1095,5 @@ files = [ [metadata] lock-version = "2.0" -python-versions = "~3.10" -content-hash = "a72b4e5791a6216b58b53a72bf68d97dbdbc95978b3974fddd9e5f9b76e36321" +python-versions = "^3.10" +content-hash = "6b42e36364178f1670267137f73e8d2b2f3fc1d534a2b198d4ca3f65457d55c2" diff --git a/pyproject.toml b/pyproject.toml index eb17479..fabe364 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ black = "^23.7.0" mypy = "^1.5.1" pylint = "^2.17.5" ruff = "^0.0.285" +types-tqdm = "^4.66.0.1" [tool.poetry.scripts] train = "swr2_asr.train:run_cli" diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py index 977462d..c49cc15 100644 --- a/swr2_asr/loss_scores.py +++ b/swr2_asr/loss_scores.py @@ -1,7 +1,9 @@ +"""Methods for determining the loss and scores of the model.""" import numpy as np def avg_wer(wer_scores, combined_ref_len): + """Calculate the average word error rate (WER) of the model.""" return float(sum(wer_scores)) / float(combined_ref_len) @@ -13,34 +15,34 @@ def _levenshtein_distance(ref, hyp): extend the edits to word level when calculate levenshtein disctance for two sentences. """ - m = len(ref) - n = len(hyp) + len_ref = len(ref) + len_hyp = len(hyp) # special case if ref == hyp: return 0 - if m == 0: - return n - if n == 0: - return m + if len_ref == 0: + return len_hyp + if len_hyp == 0: + return len_ref - if m < n: + if len_ref < len_hyp: ref, hyp = hyp, ref - m, n = n, m + len_ref, len_hyp = len_hyp, len_ref # use O(min(m, n)) space - distance = np.zeros((2, n + 1), dtype=np.int32) + distance = np.zeros((2, len_hyp + 1), dtype=np.int32) # initialize distance matrix - for j in range(0, n + 1): + for j in range(0, len_hyp + 1): distance[0][j] = j # calculate levenshtein distance - for i in range(1, m + 1): + for i in range(1, len_ref + 1): prev_row_idx = (i - 1) % 2 cur_row_idx = i % 2 distance[cur_row_idx][0] = i - for j in range(1, n + 1): + for j in range(1, len_hyp + 1): if ref[i - 1] == hyp[j - 1]: distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] else: @@ -49,7 +51,7 @@ def _levenshtein_distance(ref, hyp): d_num = distance[prev_row_idx][j] + 1 distance[cur_row_idx][j] = min(s_num, i_num, d_num) - return distance[m % 2][n] + return distance[len_ref % 2][len_hyp] def word_errors( @@ -143,8 +145,8 @@ def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "): if ref_len == 0: raise ValueError("Reference's word number should be greater than 0.") - wer = float(edit_distance) / ref_len - return wer + word_error_rate = float(edit_distance) / ref_len + return word_error_rate def cer(reference, hypothesis, ignore_case=False, remove_space=False): @@ -181,5 +183,5 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False): if ref_len == 0: raise ValueError("Length of reference should be greater than 0.") - cer = float(edit_distance) / ref_len - return cer + char_error_rate = float(edit_distance) / ref_len + return char_error_rate diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 79d6727..a665159 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -63,15 +63,15 @@ class CharTokenizer: else: splits = [split] - chars = set() - for sp in splits: + chars: set = set() + for s_plit in splits: transcript_path = os.path.join( - dataset_path, language, sp, "transcripts.txt" + dataset_path, language, s_plit, "transcripts.txt" ) # check if dataset is downloaded, download if not if download and not os.path.exists(transcript_path): - MultilingualLibriSpeech(dataset_path, language, sp, download=True) + MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) with open( transcript_path, @@ -82,7 +82,7 @@ class CharTokenizer: lines = [line.split(" ", 1)[1] for line in lines] lines = [line.strip() for line in lines] - for line in tqdm(lines, desc=f"Training tokenizer on {sp} split"): + for line in tqdm(lines, desc=f"Training tokenizer on {s_plit} split"): chars.update(line) offset = len(self.char_map) for i, char in enumerate(chars): @@ -205,10 +205,12 @@ def train_bpe_tokenizer( lines = [] - for sp in splits: - transcripts_path = os.path.join(dataset_path, language, sp, "transcripts.txt") + for s_plit in splits: + transcripts_path = os.path.join( + dataset_path, language, s_plit, "transcripts.txt" + ) if download and not os.path.exists(transcripts_path): - MultilingualLibriSpeech(dataset_path, language, sp, download=True) + MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) with open( transcripts_path, diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 8943f71..6af1e80 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -83,7 +83,7 @@ class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" def __init__(self, n_feats: int): - super(CNNLayerNorm, self).__init__() + super().__init__() self.layer_norm = nn.LayerNorm(n_feats) def forward(self, data): @@ -105,7 +105,7 @@ class ResidualCNN(nn.Module): dropout: float, n_feats: int, ): - super(ResidualCNN, self).__init__() + super().__init__() self.cnn1 = nn.Conv2d( in_channels, out_channels, kernel, stride, padding=kernel // 2 @@ -147,7 +147,7 @@ class BidirectionalGRU(nn.Module): dropout: float, batch_first: bool, ): - super(BidirectionalGRU, self).__init__() + super().__init__() self.bi_gru = nn.GRU( input_size=rnn_dim, @@ -181,7 +181,7 @@ class SpeechRecognitionModel(nn.Module): stride: int = 2, dropout: float = 0.1, ): - super(SpeechRecognitionModel, self).__init__() + super().__init__() n_feats //= 2 self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) # n residual cnn layers with filter size of 32 @@ -227,7 +227,7 @@ class SpeechRecognitionModel(nn.Module): return data -class IterMeter(object): +class IterMeter: """keeps track of total iterations""" def __init__(self): @@ -381,7 +381,7 @@ def run( ).to(device) print( - "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) + "Num Model Parameters", sum((param.nelement() for param in model.parameters())) ) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) -- cgit v1.2.3