aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xhpc.sh19
-rwxr-xr-xhpc_train.sh3
-rw-r--r--pyproject.toml3
-rw-r--r--swr2_asr/tokenizer.py41
-rw-r--r--swr2_asr/train.py89
5 files changed, 120 insertions, 35 deletions
diff --git a/hpc.sh b/hpc.sh
new file mode 100755
index 0000000..ba0c5eb
--- /dev/null
+++ b/hpc.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+#SBATCH --job-name=swr-teamprojekt
+#SBATCH --partition=a100
+#SBATCH --time=00:30:00
+
+### Note: --gres=gpu:x should equal to ntasks-per-node
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:a100:1
+#SBATCH --cpus-per-task=8
+#SBATCH --mem=64gb
+#SBATCH --chdir=/mnt/lustre/mladm/mfa252/SWR2-cool-projekt-main/
+#SBATCH --output=/mnt/lustre/mladm/mfa252/%x-%j.out
+
+source venv/bin/activate
+
+### the command to run
+srun ./hpc_train.sh
diff --git a/hpc_train.sh b/hpc_train.sh
new file mode 100755
index 0000000..c7d1636
--- /dev/null
+++ b/hpc_train.sh
@@ -0,0 +1,3 @@
+#!/bin/sh
+
+yes no | python -m swr2_asr.train --epochs=100 --batch_size=30 --dataset_path=/mnt/lustre/mladm/mfa252/data
diff --git a/pyproject.toml b/pyproject.toml
index b7e6ffb..eb17479 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,7 @@ readme = "readme.md"
packages = [{include = "swr2_asr"}]
[tool.poetry.dependencies]
-python = "~3.10"
+python = "^3.10"
torch = "2.0.0"
torchaudio = "2.0.1"
audioloader = {git = "https://github.com/marvinborner/AudioLoader.git"}
@@ -16,6 +16,7 @@ tqdm = "^4.66.1"
numpy = "^1.25.2"
mido = "^1.3.0"
tokenizers = "^0.13.3"
+click = "^8.1.7"
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index d32e60d..d9cd622 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -1,4 +1,6 @@
"""Tokenizer for use with Multilingual Librispeech"""
+from dataclasses import dataclass
+import json
import os
import click
from tqdm import tqdm
@@ -11,6 +13,13 @@ from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
+@dataclass
+class Encoding:
+ """Simple dataclass to represent an encoding"""
+
+ ids: list[int]
+
+
class CharTokenizer:
"""Very simple tokenizer for use with Multilingual Librispeech
@@ -98,7 +107,7 @@ class CharTokenizer:
else:
mapped_char = self.char_map[char]
int_sequence.append(mapped_char)
- return int_sequence
+ return Encoding(ids=int_sequence)
def decode(self, labels: list[int], remove_special_tokens: bool = True):
"""Use a character map and convert integer labels to an text sequence
@@ -110,11 +119,11 @@ class CharTokenizer:
"""
string = []
for i in labels:
- if remove_special_tokens and self.index_map[i] == "<UNK>":
+ if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>":
continue
- if remove_special_tokens and self.index_map[i] == "<SPACE>":
+ if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>":
string.append(" ")
- string.append(self.index_map[i])
+ string.append(self.index_map[f"{i}"])
return "".join(string).replace("<SPACE>", " ")
def decode_batch(self, labels: list[list[int]]):
@@ -134,16 +143,22 @@ class CharTokenizer:
def save(self, path: str):
"""Save the tokenizer to a file"""
with open(path, "w", encoding="utf-8") as file:
- for char, index in self.char_map.items():
- file.write(f"{char} {index}\n")
+ # save it in the following format:
+ # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
+ json.dump(
+ {"char_map": self.char_map, "index_map": self.index_map},
+ file,
+ ensure_ascii=False,
+ )
def from_file(self, path: str):
"""Load the tokenizer from a file"""
with open(path, "r", encoding="utf-8") as file:
- for line in file.readlines():
- char, index = line.split(" ")
- self.char_map[char] = int(index)
- self.index_map[int(index)] = char
+ # load it in the following format:
+ # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
+ saved_file = json.load(file)
+ self.char_map = saved_file["char_map"]
+ self.index_map = saved_file["index_map"]
@click.command()
@@ -303,8 +318,6 @@ def train_char_tokenizer(
if __name__ == "__main__":
tokenizer = CharTokenizer()
- tokenizer.train("/Volumes/pherkel 2/SWR2-ASR", "mls_german_opus", "all")
-
- print(tokenizer.decode(tokenizer.encode("Fichier non trouvé")))
+ tokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
- tokenizer.save("tokenizer_chars.txt")
+ print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids))
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 2628028..d13683f 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,4 +1,5 @@
"""Training script for the ASR model."""
+import os
import click
import torch
import torch.nn.functional as F
@@ -7,6 +8,7 @@ from AudioLoader.speech import MultilingualLibriSpeech
from torch import nn, optim
from torch.utils.data import DataLoader
from tokenizers import Tokenizer
+from .tokenizer import CharTokenizer
from .loss_scores import cer, wer
@@ -18,7 +20,9 @@ train_audio_transforms = nn.Sequential(
valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
-text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+text_transform = CharTokenizer()
+text_transform.from_file("data/tokenizers/char_tokenizer_german.json")
def data_processing(data, data_type="train"):
@@ -59,7 +63,11 @@ def greedy_decoder(
targets = []
for i, args in enumerate(arg_maxes):
decode = []
- targets.append(text_transform.decode(labels[i][: label_lengths[i]].tolist()))
+ targets.append(
+ text_transform.decode(
+ [int(x) for x in labels[i][: label_lengths[i]].tolist()]
+ )
+ )
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
@@ -269,6 +277,7 @@ def train(
({100.0 * batch_idx / len(train_loader)}%)]\t \
Loss: {loss.item()}"
)
+ return loss.item()
def test(model, device, test_loader, criterion):
@@ -305,13 +314,20 @@ def test(model, device, test_loader, criterion):
)
-def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
+def run(
+ learning_rate: float,
+ batch_size: int,
+ epochs: int,
+ load: bool,
+ path: str,
+ dataset_path: str,
+) -> None:
"""Runs the training script."""
hparams = {
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
- "n_class": 46,
+ "n_class": 36,
"n_feats": 128,
"stride": 2,
"dropout": 0.1,
@@ -325,21 +341,19 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
# device = torch.device("mps")
+ download_dataset = not os.path.isdir(path)
train_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="dev", download=False
+ dataset_path, "mls_german_opus", split="dev", download=download_dataset
)
test_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="test", download=False
+ dataset_path, "mls_german_opus", split="test", download=False
)
- kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
-
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
collate_fn=lambda x: data_processing(x, "train"),
- **kwargs,
)
test_loader = DataLoader(
@@ -347,9 +361,12 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
batch_size=hparams["batch_size"],
shuffle=True,
collate_fn=lambda x: data_processing(x, "train"),
- **kwargs,
)
+ # enable flag to find the most compatible algorithms in advance
+ if use_cuda:
+ torch.backends.cudnn.benchmark = True
+
model = SpeechRecognitionModel(
hparams["n_cnn_layers"],
hparams["n_rnn_layers"],
@@ -363,10 +380,14 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
print(
"Num Model Parameters", sum([param.nelement() for param in model.parameters()])
)
-
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
criterion = nn.CTCLoss(blank=28).to(device)
-
+ if load:
+ checkpoint = torch.load(path)
+ model.load_state_dict(checkpoint["model_state_dict"])
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+ epoch = checkpoint["epoch"]
+ loss = checkpoint["loss"]
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=hparams["learning_rate"],
@@ -377,7 +398,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
iter_meter = IterMeter()
for epoch in range(1, epochs + 1):
- train(
+ loss = train(
model,
device,
train_loader,
@@ -387,17 +408,45 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
epoch,
iter_meter,
)
+
test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+ print("saving epoch", str(epoch))
+ torch.save(
+ {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
+ path + str(epoch),
+ )
@click.command()
-@click.option("--learning-rate", default=1e-3, help="Learning rate")
-@click.option("--batch_size", default=1, help="Batch size")
+@click.option("--learning_rate", default=1e-3, help="Learning rate")
+@click.option("--batch_size", default=10, help="Batch size")
@click.option("--epochs", default=1, help="Number of epochs")
-def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None:
+@click.option("--load", default=False, help="Do you want to load a model?")
+@click.option(
+ "--path",
+ default="model",
+ help="Path where the model will be saved to/loaded from",
+)
+@click.option(
+ "--dataset_path",
+ default="data/",
+ help="Path for the dataset directory",
+)
+def run_cli(
+ learning_rate: float,
+ batch_size: int,
+ epochs: int,
+ load: bool,
+ path: str,
+ dataset_path: str,
+) -> None:
"""Runs the training script."""
- run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs)
-
-if __name__ == "__main__":
- run(learning_rate=5e-4, batch_size=16, epochs=1)
+ run(
+ learning_rate=learning_rate,
+ batch_size=batch_size,
+ epochs=epochs,
+ load=load,
+ path=path,
+ dataset_path=dataset_path,
+ )