From c09ff76ba6f4c5dd5de64a401efcd27449150aec Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 15:05:21 +0200 Subject: added support for lm decoder during training --- swr2_asr/train.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9c7ede9..1e57ba0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split -from swr2_asr.utils.decoder import greedy_decoder -from swr2_asr.utils.tokenizer import CharTokenizer - +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.loss_scores import cer, wer +from swr2_asr.utils.tokenizer import CharTokenizer class IterMeter: @@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: # get values from test_args: model, device, test_loader, criterion, tokenizer, decoder = test_args.values() - if decoder == "greedy": - decoder = greedy_decoder - model.eval() test_loss = 0 test_cer, test_wer = [], [] @@ -141,12 +137,13 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths, tokenizer - ) + decoded_targets = tokenizer.decode_batch(labels) + decoded_preds = decoder(output.transpose(0, 1)) for j, _ in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], decoded_preds[j])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + if j >= len(decoded_targets): + break + test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0])) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) @@ -187,6 +184,7 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) + decoder_config = config_dict.get("decoder", {}) if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) @@ -262,12 +260,19 @@ def main(config_path: str): if checkpoints_config["model_load_path"] is not None: checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) + state_dict = { + k[len("module.") :] if k.startswith("module.") else k: v + for k, v in checkpoint["model_state_dict"].items() + } + + model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() + decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config) + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): train_args: TrainArgs = { "model": model, @@ -283,14 +288,13 @@ def main(config_path: str): train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = { "model": model, "device": device, "test_loader": valid_loader, "criterion": criterion, "tokenizer": tokenizer, - "decoder": "greedy", + "decoder": decoder, } if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: -- cgit v1.2.3 From d44cf7b1cab683a8aa3876619c82226f4e6d6f3b Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 18:11:33 +0200 Subject: fix --- .gitignore | 4 + Dockerfile | 13 --- Makefile | 9 -- config.cluster.yaml | 34 ------ config.philipp.yaml | 2 +- data/own/Philipp_HerrK.flac | Bin 0 -> 595064 bytes lm_decoder_hparams.ipynb | 245 ++++++++++++++++++++++++++++++++++++++++++++ metrics.csv | 69 +++++++++++++ plots.ipynb | 131 +++++++++++++++++++++++ swr2_asr/train.py | 6 +- swr2_asr/utils/decoder.py | 3 +- 11 files changed, 456 insertions(+), 60 deletions(-) delete mode 100644 Dockerfile delete mode 100644 Makefile delete mode 100644 config.cluster.yaml create mode 100644 data/own/Philipp_HerrK.flac create mode 100644 lm_decoder_hparams.ipynb create mode 100644 metrics.csv create mode 100644 plots.ipynb (limited to 'swr2_asr/train.py') diff --git a/.gitignore b/.gitignore index 485df5b..061bfca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,11 @@ +# pictures +**/*.png + # Training files data/* !data/tokenizers !data/own +!data/metrics.csv # Mac **/.DS_Store diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index ca7463f..0000000 --- a/Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -FROM python:3.10 - -# install python poetry -RUN curl -sSL https://install.python-poetry.org | python3 - - -WORKDIR /app - -COPY readme.md mypy.ini poetry.lock pyproject.toml ./ -COPY swr2_asr ./swr2_asr -ENV POETRY_VIRTUALENVS_IN_PROJECT=true -RUN /root/.local/bin/poetry --no-interaction install --without dev - -ENTRYPOINT [ "/root/.local/bin/poetry", "run", "python", "-m", "swr2_asr" ] diff --git a/Makefile b/Makefile deleted file mode 100644 index a37644c..0000000 --- a/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -format: - @poetry run black . - -format-check: - @poetry run black --check . - -lint: - @poetry run mypy --strict swr2_asr - @poetry run pylint swr2_asr \ No newline at end of file diff --git a/config.cluster.yaml b/config.cluster.yaml deleted file mode 100644 index 7af0aca..0000000 --- a/config.cluster.yaml +++ /dev/null @@ -1,34 +0,0 @@ -model: - n_cnn_layers: 3 - n_rnn_layers: 5 - rnn_dim: 512 - n_feats: 128 # number of mel features - stride: 2 - dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets - -training: - learning_rate: 0.0005 - batch_size: 400 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 150 - eval_every_n: 5 # evaluate every n epochs - num_workers: 12 # number of workers for dataloader - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically - -dataset: - download: False - dataset_root_path: "/mnt/lustre/mladm/mfa252/data" # files will be downloaded into this dir - language_name: "mls_german_opus" - limited_supervision: False # set to True if you want to use limited supervision - dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) - shuffle: True - -tokenizer: - tokenizer_path: "data/tokenizers/char_tokenizer_german.json" - -checkpoints: - model_load_path: "data/runs/epoch50" # path to load model from - model_save_path: "data/runs/epoch" # path to save model to - -inference: - model_load_path: ~ # path to load model from - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically diff --git a/config.philipp.yaml b/config.philipp.yaml index 608720f..7a93d05 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -18,7 +18,7 @@ tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" decoder: - type: "lm" # greedy, or lm (beam search) + type: "greedy" # greedy, or lm (beam search) lm: # config for lm decoder language_model_path: "data" # path where model and supplementary files are stored diff --git a/data/own/Philipp_HerrK.flac b/data/own/Philipp_HerrK.flac new file mode 100644 index 0000000..dec59e3 Binary files /dev/null and b/data/own/Philipp_HerrK.flac differ diff --git a/lm_decoder_hparams.ipynb b/lm_decoder_hparams.ipynb new file mode 100644 index 0000000..5e56312 --- /dev/null +++ b/lm_decoder_hparams.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "lm_weights = [0, 1.0, 2.5,]\n", + "word_score = [-1.5, 0.0, 1.5]\n", + "beam_sizes = [50, 500]\n", + "beam_thresholds = [50]\n", + "beam_size_token = [10, 38]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n", + "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from tqdm.autonotebook import tqdm\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn.functional as F\n", + "\n", + "from swr2_asr.utils.decoder import decoder_factory\n", + "from swr2_asr.utils.tokenizer import CharTokenizer\n", + "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n", + "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n", + "from swr2_asr.utils.loss_scores import cer, wer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "34aafd9aca2541748dc41d8550334536", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00= len(decoded_targets):\n", + " break\n", + " pred = \" \".join(decoded_preds[j][0].words).strip()\n", + " target = decoded_targets[j]\n", + " \n", + " test_cer.append(cer(pred, target))\n", + " test_wer.append(wer(pred, target))\n", + "\n", + " avg_cer = sum(test_cer) / len(test_cer)\n", + " avg_wer = sum(test_wer) / len(test_wer)\n", + " \n", + " if avg_wer < best_wer:\n", + " best_wer = avg_wer\n", + " best_cer = avg_cer\n", + " best_config = config\n", + " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n", + " print(\"Config: \", best_config)\n", + " print(\"LM Weight: \", lm_weight, \n", + " \" Word Score: \", ws, \n", + " \" Beam Size: \", beam_size, \n", + " \" Beam Threshold: \", beam_threshold, \n", + " \" Beam Size Token: \", beam_size_t)\n", + " print(\"--------------------------------------------------------------\")\n", + " \n", + " pbar.update(1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/metrics.csv b/metrics.csv new file mode 100644 index 0000000..22b8cec --- /dev/null +++ b/metrics.csv @@ -0,0 +1,69 @@ +epoch,train_loss,test_loss,cer,wer +0.0,3.25246262550354,3.0130836963653564,1.0,0.9999533337969454 +1.0,2.791025161743164,0.0,0.0,0.0 +2.0,1.5954065322875977,0.0,0.0,0.0 +3.0,1.3106564283370972,0.0,0.0,0.0 +4.0,1.206541895866394,0.0,0.0,0.0 +5.0,1.1116338968276978,0.9584052684355759,0.26248163774768096,0.8057431713202183 +6.0,1.0295032262802124,0.0,0.0,0.0 +7.0,0.957234263420105,0.0,0.0,0.0 +8.0,0.8958202004432678,0.0,0.0,0.0 +9.0,0.8403098583221436,0.0,0.0,0.0 +10.0,0.7934719324111938,0.577774976386505,0.1647645650587519,0.5597785267513198 +11.0,0.7537956833839417,0.0,0.0,0.0 +12.0,0.7180628776550293,0.0,0.0,0.0 +13.0,0.6870554089546204,0.0,0.0,0.0 +14.0,0.6595032811164856,0.0,0.0,0.0 +15.0,0.6374552845954895,0.42232042328030084,0.12030436712014228,0.43601402176865556 +16.0,0.6134707927703857,0.0,0.0,0.0 +17.0,0.5946973562240601,0.0,0.0,0.0 +18.0,0.577201783657074,0.0,0.0,0.0 +19.0,0.5612062811851501,0.0,0.0,0.0 +20.0,0.5256602764129639,0.33855139215787244,0.09390776269838304,0.35605188295180307 +21.0,0.5190389752388,0.0,0.0,0.0 +22.0,0.5163558721542358,0.0,0.0,0.0 +23.0,0.5132778286933899,0.0,0.0,0.0 +24.0,0.5090991854667664,0.0,0.0,0.0 +25.0,0.5072354078292847,0.32589933276176464,0.08999255619329079,0.341225825396658 +26.0,0.5023046731948853,0.0,0.0,0.0 +27.0,0.4994561970233917,0.0,0.0,0.0 +28.0,0.4942632019519806,0.0,0.0,0.0 +29.0,0.4906529486179352,0.0,0.0,0.0 +30.0,0.4855062663555145,0.29864962175995297,0.08296308087950884,0.3177622785738594 +31.0,0.4822919964790344,0.0,0.0,0.0 +32.0,0.4456436336040497,0.0,0.0,0.0 +33.0,0.4389857053756714,0.0,0.0,0.0 +34.0,0.43762147426605225,0.0,0.0,0.0 +35.0,0.4351556599140167,0.5776603897412618,0.16294622142152407,0.5232870602289124 +36.0,0.43377435207366943,0.0,0.0,0.0 +37.0,0.4318349063396454,0.0,0.0,0.0 +38.0,0.43010208010673523,0.0,0.0,0.0 +39.0,0.4276123046875,0.0,0.0,0.0 +40.0,0.4253982901573181,0.5735072294871012,0.1586969400218906,0.5131595862326734 +41.0,0.4236880838871002,0.0,0.0,0.0 +42.0,0.42077934741973877,0.0,0.0,0.0 +43.0,0.4181424081325531,0.0,0.0,0.0 +44.0,0.4154696464538574,0.0,0.0,0.0 +45.0,0.419731080532074,0.5696070055166881,0.15437095897735878,0.5002024974353078 +46.0,0.4099026024341583,0.0,0.0,0.0 +47.0,0.4078012704849243,0.0,0.0,0.0 +48.0,0.40490180253982544,0.0,0.0,0.0 +49.0,0.4024839699268341,0.0,0.0,0.0 +50.0,0.3694721758365631,0.5247387786706288,0.1450933666590186,0.4700957797096995 +51.0,0.36624056100845337,0.0,0.0,0.0 +52.0,0.36418089270591736,0.0,0.0,0.0 +53.0,0.36366793513298035,0.0,0.0,0.0 +54.0,0.36317530274391174,0.0,0.0,0.0 +55.0,0.3624136447906494,0.510421613852183,0.14174752623520492,0.4632967062415951 +56.0,0.36174166202545166,0.0,0.0,0.0 +57.0,0.36113062500953674,0.0,0.0,0.0 +58.0,0.36098596453666687,0.0,0.0,0.0 +59.0,0.35909315943717957,0.0,0.0,0.0 +60.0,0.36021551489830017,0.5095615088939668,0.14084592211118552,0.45461000263956114 +61.0,0.35837724804878235,0.0,0.0,0.0 +62.0,0.3567410409450531,0.0,0.0,0.0 +63.0,0.3565385341644287,0.0,0.0,0.0 +64.0,0.35535314679145813,0.0,0.0,0.0 +65.0,0.35792484879493713,0.5086047914293077,0.13893481611889835,0.45137245514066726 +66.0,0.35215333104133606,0.0,0.0,0.0 +67.0,0.35401859879493713,0.0,0.0,0.0 \ No newline at end of file diff --git a/plots.ipynb b/plots.ipynb new file mode 100644 index 0000000..716834a --- /dev/null +++ b/plots.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n", + "# test_loss, cer, wer should not be plotted if they are 0.0\n", + "# plot train_loss and test_loss in one plot\n", + "# plot cer and wer in one plot\n", + " \n", + "# save plots as png\n", + "\n", + "csv_path = \"metrics.csv\"\n", + "\n", + "# load csv\n", + "df = pd.read_csv(csv_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot train_loss and test_loss\n", + "# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n", + "plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n", + "\n", + "# create zip with epoch and test_loss for all epochs\n", + "# filter out all test_loss with value 0.0\n", + "# plot test_loss\n", + "epoch_loss = zip(df['epoch'], df['test_loss'])\n", + "epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n", + "plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n", + "\n", + "# add markers for test_loss\n", + "for x, y in epoch_loss:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + "\n", + "plt.xlabel('epoch')\n", + "plt.ylabel('loss')\n", + "plt.legend()\n", + "\n", + "# add ticks every 5 epochs\n", + "plt.xticks(range(0, 70, 5))\n", + "\n", + "# set y limits to 0\n", + "plt.ylim(bottom=0)\n", + "# reduce margins\n", + "plt.tight_layout()\n", + "# increase resolution\n", + "plt.savefig('train_test_loss.png', dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "epoch_cer = zip(df['epoch'], df['cer'])\n", + "epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n", + "plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n", + "\n", + "# add markers for cer\n", + "for x, y in epoch_cer:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + " \n", + "epoch_wer = zip(df['epoch'], df['wer'])\n", + "epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n", + "plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n", + "\n", + "# add markers for wer\n", + "for x, y in epoch_wer:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + " \n", + "# set y limits to 0 and 1\n", + "plt.ylim(bottom=0, top=1)\n", + "plt.xlabel('epoch')\n", + "plt.ylabel('error rate')\n", + "plt.legend()\n", + "# reduce margins\n", + "plt.tight_layout()\n", + "\n", + "# add ticks every 5 epochs\n", + "plt.xticks(range(0, 70, 5))\n", + "\n", + "# add ticks every 0.1 \n", + "plt.yticks([x/10 for x in range(0, 11, 1)])\n", + "\n", + "# increase resolution\n", + "plt.savefig('cer_wer.png', dpi=300)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 1e57ba0..5277c16 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -142,8 +142,10 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: for j, _ in enumerate(decoded_preds): if j >= len(decoded_targets): break - test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0])) + pred = " ".join(decoded_preds[j][0].words).strip() # batch, top, words + target = decoded_targets[j] + test_cer.append(cer(target, pred)) + test_wer.append(wer(target, pred)) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index 6ffdef2..1fd002a 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -98,7 +98,8 @@ class GreedyDecoder: if greedy_type == "inference": res = self.inference(output) - res = [[DecoderOutput(words=res)]] + res = [x.split(" ") for x in res] + res = [[DecoderOutput(x)] for x in res] return res def train(self, output, labels, label_lengths): -- cgit v1.2.3