aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py89
1 files changed, 69 insertions, 20 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 2628028..d13683f 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,4 +1,5 @@
"""Training script for the ASR model."""
+import os
import click
import torch
import torch.nn.functional as F
@@ -7,6 +8,7 @@ from AudioLoader.speech import MultilingualLibriSpeech
from torch import nn, optim
from torch.utils.data import DataLoader
from tokenizers import Tokenizer
+from .tokenizer import CharTokenizer
from .loss_scores import cer, wer
@@ -18,7 +20,9 @@ train_audio_transforms = nn.Sequential(
valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
-text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+text_transform = CharTokenizer()
+text_transform.from_file("data/tokenizers/char_tokenizer_german.json")
def data_processing(data, data_type="train"):
@@ -59,7 +63,11 @@ def greedy_decoder(
targets = []
for i, args in enumerate(arg_maxes):
decode = []
- targets.append(text_transform.decode(labels[i][: label_lengths[i]].tolist()))
+ targets.append(
+ text_transform.decode(
+ [int(x) for x in labels[i][: label_lengths[i]].tolist()]
+ )
+ )
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
@@ -269,6 +277,7 @@ def train(
({100.0 * batch_idx / len(train_loader)}%)]\t \
Loss: {loss.item()}"
)
+ return loss.item()
def test(model, device, test_loader, criterion):
@@ -305,13 +314,20 @@ def test(model, device, test_loader, criterion):
)
-def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
+def run(
+ learning_rate: float,
+ batch_size: int,
+ epochs: int,
+ load: bool,
+ path: str,
+ dataset_path: str,
+) -> None:
"""Runs the training script."""
hparams = {
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
- "n_class": 46,
+ "n_class": 36,
"n_feats": 128,
"stride": 2,
"dropout": 0.1,
@@ -325,21 +341,19 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
# device = torch.device("mps")
+ download_dataset = not os.path.isdir(path)
train_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="dev", download=False
+ dataset_path, "mls_german_opus", split="dev", download=download_dataset
)
test_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="test", download=False
+ dataset_path, "mls_german_opus", split="test", download=False
)
- kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
-
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
collate_fn=lambda x: data_processing(x, "train"),
- **kwargs,
)
test_loader = DataLoader(
@@ -347,9 +361,12 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
batch_size=hparams["batch_size"],
shuffle=True,
collate_fn=lambda x: data_processing(x, "train"),
- **kwargs,
)
+ # enable flag to find the most compatible algorithms in advance
+ if use_cuda:
+ torch.backends.cudnn.benchmark = True
+
model = SpeechRecognitionModel(
hparams["n_cnn_layers"],
hparams["n_rnn_layers"],
@@ -363,10 +380,14 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
print(
"Num Model Parameters", sum([param.nelement() for param in model.parameters()])
)
-
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
criterion = nn.CTCLoss(blank=28).to(device)
-
+ if load:
+ checkpoint = torch.load(path)
+ model.load_state_dict(checkpoint["model_state_dict"])
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+ epoch = checkpoint["epoch"]
+ loss = checkpoint["loss"]
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=hparams["learning_rate"],
@@ -377,7 +398,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
iter_meter = IterMeter()
for epoch in range(1, epochs + 1):
- train(
+ loss = train(
model,
device,
train_loader,
@@ -387,17 +408,45 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
epoch,
iter_meter,
)
+
test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+ print("saving epoch", str(epoch))
+ torch.save(
+ {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
+ path + str(epoch),
+ )
@click.command()
-@click.option("--learning-rate", default=1e-3, help="Learning rate")
-@click.option("--batch_size", default=1, help="Batch size")
+@click.option("--learning_rate", default=1e-3, help="Learning rate")
+@click.option("--batch_size", default=10, help="Batch size")
@click.option("--epochs", default=1, help="Number of epochs")
-def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None:
+@click.option("--load", default=False, help="Do you want to load a model?")
+@click.option(
+ "--path",
+ default="model",
+ help="Path where the model will be saved to/loaded from",
+)
+@click.option(
+ "--dataset_path",
+ default="data/",
+ help="Path for the dataset directory",
+)
+def run_cli(
+ learning_rate: float,
+ batch_size: int,
+ epochs: int,
+ load: bool,
+ path: str,
+ dataset_path: str,
+) -> None:
"""Runs the training script."""
- run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs)
-
-if __name__ == "__main__":
- run(learning_rate=5e-4, batch_size=16, epochs=1)
+ run(
+ learning_rate=learning_rate,
+ batch_size=batch_size,
+ epochs=epochs,
+ load=load,
+ path=path,
+ dataset_path=dataset_path,
+ )