aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train_2.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train_2.py')
-rw-r--r--swr2_asr/train_2.py36
1 files changed, 25 insertions, 11 deletions
diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py
index 2e690e2..b1b597a 100644
--- a/swr2_asr/train_2.py
+++ b/swr2_asr/train_2.py
@@ -48,6 +48,20 @@ class TextTransform:
ö 29
ü 30
ß 31
+ - 32
+ é 33
+ è 34
+ à 35
+ ù 36
+ ç 37
+ â 38
+ ê 39
+ î 40
+ ô 41
+ û 42
+ ë 43
+ ï 44
+ ü 45
"""
self.char_map = {}
self.index_map = {}
@@ -93,15 +107,15 @@ def data_processing(data, data_type="train"):
labels = []
input_lengths = []
label_lengths = []
- for waveform, _, utterance, _, _, _ in data:
+ for x in data:
if data_type == "train":
- spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
+ spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1)
elif data_type == "valid":
- spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
+ spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1)
else:
raise Exception("data_type should be train or valid")
spectrograms.append(spec)
- label = torch.Tensor(text_transform.text_to_int(utterance.lower()))
+ label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower()))
labels.append(label)
input_lengths.append(spec.shape[0] // 2)
label_lengths.append(len(label))
@@ -364,12 +378,12 @@ def test(model, device, test_loader, criterion, epoch, iter_meter):
)
-def run(lr: float, batch_size: int, epochs: int) -> None:
+def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
hparams = {
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
- "n_class": 33,
+ "n_class": 46,
"n_feats": 128,
"stride": 2,
"dropout": 0.1,
@@ -381,13 +395,13 @@ def run(lr: float, batch_size: int, epochs: int) -> None:
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu")
- device = torch.device("mps")
+ # device = torch.device("mps")
train_dataset = MultilingualLibriSpeech(
- "data", "mls_german_opus", split="train", download=False
+ "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False
)
test_dataset = MultilingualLibriSpeech(
- "data", "mls_german_opus", split="test", download=False
+ "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False
)
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
@@ -401,7 +415,7 @@ def run(lr: float, batch_size: int, epochs: int) -> None:
)
test_loader = DataLoader(
- train_dataset,
+ test_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
collate_fn=lambda x: data_processing(x, "train"),
@@ -449,4 +463,4 @@ def run(lr: float, batch_size: int, epochs: int) -> None:
if __name__ == "__main__":
- run(lr=5e-4, batch_size=20, epochs=10)
+ run(lr=5e-4, batch_size=16, epochs=1)