aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
authorPherkel2023-08-18 23:41:09 +0200
committerPherkel2023-08-18 23:41:09 +0200
commit8159a0b8a4519dced2490d77b7e1ae7fd1bbadef (patch)
tree586e7304ce46cbd17d07ebc3c6fec1e46f1a4a0f /swr2_asr
parent13a608530eba90cea4c003566e331938fbf34bda (diff)
made linter changes (will still fail)
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/train_2.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py
index b1b597a..bea5bf4 100644
--- a/swr2_asr/train_2.py
+++ b/swr2_asr/train_2.py
@@ -1,13 +1,11 @@
"""Training script for the ASR model."""
from AudioLoader.speech.mls import MultilingualLibriSpeech
-import click
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchaudio
-import torchaudio.functional as AF
from .loss_scores import cer, wer
@@ -113,7 +111,7 @@ def data_processing(data, data_type="train"):
elif data_type == "valid":
spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1)
else:
- raise Exception("data_type should be train or valid")
+ raise ValueError("data_type should be train or valid")
spectrograms.append(spec)
label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower()))
labels.append(label)
@@ -133,6 +131,7 @@ def data_processing(data, data_type="train"):
def GreedyDecoder(
output, labels, label_lengths, blank_label=28, collapse_repeated=True
):
+ """Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2)
decodes = []
targets = []
@@ -344,13 +343,13 @@ def train(
)
-def test(model, device, test_loader, criterion, epoch, iter_meter):
+def test(model, device, test_loader, criterion):
print("\nevaluating...")
model.eval()
test_loss = 0
test_cer, test_wer = [], []
with torch.no_grad():
- for i, _data in enumerate(test_loader):
+ for _data in test_loader:
spectrograms, labels, input_lengths, label_lengths = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)
@@ -372,9 +371,8 @@ def test(model, device, test_loader, criterion, epoch, iter_meter):
avg_wer = sum(test_wer) / len(test_wer)
print(
- "Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n".format(
- test_loss, avg_cer, avg_wer
- )
+ f"Test set: Average loss:\
+ {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n"
)
@@ -459,7 +457,7 @@ def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
epoch,
iter_meter,
)
- test(model, device, test_loader, criterion, epoch, iter_meter)
+ test(model=model, device=device, test_loader=test_loader, criterion=criterion)
if __name__ == "__main__":