diff options
-rwxr-xr-x | hpc.sh | 19 | ||||
-rwxr-xr-x | hpc_train.sh | 3 | ||||
-rw-r--r-- | pyproject.toml | 3 | ||||
-rw-r--r-- | swr2_asr/tokenizer.py | 41 | ||||
-rw-r--r-- | swr2_asr/train.py | 89 |
5 files changed, 120 insertions, 35 deletions
@@ -0,0 +1,19 @@ +#!/bin/bash + +#SBATCH --job-name=swr-teamprojekt +#SBATCH --partition=a100 +#SBATCH --time=00:30:00 + +### Note: --gres=gpu:x should equal to ntasks-per-node +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:a100:1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64gb +#SBATCH --chdir=/mnt/lustre/mladm/mfa252/SWR2-cool-projekt-main/ +#SBATCH --output=/mnt/lustre/mladm/mfa252/%x-%j.out + +source venv/bin/activate + +### the command to run +srun ./hpc_train.sh diff --git a/hpc_train.sh b/hpc_train.sh new file mode 100755 index 0000000..c7d1636 --- /dev/null +++ b/hpc_train.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +yes no | python -m swr2_asr.train --epochs=100 --batch_size=30 --dataset_path=/mnt/lustre/mladm/mfa252/data diff --git a/pyproject.toml b/pyproject.toml index b7e6ffb..eb17479 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ readme = "readme.md" packages = [{include = "swr2_asr"}] [tool.poetry.dependencies] -python = "~3.10" +python = "^3.10" torch = "2.0.0" torchaudio = "2.0.1" audioloader = {git = "https://github.com/marvinborner/AudioLoader.git"} @@ -16,6 +16,7 @@ tqdm = "^4.66.1" numpy = "^1.25.2" mido = "^1.3.0" tokenizers = "^0.13.3" +click = "^8.1.7" [tool.poetry.group.dev.dependencies] black = "^23.7.0" diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index d32e60d..d9cd622 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -1,4 +1,6 @@ """Tokenizer for use with Multilingual Librispeech""" +from dataclasses import dataclass +import json import os import click from tqdm import tqdm @@ -11,6 +13,13 @@ from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace +@dataclass +class Encoding: + """Simple dataclass to represent an encoding""" + + ids: list[int] + + class CharTokenizer: """Very simple tokenizer for use with Multilingual Librispeech @@ -98,7 +107,7 @@ class CharTokenizer: else: mapped_char = self.char_map[char] int_sequence.append(mapped_char) - return int_sequence + return Encoding(ids=int_sequence) def decode(self, labels: list[int], remove_special_tokens: bool = True): """Use a character map and convert integer labels to an text sequence @@ -110,11 +119,11 @@ class CharTokenizer: """ string = [] for i in labels: - if remove_special_tokens and self.index_map[i] == "<UNK>": + if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>": continue - if remove_special_tokens and self.index_map[i] == "<SPACE>": + if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>": string.append(" ") - string.append(self.index_map[i]) + string.append(self.index_map[f"{i}"]) return "".join(string).replace("<SPACE>", " ") def decode_batch(self, labels: list[list[int]]): @@ -134,16 +143,22 @@ class CharTokenizer: def save(self, path: str): """Save the tokenizer to a file""" with open(path, "w", encoding="utf-8") as file: - for char, index in self.char_map.items(): - file.write(f"{char} {index}\n") + # save it in the following format: + # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} + json.dump( + {"char_map": self.char_map, "index_map": self.index_map}, + file, + ensure_ascii=False, + ) def from_file(self, path: str): """Load the tokenizer from a file""" with open(path, "r", encoding="utf-8") as file: - for line in file.readlines(): - char, index = line.split(" ") - self.char_map[char] = int(index) - self.index_map[int(index)] = char + # load it in the following format: + # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} + saved_file = json.load(file) + self.char_map = saved_file["char_map"] + self.index_map = saved_file["index_map"] @click.command() @@ -303,8 +318,6 @@ def train_char_tokenizer( if __name__ == "__main__": tokenizer = CharTokenizer() - tokenizer.train("/Volumes/pherkel 2/SWR2-ASR", "mls_german_opus", "all") - - print(tokenizer.decode(tokenizer.encode("Fichier non trouvé"))) + tokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - tokenizer.save("tokenizer_chars.txt") + print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids)) diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 2628028..d13683f 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,4 +1,5 @@ """Training script for the ASR model.""" +import os import click import torch import torch.nn.functional as F @@ -7,6 +8,7 @@ from AudioLoader.speech import MultilingualLibriSpeech from torch import nn, optim from torch.utils.data import DataLoader from tokenizers import Tokenizer +from .tokenizer import CharTokenizer from .loss_scores import cer, wer @@ -18,7 +20,9 @@ train_audio_transforms = nn.Sequential( valid_audio_transforms = torchaudio.transforms.MelSpectrogram() -text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") +# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") +text_transform = CharTokenizer() +text_transform.from_file("data/tokenizers/char_tokenizer_german.json") def data_processing(data, data_type="train"): @@ -59,7 +63,11 @@ def greedy_decoder( targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append(text_transform.decode(labels[i][: label_lengths[i]].tolist())) + targets.append( + text_transform.decode( + [int(x) for x in labels[i][: label_lengths[i]].tolist()] + ) + ) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: @@ -269,6 +277,7 @@ def train( ({100.0 * batch_idx / len(train_loader)}%)]\t \ Loss: {loss.item()}" ) + return loss.item() def test(model, device, test_loader, criterion): @@ -305,13 +314,20 @@ def test(model, device, test_loader, criterion): ) -def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: +def run( + learning_rate: float, + batch_size: int, + epochs: int, + load: bool, + path: str, + dataset_path: str, +) -> None: """Runs the training script.""" hparams = { "n_cnn_layers": 3, "n_rnn_layers": 5, "rnn_dim": 512, - "n_class": 46, + "n_class": 36, "n_feats": 128, "stride": 2, "dropout": 0.1, @@ -325,21 +341,19 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") + download_dataset = not os.path.isdir(path) train_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="dev", download=False + dataset_path, "mls_german_opus", split="dev", download=download_dataset ) test_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="test", download=False + dataset_path, "mls_german_opus", split="test", download=False ) - kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} - train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, collate_fn=lambda x: data_processing(x, "train"), - **kwargs, ) test_loader = DataLoader( @@ -347,9 +361,12 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No batch_size=hparams["batch_size"], shuffle=True, collate_fn=lambda x: data_processing(x, "train"), - **kwargs, ) + # enable flag to find the most compatible algorithms in advance + if use_cuda: + torch.backends.cudnn.benchmark = True + model = SpeechRecognitionModel( hparams["n_cnn_layers"], hparams["n_rnn_layers"], @@ -363,10 +380,14 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No print( "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) ) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) - + if load: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + epoch = checkpoint["epoch"] + loss = checkpoint["loss"] scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hparams["learning_rate"], @@ -377,7 +398,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No iter_meter = IterMeter() for epoch in range(1, epochs + 1): - train( + loss = train( model, device, train_loader, @@ -387,17 +408,45 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No epoch, iter_meter, ) + test(model=model, device=device, test_loader=test_loader, criterion=criterion) + print("saving epoch", str(epoch)) + torch.save( + {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, + path + str(epoch), + ) @click.command() -@click.option("--learning-rate", default=1e-3, help="Learning rate") -@click.option("--batch_size", default=1, help="Batch size") +@click.option("--learning_rate", default=1e-3, help="Learning rate") +@click.option("--batch_size", default=10, help="Batch size") @click.option("--epochs", default=1, help="Number of epochs") -def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None: +@click.option("--load", default=False, help="Do you want to load a model?") +@click.option( + "--path", + default="model", + help="Path where the model will be saved to/loaded from", +) +@click.option( + "--dataset_path", + default="data/", + help="Path for the dataset directory", +) +def run_cli( + learning_rate: float, + batch_size: int, + epochs: int, + load: bool, + path: str, + dataset_path: str, +) -> None: """Runs the training script.""" - run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs) - -if __name__ == "__main__": - run(learning_rate=5e-4, batch_size=16, epochs=1) + run( + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, + load=load, + path=path, + dataset_path=dataset_path, + ) |