diff options
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 103 |
1 files changed, 11 insertions, 92 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 29f9372..2628028 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,93 +1,14 @@ """Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech import click import torch -import torch.nn as nn -import torch.optim as optim import torch.nn.functional as F -from torch.utils.data import DataLoader import torchaudio -from .loss_scores import cer, wer - - -class TextTransform: - """Maps characters to integers and vice versa""" - - def __init__(self): - char_map_str = """ - ' 0 - <SPACE> 1 - a 2 - b 3 - c 4 - d 5 - e 6 - f 7 - g 8 - h 9 - i 10 - j 11 - k 12 - l 13 - m 14 - n 15 - o 16 - p 17 - q 18 - r 19 - s 20 - t 21 - u 22 - v 23 - w 24 - x 25 - y 26 - z 27 - ä 28 - ö 29 - ü 30 - ß 31 - - 32 - é 33 - è 34 - à 35 - ù 36 - ç 37 - â 38 - ê 39 - î 40 - ô 41 - û 42 - ë 43 - ï 44 - ü 45 - """ - self.char_map = {} - self.index_map = {} - for line in char_map_str.strip().split("\n"): - char, index = line.split() - self.char_map[char] = int(index) - self.index_map[int(index)] = char - self.index_map[1] = " " - - def text_to_int(self, text): - """Use a character map and convert text to an integer sequence""" - int_sequence = [] - for char in text: - if char == " ": - mapped_char = self.char_map["<SPACE>"] - else: - mapped_char = self.char_map[char] - int_sequence.append(mapped_char) - return int_sequence - - def int_to_text(self, labels): - """Use a character map and convert integer labels to an text sequence""" - string = [] - for i in labels: - string.append(self.index_map[i]) - return "".join(string).replace("<SPACE>", " ") +from AudioLoader.speech import MultilingualLibriSpeech +from torch import nn, optim +from torch.utils.data import DataLoader +from tokenizers import Tokenizer +from .loss_scores import cer, wer train_audio_transforms = nn.Sequential( torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), @@ -97,7 +18,7 @@ train_audio_transforms = nn.Sequential( valid_audio_transforms = torchaudio.transforms.MelSpectrogram() -text_transform = TextTransform() +text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") def data_processing(data, data_type="train"): @@ -114,7 +35,7 @@ def data_processing(data, data_type="train"): else: raise ValueError("data_type should be train or valid") spectrograms.append(spec) - label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower())) + label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) labels.append(label) input_lengths.append(spec.shape[0] // 2) label_lengths.append(len(label)) @@ -138,15 +59,13 @@ def greedy_decoder( targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append( - text_transform.int_to_text(labels[i][: label_lengths[i]].tolist()) - ) + targets.append(text_transform.decode(labels[i][: label_lengths[i]].tolist())) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: continue decode.append(index.item()) - decodes.append(text_transform.int_to_text(decode)) + decodes.append(text_transform.decode(decode)) return decodes, targets @@ -407,10 +326,10 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No # device = torch.device("mps") train_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False + "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="dev", download=False ) test_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False + "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="test", download=False ) kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} |