aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py103
1 files changed, 11 insertions, 92 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 29f9372..2628028 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,93 +1,14 @@
"""Training script for the ASR model."""
-from AudioLoader.speech.mls import MultilingualLibriSpeech
import click
import torch
-import torch.nn as nn
-import torch.optim as optim
import torch.nn.functional as F
-from torch.utils.data import DataLoader
import torchaudio
-from .loss_scores import cer, wer
-
-
-class TextTransform:
- """Maps characters to integers and vice versa"""
-
- def __init__(self):
- char_map_str = """
- ' 0
- <SPACE> 1
- a 2
- b 3
- c 4
- d 5
- e 6
- f 7
- g 8
- h 9
- i 10
- j 11
- k 12
- l 13
- m 14
- n 15
- o 16
- p 17
- q 18
- r 19
- s 20
- t 21
- u 22
- v 23
- w 24
- x 25
- y 26
- z 27
- ä 28
- ö 29
- ü 30
- ß 31
- - 32
- é 33
- è 34
- à 35
- ù 36
- ç 37
- â 38
- ê 39
- î 40
- ô 41
- û 42
- ë 43
- ï 44
- ü 45
- """
- self.char_map = {}
- self.index_map = {}
- for line in char_map_str.strip().split("\n"):
- char, index = line.split()
- self.char_map[char] = int(index)
- self.index_map[int(index)] = char
- self.index_map[1] = " "
-
- def text_to_int(self, text):
- """Use a character map and convert text to an integer sequence"""
- int_sequence = []
- for char in text:
- if char == " ":
- mapped_char = self.char_map["<SPACE>"]
- else:
- mapped_char = self.char_map[char]
- int_sequence.append(mapped_char)
- return int_sequence
-
- def int_to_text(self, labels):
- """Use a character map and convert integer labels to an text sequence"""
- string = []
- for i in labels:
- string.append(self.index_map[i])
- return "".join(string).replace("<SPACE>", " ")
+from AudioLoader.speech import MultilingualLibriSpeech
+from torch import nn, optim
+from torch.utils.data import DataLoader
+from tokenizers import Tokenizer
+from .loss_scores import cer, wer
train_audio_transforms = nn.Sequential(
torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
@@ -97,7 +18,7 @@ train_audio_transforms = nn.Sequential(
valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
-text_transform = TextTransform()
+text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
def data_processing(data, data_type="train"):
@@ -114,7 +35,7 @@ def data_processing(data, data_type="train"):
else:
raise ValueError("data_type should be train or valid")
spectrograms.append(spec)
- label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower()))
+ label = torch.Tensor(text_transform.encode(sample["utterance"]).ids)
labels.append(label)
input_lengths.append(spec.shape[0] // 2)
label_lengths.append(len(label))
@@ -138,15 +59,13 @@ def greedy_decoder(
targets = []
for i, args in enumerate(arg_maxes):
decode = []
- targets.append(
- text_transform.int_to_text(labels[i][: label_lengths[i]].tolist())
- )
+ targets.append(text_transform.decode(labels[i][: label_lengths[i]].tolist()))
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
- decodes.append(text_transform.int_to_text(decode))
+ decodes.append(text_transform.decode(decode))
return decodes, targets
@@ -407,10 +326,10 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
# device = torch.device("mps")
train_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False
+ "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="dev", download=False
)
test_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False
+ "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="test", download=False
)
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}