aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/__main__.py12
-rw-r--r--swr2_asr/inference.py16
-rw-r--r--swr2_asr/model_deep_speech.py17
-rw-r--r--swr2_asr/train.py192
-rw-r--r--swr2_asr/utils/data.py7
-rw-r--r--swr2_asr/utils/tokenizer.py8
-rw-r--r--swr2_asr/utils/visualization.py8
7 files changed, 109 insertions, 151 deletions
diff --git a/swr2_asr/__main__.py b/swr2_asr/__main__.py
deleted file mode 100644
index be294fb..0000000
--- a/swr2_asr/__main__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""Main entrypoint for swr2-asr."""
-import torch
-import torchaudio
-
-if __name__ == "__main__":
- # test if GPU is available
- print("GPU available: ", torch.cuda.is_available())
-
- # test if torchaudio is installed correctly
- print("torchaudio version: ", torchaudio.__version__)
- print("torchaudio backend: ", torchaudio.get_audio_backend())
- print("torchaudio info: ", torchaudio.get_audio_backend())
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index c3eec42..f8342f7 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -1,11 +1,12 @@
"""Training script for the ASR model."""
+from typing import TypedDict
+
import torch
-import torchaudio
import torch.nn.functional as F
-from typing import TypedDict
+import torchaudio
-from swr2_asr.tokenizer import CharTokenizer
from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.utils.tokenizer import CharTokenizer
class HParams(TypedDict):
@@ -28,8 +29,7 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True):
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
blank_label = tokenizer.encode(" ").ids[0]
decodes = []
- targets = []
- for i, args in enumerate(arg_maxes):
+ for _i, args in enumerate(arg_maxes):
decode = []
for j, index in enumerate(args):
if index != blank_label:
@@ -44,7 +44,7 @@ def main() -> None:
"""inference function."""
device = "cuda" if torch.cuda.is_available() else "cpu"
- device = torch.device(device)
+ device = torch.device(device) # pylint: disable=no-member
tokenizer = CharTokenizer.from_file("char_tokenizer_german.json")
@@ -90,7 +90,7 @@ def main() -> None:
model.load_state_dict(state_dict)
# waveform, sample_rate = torchaudio.load("test.opus")
- waveform, sample_rate = torchaudio.load("marvin_rede.flac")
+ waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member
if sample_rate != spectrogram_hparams["sample_rate"]:
resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"])
waveform = resampler(waveform)
@@ -103,7 +103,7 @@ def main() -> None:
specs = [spec]
specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3)
- output = model(specs)
+ output = model(specs) # pylint: disable=not-callable
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
decodes = greedy_decoder(output, tokenizer)
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
index 8ddbd99..77f4c8a 100644
--- a/swr2_asr/model_deep_speech.py
+++ b/swr2_asr/model_deep_speech.py
@@ -3,27 +3,10 @@
Following definition by Assembly AI
(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/)
"""
-from typing import TypedDict
-
import torch.nn.functional as F
from torch import nn
-class HParams(TypedDict):
- """Type for the hyperparameters of the model."""
-
- n_cnn_layers: int
- n_rnn_layers: int
- rnn_dim: int
- n_class: int
- n_feats: int
- stride: int
- dropout: float
- learning_rate: float
- batch_size: int
- epochs: int
-
-
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ac7666b..eb79ee2 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -5,11 +5,12 @@ from typing import TypedDict
import click
import torch
import torch.nn.functional as F
+import yaml
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
-from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel
+from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
from swr2_asr.utils.decoder import greedy_decoder
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -17,7 +18,7 @@ from swr2_asr.utils.tokenizer import CharTokenizer
from .utils.loss_scores import cer, wer
-class IterMeter(object):
+class IterMeter:
"""keeps track of total iterations"""
def __init__(self):
@@ -116,6 +117,7 @@ class TestArgs(TypedDict):
def test(test_args: TestArgs) -> tuple[float, float, float]:
+ """Test"""
print("\nevaluating...")
# get values from test_args:
@@ -128,7 +130,7 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
test_loss = 0
test_cer, test_wer = [], []
with torch.no_grad():
- for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")):
+ for _data in tqdm(test_loader, desc="Validation Batches"):
spectrograms, labels, input_lengths, label_lengths = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)
@@ -142,8 +144,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
decoded_preds, decoded_targets = greedy_decoder(
output.transpose(0, 1), labels, label_lengths, tokenizer
)
- if i == 1:
- print(f"decoding first sample: {decoded_preds}")
for j, _ in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
@@ -161,157 +161,149 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
return test_loss, avg_cer, avg_wer
-def main(
- learning_rate: float,
- batch_size: int,
- epochs: int,
- dataset_path: str,
- language: str,
- limited_supervision: bool,
- model_load_path: str,
- model_save_path: str,
- dataset_percentage: float,
- eval_every: int,
- num_workers: int,
-):
+@click.command()
+@click.option(
+ "--config_path",
+ default="config.yaml",
+ help="Path to yaml config file",
+ type=click.Path(exists=True),
+)
+def main(config_path: str):
"""Main function for training the model.
- Args:
- learning_rate: learning rate for the optimizer
- batch_size: batch size
- epochs: number of epochs to train
- dataset_path: path for the dataset
- language: language of the dataset
- limited_supervision: whether to use only limited supervision
- model_load_path: path to load a model from
- model_save_path: path to save the model to
- dataset_percentage: percentage of the dataset to use
- eval_every: evaluate every n epochs
- num_workers: number of workers for the dataloader
+ Gets all configuration arguments from yaml config file.
"""
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
torch.manual_seed(7)
- if not os.path.isdir(dataset_path):
- os.makedirs(dataset_path)
+ with open(config_path, "r", encoding="utf-8") as yaml_file:
+ config_dict = yaml.safe_load(yaml_file)
+
+ # Create separate dictionaries for each top-level key
+ model_config = config_dict.get("model", {})
+ training_config = config_dict.get("training", {})
+ dataset_config = config_dict.get("dataset", {})
+ tokenizer_config = config_dict.get("tokenizer", {})
+ checkpoints_config = config_dict.get("checkpoints", {})
+
+ print(training_config["learning_rate"])
+
+ if not os.path.isdir(dataset_config["dataset_root_path"]):
+ os.makedirs(dataset_config["dataset_root_path"])
train_dataset = MLSDataset(
- dataset_path,
- language,
+ dataset_config["dataset_root_path"],
+ dataset_config["language_name"],
Split.TEST,
- download=True,
- limited=limited_supervision,
- size=dataset_percentage,
+ download=dataset_config["download"],
+ limited=dataset_config["limited_supervision"],
+ size=dataset_config["dataset_percentage"],
)
valid_dataset = MLSDataset(
- dataset_path,
- language,
+ dataset_config["dataset_root_path"],
+ dataset_config["language_name"],
Split.TRAIN,
- download=False,
- limited=Falimited_supervisionlse,
- size=dataset_percentage,
+ download=dataset_config["download"],
+ limited=dataset_config["limited_supervision"],
+ size=dataset_config["dataset_percentage"],
)
- # TODO: initialize and possibly train tokenizer if none found
-
- kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}
-
- hparams = HParams(
- n_cnn_layers=3,
- n_rnn_layers=5,
- rnn_dim=512,
- n_class=tokenizer.get_vocab_size(),
- n_feats=128,
- stride=2,
- dropout=0.1,
- learning_rate=learning_rate,
- batch_size=batch_size,
- epochs=epochs,
- )
+ kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {}
+
+ if tokenizer_config["tokenizer_path"] is None:
+ print("Tokenizer not found!")
+ if click.confirm("Do you want to train a new tokenizer?", default=True):
+ pass
+ else:
+ return
+ tokenizer = CharTokenizer.train(
+ dataset_config["dataset_root_path"], dataset_config["language_name"]
+ )
+ tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
train_data_processing = DataProcessing("train", tokenizer)
valid_data_processing = DataProcessing("valid", tokenizer)
train_loader = DataLoader(
dataset=train_dataset,
- batch_size=hparams["batch_size"],
- shuffle=True,
+ batch_size=training_config["batch_size"],
+ shuffle=dataset_config["shuffle"],
collate_fn=train_data_processing,
**kwargs,
)
valid_loader = DataLoader(
dataset=valid_dataset,
- batch_size=hparams["batch_size"],
- shuffle=False,
+ batch_size=training_config["batch_size"],
+ shuffle=dataset_config["shuffle"],
collate_fn=valid_data_processing,
**kwargs,
)
model = SpeechRecognitionModel(
- hparams["n_cnn_layers"],
- hparams["n_rnn_layers"],
- hparams["rnn_dim"],
- hparams["n_class"],
- hparams["n_feats"],
- hparams["stride"],
- hparams["dropout"],
+ model_config["n_cnn_layers"],
+ model_config["n_rnn_layers"],
+ model_config["rnn_dim"],
+ tokenizer.get_vocab_size(),
+ model_config["n_feats"],
+ model_config["stride"],
+ model_config["dropout"],
).to(device)
- optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
+ optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"])
criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
- max_lr=hparams["learning_rate"],
+ max_lr=training_config["learning_rate"],
steps_per_epoch=int(len(train_loader)),
- epochs=hparams["epochs"],
+ epochs=training_config["epochs"],
anneal_strategy="linear",
)
prev_epoch = 0
- if model_load_path is not None:
- checkpoint = torch.load(model_load_path)
+ if checkpoints_config["model_load_path"] is not None:
+ checkpoint = torch.load(checkpoints_config["model_load_path"])
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]
iter_meter = IterMeter()
- if not os.path.isdir(os.path.dirname(model_save_path)):
- os.makedirs(os.path.dirname(model_save_path))
- for epoch in range(prev_epoch + 1, epochs + 1):
- train_args: TrainArgs = dict(
- model=model,
- device=device,
- train_loader=train_loader,
- criterion=criterion,
- optimizer=optimizer,
- scheduler=scheduler,
- epoch=epoch,
- iter_meter=iter_meter,
- )
+
+ for epoch in range(prev_epoch + 1, training_config["epochs"] + 1):
+ train_args: TrainArgs = {
+ "model": model,
+ "device": device,
+ "train_loader": train_loader,
+ "criterion": criterion,
+ "optimizer": optimizer,
+ "scheduler": scheduler,
+ "epoch": epoch,
+ "iter_meter": iter_meter,
+ }
train_loss = train(train_args)
test_loss, test_cer, test_wer = 0, 0, 0
- test_args: TestArgs = dict(
- model=model,
- device=device,
- test_loader=valid_loader,
- criterion=criterion,
- tokenizer=tokenizer,
- decoder="greedy",
- )
+ test_args: TestArgs = {
+ "model": model,
+ "device": device,
+ "test_loader": valid_loader,
+ "criterion": criterion,
+ "tokenizer": tokenizer,
+ "decoder": "greedy",
+ }
- if epoch % eval_every == 0:
+ if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0:
test_loss, test_cer, test_wer = test(test_args)
- if model_save_path is None:
+ if checkpoints_config["model_save_path"] is None:
continue
- if not os.path.isdir(os.path.dirname(model_save_path)):
- os.makedirs(os.path.dirname(model_save_path))
+ if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])):
+ os.makedirs(os.path.dirname(checkpoints_config["model_save_path"]))
+
torch.save(
{
"epoch": epoch,
@@ -322,7 +314,7 @@ def main(
"avg_cer": test_cer,
"avg_wer": test_wer,
},
- model_save_path + str(epoch),
+ checkpoints_config["model_save_path"] + str(epoch),
)
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index e939e1d..0e06eec 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -1,13 +1,12 @@
"""Class containing utils for the ASR system."""
import os
from enum import Enum
-from typing import TypedDict
import numpy as np
import torch
import torchaudio
from torch import Tensor, nn
-from torch.utils.data import Dataset
+from torch.utils.data import DataLoader, Dataset
from torchaudio.datasets.utils import _extract_tar
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -125,7 +124,7 @@ class MLSDataset(Dataset):
self._handle_download_dataset(download)
self._validate_local_directory()
- if limited and (split == Split.TRAIN or split == Split.VALID):
+ if limited and split in (Split.TRAIN, Split.VALID):
self.initialize_limited()
else:
self.initialize()
@@ -351,8 +350,6 @@ class MLSDataset(Dataset):
if __name__ == "__main__":
- from torch.utils.data import DataLoader
-
DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
LANGUAGE = "mls_german_opus"
split = Split.DEV
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 5482bbe..22569eb 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -1,8 +1,6 @@
"""Tokenizer for Multilingual Librispeech datasets"""
-
-
-from datetime import datetime
import os
+from datetime import datetime
from tqdm.autonotebook import tqdm
@@ -119,8 +117,8 @@ class CharTokenizer:
line = line.strip()
if line:
char, index = line.split()
- tokenizer.char_map[char] = int(index)
- tokenizer.index_map[int(index)] = char
+ load_tokenizer.char_map[char] = int(index)
+ load_tokenizer.index_map[int(index)] = char
return load_tokenizer
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
index 80f942a..a55d0d5 100644
--- a/swr2_asr/utils/visualization.py
+++ b/swr2_asr/utils/visualization.py
@@ -6,10 +6,10 @@ import torch
def plot(epochs, path):
"""Plots the losses over the epochs"""
- losses = list()
- test_losses = list()
- cers = list()
- wers = list()
+ losses = []
+ test_losses = []
+ cers = []
+ wers = []
for epoch in range(1, epochs + 1):
current_state = torch.load(path + str(epoch))
losses.append(current_state["loss"])