aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.philipp.yaml29
-rw-r--r--config.train.yaml28
-rw-r--r--poetry.lock51
-rw-r--r--pyproject.toml1
-rw-r--r--requirements.txt1
-rw-r--r--swr2_asr/__main__.py12
-rw-r--r--swr2_asr/inference.py16
-rw-r--r--swr2_asr/model_deep_speech.py17
-rw-r--r--swr2_asr/train.py192
-rw-r--r--swr2_asr/utils/data.py7
-rw-r--r--swr2_asr/utils/tokenizer.py8
-rw-r--r--swr2_asr/utils/visualization.py8
12 files changed, 218 insertions, 152 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml
new file mode 100644
index 0000000..638b5ef
--- /dev/null
+++ b/config.philipp.yaml
@@ -0,0 +1,29 @@
+model:
+ n_cnn_layers: 3
+ n_rnn_layers: 5
+ rnn_dim: 512
+ n_feats: 128 # number of mel features
+ stride: 2
+ dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets
+
+training:
+ learning_rate: 0.0005
+ batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
+ epochs: 3
+ eval_every_n: 1 # evaluate every n epochs
+ num_workers: 4 # number of workers for dataloader
+
+dataset:
+ download: True
+ dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir
+ language_name: "mls_german_opus"
+ limited_supervision: True # set to True if you want to use limited supervision
+ dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%)
+ shuffle: True
+
+tokenizer:
+ tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
+
+checkpoints:
+ model_load_path: ~ # path to load model from
+ model_save_path: ~ # path to save model to \ No newline at end of file
diff --git a/config.train.yaml b/config.train.yaml
new file mode 100644
index 0000000..c82439d
--- /dev/null
+++ b/config.train.yaml
@@ -0,0 +1,28 @@
+model:
+ n_cnn_layers: 3
+ n_rnn_layers: 5
+ rnn_dim: 512
+ n_feats: 128 # number of mel features
+ stride: 2
+ dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets
+
+training:
+ learning_rate: 5e-4
+ batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
+ epochs: 3
+ eval_every_n: 3 # evaluate every n epochs
+ num_workers: 8 # number of workers for dataloader
+
+dataset:
+ download: True
+ dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir
+ language_name: "mls_german_opus"
+ limited_supervision: False # set to True if you want to use limited supervision
+ dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
+
+tokenizer:
+ tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml"
+
+checkpoints:
+ model_load_path: "YOUR/PATH" # path to load model from
+ model_save_path: "YOUR/PATH" # path to save model to \ No newline at end of file
diff --git a/poetry.lock b/poetry.lock
index 3901b8c..a1f916b 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1084,6 +1084,55 @@ files = [
six = ">=1.5"
[[package]]
+name = "pyyaml"
+version = "6.0.1"
+description = "YAML parser and emitter for Python"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"},
+ {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"},
+ {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"},
+ {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"},
+ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
+ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
+ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
+ {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
+ {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
+ {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
+ {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
+ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
+]
+
+[[package]]
name = "ruff"
version = "0.0.285"
description = "An extremely fast Python linter, written in Rust."
@@ -1482,4 +1531,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "b9efbbcd85e7d70578496491d81aa6ef8a610a77ffe134c08446300d5de42ed5"
+content-hash = "e45a9c1ba8b67cbe83c4b010c3f4718eee990b064b90a3ccd64380387e734faf"
diff --git a/pyproject.toml b/pyproject.toml
index 38cc51a..f6d19dd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,6 +17,7 @@ mido = "^1.3.0"
tokenizers = "^0.13.3"
click = "^8.1.7"
matplotlib = "^3.7.2"
+pyyaml = "^6.0.1"
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
diff --git a/requirements.txt b/requirements.txt
index 3b39b56..040fed0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -26,6 +26,7 @@ platformdirs==3.10.0
pylint==2.17.5
pyparsing==3.0.9
python-dateutil==2.8.2
+PyYAML==6.0.1
ruff==0.0.285
six==1.16.0
sympy==1.12
diff --git a/swr2_asr/__main__.py b/swr2_asr/__main__.py
deleted file mode 100644
index be294fb..0000000
--- a/swr2_asr/__main__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""Main entrypoint for swr2-asr."""
-import torch
-import torchaudio
-
-if __name__ == "__main__":
- # test if GPU is available
- print("GPU available: ", torch.cuda.is_available())
-
- # test if torchaudio is installed correctly
- print("torchaudio version: ", torchaudio.__version__)
- print("torchaudio backend: ", torchaudio.get_audio_backend())
- print("torchaudio info: ", torchaudio.get_audio_backend())
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index c3eec42..f8342f7 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -1,11 +1,12 @@
"""Training script for the ASR model."""
+from typing import TypedDict
+
import torch
-import torchaudio
import torch.nn.functional as F
-from typing import TypedDict
+import torchaudio
-from swr2_asr.tokenizer import CharTokenizer
from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.utils.tokenizer import CharTokenizer
class HParams(TypedDict):
@@ -28,8 +29,7 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True):
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
blank_label = tokenizer.encode(" ").ids[0]
decodes = []
- targets = []
- for i, args in enumerate(arg_maxes):
+ for _i, args in enumerate(arg_maxes):
decode = []
for j, index in enumerate(args):
if index != blank_label:
@@ -44,7 +44,7 @@ def main() -> None:
"""inference function."""
device = "cuda" if torch.cuda.is_available() else "cpu"
- device = torch.device(device)
+ device = torch.device(device) # pylint: disable=no-member
tokenizer = CharTokenizer.from_file("char_tokenizer_german.json")
@@ -90,7 +90,7 @@ def main() -> None:
model.load_state_dict(state_dict)
# waveform, sample_rate = torchaudio.load("test.opus")
- waveform, sample_rate = torchaudio.load("marvin_rede.flac")
+ waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member
if sample_rate != spectrogram_hparams["sample_rate"]:
resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"])
waveform = resampler(waveform)
@@ -103,7 +103,7 @@ def main() -> None:
specs = [spec]
specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3)
- output = model(specs)
+ output = model(specs) # pylint: disable=not-callable
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
decodes = greedy_decoder(output, tokenizer)
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
index 8ddbd99..77f4c8a 100644
--- a/swr2_asr/model_deep_speech.py
+++ b/swr2_asr/model_deep_speech.py
@@ -3,27 +3,10 @@
Following definition by Assembly AI
(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/)
"""
-from typing import TypedDict
-
import torch.nn.functional as F
from torch import nn
-class HParams(TypedDict):
- """Type for the hyperparameters of the model."""
-
- n_cnn_layers: int
- n_rnn_layers: int
- rnn_dim: int
- n_class: int
- n_feats: int
- stride: int
- dropout: float
- learning_rate: float
- batch_size: int
- epochs: int
-
-
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ac7666b..eb79ee2 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -5,11 +5,12 @@ from typing import TypedDict
import click
import torch
import torch.nn.functional as F
+import yaml
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
-from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel
+from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
from swr2_asr.utils.decoder import greedy_decoder
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -17,7 +18,7 @@ from swr2_asr.utils.tokenizer import CharTokenizer
from .utils.loss_scores import cer, wer
-class IterMeter(object):
+class IterMeter:
"""keeps track of total iterations"""
def __init__(self):
@@ -116,6 +117,7 @@ class TestArgs(TypedDict):
def test(test_args: TestArgs) -> tuple[float, float, float]:
+ """Test"""
print("\nevaluating...")
# get values from test_args:
@@ -128,7 +130,7 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
test_loss = 0
test_cer, test_wer = [], []
with torch.no_grad():
- for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")):
+ for _data in tqdm(test_loader, desc="Validation Batches"):
spectrograms, labels, input_lengths, label_lengths = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)
@@ -142,8 +144,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
decoded_preds, decoded_targets = greedy_decoder(
output.transpose(0, 1), labels, label_lengths, tokenizer
)
- if i == 1:
- print(f"decoding first sample: {decoded_preds}")
for j, _ in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
@@ -161,157 +161,149 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
return test_loss, avg_cer, avg_wer
-def main(
- learning_rate: float,
- batch_size: int,
- epochs: int,
- dataset_path: str,
- language: str,
- limited_supervision: bool,
- model_load_path: str,
- model_save_path: str,
- dataset_percentage: float,
- eval_every: int,
- num_workers: int,
-):
+@click.command()
+@click.option(
+ "--config_path",
+ default="config.yaml",
+ help="Path to yaml config file",
+ type=click.Path(exists=True),
+)
+def main(config_path: str):
"""Main function for training the model.
- Args:
- learning_rate: learning rate for the optimizer
- batch_size: batch size
- epochs: number of epochs to train
- dataset_path: path for the dataset
- language: language of the dataset
- limited_supervision: whether to use only limited supervision
- model_load_path: path to load a model from
- model_save_path: path to save the model to
- dataset_percentage: percentage of the dataset to use
- eval_every: evaluate every n epochs
- num_workers: number of workers for the dataloader
+ Gets all configuration arguments from yaml config file.
"""
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
torch.manual_seed(7)
- if not os.path.isdir(dataset_path):
- os.makedirs(dataset_path)
+ with open(config_path, "r", encoding="utf-8") as yaml_file:
+ config_dict = yaml.safe_load(yaml_file)
+
+ # Create separate dictionaries for each top-level key
+ model_config = config_dict.get("model", {})
+ training_config = config_dict.get("training", {})
+ dataset_config = config_dict.get("dataset", {})
+ tokenizer_config = config_dict.get("tokenizer", {})
+ checkpoints_config = config_dict.get("checkpoints", {})
+
+ print(training_config["learning_rate"])
+
+ if not os.path.isdir(dataset_config["dataset_root_path"]):
+ os.makedirs(dataset_config["dataset_root_path"])
train_dataset = MLSDataset(
- dataset_path,
- language,
+ dataset_config["dataset_root_path"],
+ dataset_config["language_name"],
Split.TEST,
- download=True,
- limited=limited_supervision,
- size=dataset_percentage,
+ download=dataset_config["download"],
+ limited=dataset_config["limited_supervision"],
+ size=dataset_config["dataset_percentage"],
)
valid_dataset = MLSDataset(
- dataset_path,
- language,
+ dataset_config["dataset_root_path"],
+ dataset_config["language_name"],
Split.TRAIN,
- download=False,
- limited=Falimited_supervisionlse,
- size=dataset_percentage,
+ download=dataset_config["download"],
+ limited=dataset_config["limited_supervision"],
+ size=dataset_config["dataset_percentage"],
)
- # TODO: initialize and possibly train tokenizer if none found
-
- kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}
-
- hparams = HParams(
- n_cnn_layers=3,
- n_rnn_layers=5,
- rnn_dim=512,
- n_class=tokenizer.get_vocab_size(),
- n_feats=128,
- stride=2,
- dropout=0.1,
- learning_rate=learning_rate,
- batch_size=batch_size,
- epochs=epochs,
- )
+ kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {}
+
+ if tokenizer_config["tokenizer_path"] is None:
+ print("Tokenizer not found!")
+ if click.confirm("Do you want to train a new tokenizer?", default=True):
+ pass
+ else:
+ return
+ tokenizer = CharTokenizer.train(
+ dataset_config["dataset_root_path"], dataset_config["language_name"]
+ )
+ tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
train_data_processing = DataProcessing("train", tokenizer)
valid_data_processing = DataProcessing("valid", tokenizer)
train_loader = DataLoader(
dataset=train_dataset,
- batch_size=hparams["batch_size"],
- shuffle=True,
+ batch_size=training_config["batch_size"],
+ shuffle=dataset_config["shuffle"],
collate_fn=train_data_processing,
**kwargs,
)
valid_loader = DataLoader(
dataset=valid_dataset,
- batch_size=hparams["batch_size"],
- shuffle=False,
+ batch_size=training_config["batch_size"],
+ shuffle=dataset_config["shuffle"],
collate_fn=valid_data_processing,
**kwargs,
)
model = SpeechRecognitionModel(
- hparams["n_cnn_layers"],
- hparams["n_rnn_layers"],
- hparams["rnn_dim"],
- hparams["n_class"],
- hparams["n_feats"],
- hparams["stride"],
- hparams["dropout"],
+ model_config["n_cnn_layers"],
+ model_config["n_rnn_layers"],
+ model_config["rnn_dim"],
+ tokenizer.get_vocab_size(),
+ model_config["n_feats"],
+ model_config["stride"],
+ model_config["dropout"],
).to(device)
- optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
+ optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"])
criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
- max_lr=hparams["learning_rate"],
+ max_lr=training_config["learning_rate"],
steps_per_epoch=int(len(train_loader)),
- epochs=hparams["epochs"],
+ epochs=training_config["epochs"],
anneal_strategy="linear",
)
prev_epoch = 0
- if model_load_path is not None:
- checkpoint = torch.load(model_load_path)
+ if checkpoints_config["model_load_path"] is not None:
+ checkpoint = torch.load(checkpoints_config["model_load_path"])
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]
iter_meter = IterMeter()
- if not os.path.isdir(os.path.dirname(model_save_path)):
- os.makedirs(os.path.dirname(model_save_path))
- for epoch in range(prev_epoch + 1, epochs + 1):
- train_args: TrainArgs = dict(
- model=model,
- device=device,
- train_loader=train_loader,
- criterion=criterion,
- optimizer=optimizer,
- scheduler=scheduler,
- epoch=epoch,
- iter_meter=iter_meter,
- )
+
+ for epoch in range(prev_epoch + 1, training_config["epochs"] + 1):
+ train_args: TrainArgs = {
+ "model": model,
+ "device": device,
+ "train_loader": train_loader,
+ "criterion": criterion,
+ "optimizer": optimizer,
+ "scheduler": scheduler,
+ "epoch": epoch,
+ "iter_meter": iter_meter,
+ }
train_loss = train(train_args)
test_loss, test_cer, test_wer = 0, 0, 0
- test_args: TestArgs = dict(
- model=model,
- device=device,
- test_loader=valid_loader,
- criterion=criterion,
- tokenizer=tokenizer,
- decoder="greedy",
- )
+ test_args: TestArgs = {
+ "model": model,
+ "device": device,
+ "test_loader": valid_loader,
+ "criterion": criterion,
+ "tokenizer": tokenizer,
+ "decoder": "greedy",
+ }
- if epoch % eval_every == 0:
+ if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0:
test_loss, test_cer, test_wer = test(test_args)
- if model_save_path is None:
+ if checkpoints_config["model_save_path"] is None:
continue
- if not os.path.isdir(os.path.dirname(model_save_path)):
- os.makedirs(os.path.dirname(model_save_path))
+ if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])):
+ os.makedirs(os.path.dirname(checkpoints_config["model_save_path"]))
+
torch.save(
{
"epoch": epoch,
@@ -322,7 +314,7 @@ def main(
"avg_cer": test_cer,
"avg_wer": test_wer,
},
- model_save_path + str(epoch),
+ checkpoints_config["model_save_path"] + str(epoch),
)
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index e939e1d..0e06eec 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -1,13 +1,12 @@
"""Class containing utils for the ASR system."""
import os
from enum import Enum
-from typing import TypedDict
import numpy as np
import torch
import torchaudio
from torch import Tensor, nn
-from torch.utils.data import Dataset
+from torch.utils.data import DataLoader, Dataset
from torchaudio.datasets.utils import _extract_tar
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -125,7 +124,7 @@ class MLSDataset(Dataset):
self._handle_download_dataset(download)
self._validate_local_directory()
- if limited and (split == Split.TRAIN or split == Split.VALID):
+ if limited and split in (Split.TRAIN, Split.VALID):
self.initialize_limited()
else:
self.initialize()
@@ -351,8 +350,6 @@ class MLSDataset(Dataset):
if __name__ == "__main__":
- from torch.utils.data import DataLoader
-
DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
LANGUAGE = "mls_german_opus"
split = Split.DEV
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 5482bbe..22569eb 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -1,8 +1,6 @@
"""Tokenizer for Multilingual Librispeech datasets"""
-
-
-from datetime import datetime
import os
+from datetime import datetime
from tqdm.autonotebook import tqdm
@@ -119,8 +117,8 @@ class CharTokenizer:
line = line.strip()
if line:
char, index = line.split()
- tokenizer.char_map[char] = int(index)
- tokenizer.index_map[int(index)] = char
+ load_tokenizer.char_map[char] = int(index)
+ load_tokenizer.index_map[int(index)] = char
return load_tokenizer
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
index 80f942a..a55d0d5 100644
--- a/swr2_asr/utils/visualization.py
+++ b/swr2_asr/utils/visualization.py
@@ -6,10 +6,10 @@ import torch
def plot(epochs, path):
"""Plots the losses over the epochs"""
- losses = list()
- test_losses = list()
- cers = list()
- wers = list()
+ losses = []
+ test_losses = []
+ cers = []
+ wers = []
for epoch in range(1, epochs + 1):
current_state = torch.load(path + str(epoch))
losses.append(current_state["loss"])