diff options
Diffstat (limited to 'swr2_asr/train_2.py')
-rw-r--r-- | swr2_asr/train_2.py | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py index 2e690e2..b1b597a 100644 --- a/swr2_asr/train_2.py +++ b/swr2_asr/train_2.py @@ -48,6 +48,20 @@ class TextTransform: ö 29 ü 30 ß 31 + - 32 + é 33 + è 34 + à 35 + ù 36 + ç 37 + â 38 + ê 39 + î 40 + ô 41 + û 42 + ë 43 + ï 44 + ü 45 """ self.char_map = {} self.index_map = {} @@ -93,15 +107,15 @@ def data_processing(data, data_type="train"): labels = [] input_lengths = [] label_lengths = [] - for waveform, _, utterance, _, _, _ in data: + for x in data: if data_type == "train": - spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1) + spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) elif data_type == "valid": - spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1) + spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) else: raise Exception("data_type should be train or valid") spectrograms.append(spec) - label = torch.Tensor(text_transform.text_to_int(utterance.lower())) + label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower())) labels.append(label) input_lengths.append(spec.shape[0] // 2) label_lengths.append(len(label)) @@ -364,12 +378,12 @@ def test(model, device, test_loader, criterion, epoch, iter_meter): ) -def run(lr: float, batch_size: int, epochs: int) -> None: +def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: hparams = { "n_cnn_layers": 3, "n_rnn_layers": 5, "rnn_dim": 512, - "n_class": 33, + "n_class": 46, "n_feats": 128, "stride": 2, "dropout": 0.1, @@ -381,13 +395,13 @@ def run(lr: float, batch_size: int, epochs: int) -> None: use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") - device = torch.device("mps") + # device = torch.device("mps") train_dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="train", download=False + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False ) test_dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="test", download=False + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False ) kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} @@ -401,7 +415,7 @@ def run(lr: float, batch_size: int, epochs: int) -> None: ) test_loader = DataLoader( - train_dataset, + test_dataset, batch_size=hparams["batch_size"], shuffle=True, collate_fn=lambda x: data_processing(x, "train"), @@ -449,4 +463,4 @@ def run(lr: float, batch_size: int, epochs: int) -> None: if __name__ == "__main__": - run(lr=5e-4, batch_size=20, epochs=10) + run(lr=5e-4, batch_size=16, epochs=1) |