aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore4
-rw-r--r--Dockerfile13
-rw-r--r--Makefile9
-rw-r--r--config.cluster.yaml34
-rw-r--r--config.philipp.yaml2
-rw-r--r--data/own/Philipp_HerrK.flacbin0 -> 595064 bytes
-rw-r--r--lm_decoder_hparams.ipynb245
-rw-r--r--metrics.csv69
-rw-r--r--plots.ipynb131
-rw-r--r--swr2_asr/train.py6
-rw-r--r--swr2_asr/utils/decoder.py3
11 files changed, 456 insertions, 60 deletions
diff --git a/.gitignore b/.gitignore
index 485df5b..061bfca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,11 @@
+# pictures
+**/*.png
+
# Training files
data/*
!data/tokenizers
!data/own
+!data/metrics.csv
# Mac
**/.DS_Store
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index ca7463f..0000000
--- a/Dockerfile
+++ /dev/null
@@ -1,13 +0,0 @@
-FROM python:3.10
-
-# install python poetry
-RUN curl -sSL https://install.python-poetry.org | python3 -
-
-WORKDIR /app
-
-COPY readme.md mypy.ini poetry.lock pyproject.toml ./
-COPY swr2_asr ./swr2_asr
-ENV POETRY_VIRTUALENVS_IN_PROJECT=true
-RUN /root/.local/bin/poetry --no-interaction install --without dev
-
-ENTRYPOINT [ "/root/.local/bin/poetry", "run", "python", "-m", "swr2_asr" ]
diff --git a/Makefile b/Makefile
deleted file mode 100644
index a37644c..0000000
--- a/Makefile
+++ /dev/null
@@ -1,9 +0,0 @@
-format:
- @poetry run black .
-
-format-check:
- @poetry run black --check .
-
-lint:
- @poetry run mypy --strict swr2_asr
- @poetry run pylint swr2_asr \ No newline at end of file
diff --git a/config.cluster.yaml b/config.cluster.yaml
deleted file mode 100644
index 7af0aca..0000000
--- a/config.cluster.yaml
+++ /dev/null
@@ -1,34 +0,0 @@
-model:
- n_cnn_layers: 3
- n_rnn_layers: 5
- rnn_dim: 512
- n_feats: 128 # number of mel features
- stride: 2
- dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets
-
-training:
- learning_rate: 0.0005
- batch_size: 400 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
- epochs: 150
- eval_every_n: 5 # evaluate every n epochs
- num_workers: 12 # number of workers for dataloader
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
-
-dataset:
- download: False
- dataset_root_path: "/mnt/lustre/mladm/mfa252/data" # files will be downloaded into this dir
- language_name: "mls_german_opus"
- limited_supervision: False # set to True if you want to use limited supervision
- dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
- shuffle: True
-
-tokenizer:
- tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
-
-checkpoints:
- model_load_path: "data/runs/epoch50" # path to load model from
- model_save_path: "data/runs/epoch" # path to save model to
-
-inference:
- model_load_path: ~ # path to load model from
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
diff --git a/config.philipp.yaml b/config.philipp.yaml
index 608720f..7a93d05 100644
--- a/config.philipp.yaml
+++ b/config.philipp.yaml
@@ -18,7 +18,7 @@ tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
decoder:
- type: "lm" # greedy, or lm (beam search)
+ type: "greedy" # greedy, or lm (beam search)
lm: # config for lm decoder
language_model_path: "data" # path where model and supplementary files are stored
diff --git a/data/own/Philipp_HerrK.flac b/data/own/Philipp_HerrK.flac
new file mode 100644
index 0000000..dec59e3
--- /dev/null
+++ b/data/own/Philipp_HerrK.flac
Binary files differ
diff --git a/lm_decoder_hparams.ipynb b/lm_decoder_hparams.ipynb
new file mode 100644
index 0000000..5e56312
--- /dev/null
+++ b/lm_decoder_hparams.ipynb
@@ -0,0 +1,245 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lm_weights = [0, 1.0, 2.5,]\n",
+ "word_score = [-1.5, 0.0, 1.5]\n",
+ "beam_sizes = [50, 500]\n",
+ "beam_thresholds = [50]\n",
+ "beam_size_token = [10, 38]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
+ " from tqdm.autonotebook import tqdm\n",
+ "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "from tqdm.autonotebook import tqdm\n",
+ "\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "from swr2_asr.utils.decoder import decoder_factory\n",
+ "from swr2_asr.utils.tokenizer import CharTokenizer\n",
+ "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n",
+ "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n",
+ "from swr2_asr.utils.loss_scores import cer, wer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "34aafd9aca2541748dc41d8550334536",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/144 [00:00<?, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Download flag not set, skipping download\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/functional/functional.py:576: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (201) may be set too low.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "New best WER: 0.8266228565397248 CER: 0.6048691547202959\n",
+ "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 10, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
+ "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 10 Beam Size Token: 10\n",
+ "--------------------------------------------------------------\n",
+ "New best WER: 0.7900706123452581 CER: 0.49197597466135945\n",
+ "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
+ "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 50 Beam Size Token: 10\n",
+ "--------------------------------------------------------------\n",
+ "New best WER: 0.7877685082828738 CER: 0.48660732878914315\n",
+ "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 100, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
+ "LM Weight: 0 Word Score: -1.5 Beam Size: 100 Beam Threshold: 50 Beam Size Token: 10\n",
+ "--------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "tokenizer = CharTokenizer.from_file(\"data/tokenizers/char_tokenizer_german.json\")\n",
+ "\n",
+ "# manually increment tqdm progress bar\n",
+ "pbar = tqdm(total=len(lm_weights) * len(word_score) * len(beam_sizes) * len(beam_thresholds) * len(beam_size_token))\n",
+ "\n",
+ "base_config = {\n",
+ " \"language\": \"german\",\n",
+ " \"language_model_path\": \"data\", # path where model and supplementary files are stored\n",
+ " \"n_gram\": 3, # n-gram size of ,the language model, 3 or 5\n",
+ " \"beam_size\": 50 ,\n",
+ " \"beam_threshold\": 50,\n",
+ " \"n_best\": 1,\n",
+ " \"lm_weight\": 2,\n",
+ " \"word_score\": 0,\n",
+ " }\n",
+ "\n",
+ "dataset_params = {\n",
+ " \"dataset_path\": \"/Volumes/pherkel 2/SWR2-ASR\",\n",
+ " \"language\": \"mls_german_opus\",\n",
+ " \"split\": Split.DEV,\n",
+ " \"limited\": True,\n",
+ " \"download\": False,\n",
+ " \"size\": 0.01,\n",
+ "}\n",
+ " \n",
+ "\n",
+ "model_params = {\n",
+ " \"n_cnn_layers\": 3,\n",
+ " \"n_rnn_layers\": 5,\n",
+ " \"rnn_dim\": 512,\n",
+ " \"n_class\": tokenizer.get_vocab_size(),\n",
+ " \"n_feats\": 128,\n",
+ " \"stride\": 2,\n",
+ " \"dropout\": 0.1,\n",
+ "}\n",
+ "\n",
+ "model = SpeechRecognitionModel(**model_params)\n",
+ "\n",
+ "checkpoint = torch.load(\"data/epoch67\", map_location=torch.device(\"cpu\"))\n",
+ "\n",
+ "state_dict = {\n",
+ " k[len(\"module.\") :] if k.startswith(\"module.\") else k: v\n",
+ " for k, v in checkpoint[\"model_state_dict\"].items()\n",
+ "}\n",
+ "model.load_state_dict(state_dict, strict=True)\n",
+ "model.eval()\n",
+ "\n",
+ "\n",
+ "dataset = MLSDataset(**dataset_params,)\n",
+ "\n",
+ "data_processing = DataProcessing(\"valid\", tokenizer, {\"n_feats\": model_params[\"n_feats\"]})\n",
+ "\n",
+ "dataloader = DataLoader(\n",
+ " dataset=dataset,\n",
+ " batch_size=16,\n",
+ " shuffle = False,\n",
+ " collate_fn=data_processing,\n",
+ " num_workers=8,\n",
+ " pin_memory=True,\n",
+ ")\n",
+ "\n",
+ "best_wer = 1.0\n",
+ "best_cer = 1.0\n",
+ "best_config = None\n",
+ "\n",
+ "for lm_weight in lm_weights:\n",
+ " for ws in word_score:\n",
+ " for beam_size in beam_sizes:\n",
+ " for beam_threshold in beam_thresholds:\n",
+ " for beam_size_t in beam_size_token:\n",
+ " config = base_config.copy()\n",
+ " config[\"lm_weight\"] = lm_weight\n",
+ " config[\"word_score\"] = ws\n",
+ " config[\"beam_size\"] = beam_size\n",
+ " config[\"beam_threshold\"] = beam_threshold\n",
+ " config[\"beam_size_token\"] = beam_size_t\n",
+ " \n",
+ " decoder = decoder_factory(\"lm\")(tokenizer, {\"lm\": config})\n",
+ " \n",
+ " test_cer, test_wer = [], []\n",
+ " with torch.no_grad():\n",
+ " model.eval()\n",
+ " for batch in dataloader:\n",
+ " # perform inference, decode, compute WER and CER\n",
+ " spectrograms, labels, input_lengths, label_lengths = batch\n",
+ " \n",
+ " output = model(spectrograms)\n",
+ " output = F.log_softmax(output, dim=2)\n",
+ " \n",
+ " decoded_preds = decoder(output)\n",
+ " decoded_targets = tokenizer.decode_batch(labels)\n",
+ " \n",
+ " for j, _ in enumerate(decoded_preds):\n",
+ " if j >= len(decoded_targets):\n",
+ " break\n",
+ " pred = \" \".join(decoded_preds[j][0].words).strip()\n",
+ " target = decoded_targets[j]\n",
+ " \n",
+ " test_cer.append(cer(pred, target))\n",
+ " test_wer.append(wer(pred, target))\n",
+ "\n",
+ " avg_cer = sum(test_cer) / len(test_cer)\n",
+ " avg_wer = sum(test_wer) / len(test_wer)\n",
+ " \n",
+ " if avg_wer < best_wer:\n",
+ " best_wer = avg_wer\n",
+ " best_cer = avg_cer\n",
+ " best_config = config\n",
+ " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n",
+ " print(\"Config: \", best_config)\n",
+ " print(\"LM Weight: \", lm_weight, \n",
+ " \" Word Score: \", ws, \n",
+ " \" Beam Size: \", beam_size, \n",
+ " \" Beam Threshold: \", beam_threshold, \n",
+ " \" Beam Size Token: \", beam_size_t)\n",
+ " print(\"--------------------------------------------------------------\")\n",
+ " \n",
+ " pbar.update(1)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/metrics.csv b/metrics.csv
new file mode 100644
index 0000000..22b8cec
--- /dev/null
+++ b/metrics.csv
@@ -0,0 +1,69 @@
+epoch,train_loss,test_loss,cer,wer
+0.0,3.25246262550354,3.0130836963653564,1.0,0.9999533337969454
+1.0,2.791025161743164,0.0,0.0,0.0
+2.0,1.5954065322875977,0.0,0.0,0.0
+3.0,1.3106564283370972,0.0,0.0,0.0
+4.0,1.206541895866394,0.0,0.0,0.0
+5.0,1.1116338968276978,0.9584052684355759,0.26248163774768096,0.8057431713202183
+6.0,1.0295032262802124,0.0,0.0,0.0
+7.0,0.957234263420105,0.0,0.0,0.0
+8.0,0.8958202004432678,0.0,0.0,0.0
+9.0,0.8403098583221436,0.0,0.0,0.0
+10.0,0.7934719324111938,0.577774976386505,0.1647645650587519,0.5597785267513198
+11.0,0.7537956833839417,0.0,0.0,0.0
+12.0,0.7180628776550293,0.0,0.0,0.0
+13.0,0.6870554089546204,0.0,0.0,0.0
+14.0,0.6595032811164856,0.0,0.0,0.0
+15.0,0.6374552845954895,0.42232042328030084,0.12030436712014228,0.43601402176865556
+16.0,0.6134707927703857,0.0,0.0,0.0
+17.0,0.5946973562240601,0.0,0.0,0.0
+18.0,0.577201783657074,0.0,0.0,0.0
+19.0,0.5612062811851501,0.0,0.0,0.0
+20.0,0.5256602764129639,0.33855139215787244,0.09390776269838304,0.35605188295180307
+21.0,0.5190389752388,0.0,0.0,0.0
+22.0,0.5163558721542358,0.0,0.0,0.0
+23.0,0.5132778286933899,0.0,0.0,0.0
+24.0,0.5090991854667664,0.0,0.0,0.0
+25.0,0.5072354078292847,0.32589933276176464,0.08999255619329079,0.341225825396658
+26.0,0.5023046731948853,0.0,0.0,0.0
+27.0,0.4994561970233917,0.0,0.0,0.0
+28.0,0.4942632019519806,0.0,0.0,0.0
+29.0,0.4906529486179352,0.0,0.0,0.0
+30.0,0.4855062663555145,0.29864962175995297,0.08296308087950884,0.3177622785738594
+31.0,0.4822919964790344,0.0,0.0,0.0
+32.0,0.4456436336040497,0.0,0.0,0.0
+33.0,0.4389857053756714,0.0,0.0,0.0
+34.0,0.43762147426605225,0.0,0.0,0.0
+35.0,0.4351556599140167,0.5776603897412618,0.16294622142152407,0.5232870602289124
+36.0,0.43377435207366943,0.0,0.0,0.0
+37.0,0.4318349063396454,0.0,0.0,0.0
+38.0,0.43010208010673523,0.0,0.0,0.0
+39.0,0.4276123046875,0.0,0.0,0.0
+40.0,0.4253982901573181,0.5735072294871012,0.1586969400218906,0.5131595862326734
+41.0,0.4236880838871002,0.0,0.0,0.0
+42.0,0.42077934741973877,0.0,0.0,0.0
+43.0,0.4181424081325531,0.0,0.0,0.0
+44.0,0.4154696464538574,0.0,0.0,0.0
+45.0,0.419731080532074,0.5696070055166881,0.15437095897735878,0.5002024974353078
+46.0,0.4099026024341583,0.0,0.0,0.0
+47.0,0.4078012704849243,0.0,0.0,0.0
+48.0,0.40490180253982544,0.0,0.0,0.0
+49.0,0.4024839699268341,0.0,0.0,0.0
+50.0,0.3694721758365631,0.5247387786706288,0.1450933666590186,0.4700957797096995
+51.0,0.36624056100845337,0.0,0.0,0.0
+52.0,0.36418089270591736,0.0,0.0,0.0
+53.0,0.36366793513298035,0.0,0.0,0.0
+54.0,0.36317530274391174,0.0,0.0,0.0
+55.0,0.3624136447906494,0.510421613852183,0.14174752623520492,0.4632967062415951
+56.0,0.36174166202545166,0.0,0.0,0.0
+57.0,0.36113062500953674,0.0,0.0,0.0
+58.0,0.36098596453666687,0.0,0.0,0.0
+59.0,0.35909315943717957,0.0,0.0,0.0
+60.0,0.36021551489830017,0.5095615088939668,0.14084592211118552,0.45461000263956114
+61.0,0.35837724804878235,0.0,0.0,0.0
+62.0,0.3567410409450531,0.0,0.0,0.0
+63.0,0.3565385341644287,0.0,0.0,0.0
+64.0,0.35535314679145813,0.0,0.0,0.0
+65.0,0.35792484879493713,0.5086047914293077,0.13893481611889835,0.45137245514066726
+66.0,0.35215333104133606,0.0,0.0,0.0
+67.0,0.35401859879493713,0.0,0.0,0.0 \ No newline at end of file
diff --git a/plots.ipynb b/plots.ipynb
new file mode 100644
index 0000000..716834a
--- /dev/null
+++ b/plots.ipynb
@@ -0,0 +1,131 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n",
+ "# test_loss, cer, wer should not be plotted if they are 0.0\n",
+ "# plot train_loss and test_loss in one plot\n",
+ "# plot cer and wer in one plot\n",
+ " \n",
+ "# save plots as png\n",
+ "\n",
+ "csv_path = \"metrics.csv\"\n",
+ "\n",
+ "# load csv\n",
+ "df = pd.read_csv(csv_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# plot train_loss and test_loss\n",
+ "# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n",
+ "plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n",
+ "\n",
+ "# create zip with epoch and test_loss for all epochs\n",
+ "# filter out all test_loss with value 0.0\n",
+ "# plot test_loss\n",
+ "epoch_loss = zip(df['epoch'], df['test_loss'])\n",
+ "epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n",
+ "plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n",
+ "\n",
+ "# add markers for test_loss\n",
+ "for x, y in epoch_loss:\n",
+ " plt.plot(x, y, marker='o', markersize=3, color='black')\n",
+ "\n",
+ "plt.xlabel('epoch')\n",
+ "plt.ylabel('loss')\n",
+ "plt.legend()\n",
+ "\n",
+ "# add ticks every 5 epochs\n",
+ "plt.xticks(range(0, 70, 5))\n",
+ "\n",
+ "# set y limits to 0\n",
+ "plt.ylim(bottom=0)\n",
+ "# reduce margins\n",
+ "plt.tight_layout()\n",
+ "# increase resolution\n",
+ "plt.savefig('train_test_loss.png', dpi=300)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "epoch_cer = zip(df['epoch'], df['cer'])\n",
+ "epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n",
+ "plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n",
+ "\n",
+ "# add markers for cer\n",
+ "for x, y in epoch_cer:\n",
+ " plt.plot(x, y, marker='o', markersize=3, color='black')\n",
+ " \n",
+ "epoch_wer = zip(df['epoch'], df['wer'])\n",
+ "epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n",
+ "plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n",
+ "\n",
+ "# add markers for wer\n",
+ "for x, y in epoch_wer:\n",
+ " plt.plot(x, y, marker='o', markersize=3, color='black')\n",
+ " \n",
+ "# set y limits to 0 and 1\n",
+ "plt.ylim(bottom=0, top=1)\n",
+ "plt.xlabel('epoch')\n",
+ "plt.ylabel('error rate')\n",
+ "plt.legend()\n",
+ "# reduce margins\n",
+ "plt.tight_layout()\n",
+ "\n",
+ "# add ticks every 5 epochs\n",
+ "plt.xticks(range(0, 70, 5))\n",
+ "\n",
+ "# add ticks every 0.1 \n",
+ "plt.yticks([x/10 for x in range(0, 11, 1)])\n",
+ "\n",
+ "# increase resolution\n",
+ "plt.savefig('cer_wer.png', dpi=300)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 1e57ba0..5277c16 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -142,8 +142,10 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
for j, _ in enumerate(decoded_preds):
if j >= len(decoded_targets):
break
- test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0]))
- test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0]))
+ pred = " ".join(decoded_preds[j][0].words).strip() # batch, top, words
+ target = decoded_targets[j]
+ test_cer.append(cer(target, pred))
+ test_wer.append(wer(target, pred))
avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index 6ffdef2..1fd002a 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -98,7 +98,8 @@ class GreedyDecoder:
if greedy_type == "inference":
res = self.inference(output)
- res = [[DecoderOutput(words=res)]]
+ res = [x.split(" ") for x in res]
+ res = [[DecoderOutput(x)] for x in res]
return res
def train(self, output, labels, label_lengths):