diff options
author | Pherkel | 2023-08-18 23:41:09 +0200 |
---|---|---|
committer | Pherkel | 2023-08-18 23:41:09 +0200 |
commit | 8159a0b8a4519dced2490d77b7e1ae7fd1bbadef (patch) | |
tree | 586e7304ce46cbd17d07ebc3c6fec1e46f1a4a0f /swr2_asr | |
parent | 13a608530eba90cea4c003566e331938fbf34bda (diff) |
made linter changes (will still fail)
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/train_2.py | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py index b1b597a..bea5bf4 100644 --- a/swr2_asr/train_2.py +++ b/swr2_asr/train_2.py @@ -1,13 +1,11 @@ """Training script for the ASR model.""" from AudioLoader.speech.mls import MultilingualLibriSpeech -import click import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader import torchaudio -import torchaudio.functional as AF from .loss_scores import cer, wer @@ -113,7 +111,7 @@ def data_processing(data, data_type="train"): elif data_type == "valid": spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) else: - raise Exception("data_type should be train or valid") + raise ValueError("data_type should be train or valid") spectrograms.append(spec) label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower())) labels.append(label) @@ -133,6 +131,7 @@ def data_processing(data, data_type="train"): def GreedyDecoder( output, labels, label_lengths, blank_label=28, collapse_repeated=True ): + """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) decodes = [] targets = [] @@ -344,13 +343,13 @@ def train( ) -def test(model, device, test_loader, criterion, epoch, iter_meter): +def test(model, device, test_loader, criterion): print("\nevaluating...") model.eval() test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for i, _data in enumerate(test_loader): + for _data in test_loader: spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) @@ -372,9 +371,8 @@ def test(model, device, test_loader, criterion, epoch, iter_meter): avg_wer = sum(test_wer) / len(test_wer) print( - "Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n".format( - test_loss, avg_cer, avg_wer - ) + f"Test set: Average loss:\ + {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" ) @@ -459,7 +457,7 @@ def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: epoch, iter_meter, ) - test(model, device, test_loader, criterion, epoch, iter_meter) + test(model=model, device=device, test_loader=test_loader, criterion=criterion) if __name__ == "__main__": |