diff options
-rw-r--r-- | .gitignore | 4 | ||||
-rw-r--r-- | Dockerfile | 13 | ||||
-rw-r--r-- | Makefile | 9 | ||||
-rw-r--r-- | config.cluster.yaml | 34 | ||||
-rw-r--r-- | config.philipp.yaml | 2 | ||||
-rw-r--r-- | data/own/Philipp_HerrK.flac | bin | 0 -> 595064 bytes | |||
-rw-r--r-- | lm_decoder_hparams.ipynb | 245 | ||||
-rw-r--r-- | metrics.csv | 69 | ||||
-rw-r--r-- | plots.ipynb | 131 | ||||
-rw-r--r-- | swr2_asr/train.py | 6 | ||||
-rw-r--r-- | swr2_asr/utils/decoder.py | 3 |
11 files changed, 456 insertions, 60 deletions
@@ -1,7 +1,11 @@ +# pictures +**/*.png + # Training files data/* !data/tokenizers !data/own +!data/metrics.csv # Mac **/.DS_Store diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index ca7463f..0000000 --- a/Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -FROM python:3.10 - -# install python poetry -RUN curl -sSL https://install.python-poetry.org | python3 - - -WORKDIR /app - -COPY readme.md mypy.ini poetry.lock pyproject.toml ./ -COPY swr2_asr ./swr2_asr -ENV POETRY_VIRTUALENVS_IN_PROJECT=true -RUN /root/.local/bin/poetry --no-interaction install --without dev - -ENTRYPOINT [ "/root/.local/bin/poetry", "run", "python", "-m", "swr2_asr" ] diff --git a/Makefile b/Makefile deleted file mode 100644 index a37644c..0000000 --- a/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -format: - @poetry run black . - -format-check: - @poetry run black --check . - -lint: - @poetry run mypy --strict swr2_asr - @poetry run pylint swr2_asr
\ No newline at end of file diff --git a/config.cluster.yaml b/config.cluster.yaml deleted file mode 100644 index 7af0aca..0000000 --- a/config.cluster.yaml +++ /dev/null @@ -1,34 +0,0 @@ -model: - n_cnn_layers: 3 - n_rnn_layers: 5 - rnn_dim: 512 - n_feats: 128 # number of mel features - stride: 2 - dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets - -training: - learning_rate: 0.0005 - batch_size: 400 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 150 - eval_every_n: 5 # evaluate every n epochs - num_workers: 12 # number of workers for dataloader - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically - -dataset: - download: False - dataset_root_path: "/mnt/lustre/mladm/mfa252/data" # files will be downloaded into this dir - language_name: "mls_german_opus" - limited_supervision: False # set to True if you want to use limited supervision - dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) - shuffle: True - -tokenizer: - tokenizer_path: "data/tokenizers/char_tokenizer_german.json" - -checkpoints: - model_load_path: "data/runs/epoch50" # path to load model from - model_save_path: "data/runs/epoch" # path to save model to - -inference: - model_load_path: ~ # path to load model from - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically diff --git a/config.philipp.yaml b/config.philipp.yaml index 608720f..7a93d05 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -18,7 +18,7 @@ tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" decoder: - type: "lm" # greedy, or lm (beam search) + type: "greedy" # greedy, or lm (beam search) lm: # config for lm decoder language_model_path: "data" # path where model and supplementary files are stored diff --git a/data/own/Philipp_HerrK.flac b/data/own/Philipp_HerrK.flac Binary files differnew file mode 100644 index 0000000..dec59e3 --- /dev/null +++ b/data/own/Philipp_HerrK.flac diff --git a/lm_decoder_hparams.ipynb b/lm_decoder_hparams.ipynb new file mode 100644 index 0000000..5e56312 --- /dev/null +++ b/lm_decoder_hparams.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "lm_weights = [0, 1.0, 2.5,]\n", + "word_score = [-1.5, 0.0, 1.5]\n", + "beam_sizes = [50, 500]\n", + "beam_thresholds = [50]\n", + "beam_size_token = [10, 38]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n", + "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from tqdm.autonotebook import tqdm\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn.functional as F\n", + "\n", + "from swr2_asr.utils.decoder import decoder_factory\n", + "from swr2_asr.utils.tokenizer import CharTokenizer\n", + "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n", + "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n", + "from swr2_asr.utils.loss_scores import cer, wer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "34aafd9aca2541748dc41d8550334536", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Download flag not set, skipping download\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/functional/functional.py:576: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (201) may be set too low.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "New best WER: 0.8266228565397248 CER: 0.6048691547202959\n", + "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 10, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n", + "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 10 Beam Size Token: 10\n", + "--------------------------------------------------------------\n", + "New best WER: 0.7900706123452581 CER: 0.49197597466135945\n", + "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n", + "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 50 Beam Size Token: 10\n", + "--------------------------------------------------------------\n", + "New best WER: 0.7877685082828738 CER: 0.48660732878914315\n", + "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 100, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n", + "LM Weight: 0 Word Score: -1.5 Beam Size: 100 Beam Threshold: 50 Beam Size Token: 10\n", + "--------------------------------------------------------------\n" + ] + } + ], + "source": [ + "\n", + "\n", + "tokenizer = CharTokenizer.from_file(\"data/tokenizers/char_tokenizer_german.json\")\n", + "\n", + "# manually increment tqdm progress bar\n", + "pbar = tqdm(total=len(lm_weights) * len(word_score) * len(beam_sizes) * len(beam_thresholds) * len(beam_size_token))\n", + "\n", + "base_config = {\n", + " \"language\": \"german\",\n", + " \"language_model_path\": \"data\", # path where model and supplementary files are stored\n", + " \"n_gram\": 3, # n-gram size of ,the language model, 3 or 5\n", + " \"beam_size\": 50 ,\n", + " \"beam_threshold\": 50,\n", + " \"n_best\": 1,\n", + " \"lm_weight\": 2,\n", + " \"word_score\": 0,\n", + " }\n", + "\n", + "dataset_params = {\n", + " \"dataset_path\": \"/Volumes/pherkel 2/SWR2-ASR\",\n", + " \"language\": \"mls_german_opus\",\n", + " \"split\": Split.DEV,\n", + " \"limited\": True,\n", + " \"download\": False,\n", + " \"size\": 0.01,\n", + "}\n", + " \n", + "\n", + "model_params = {\n", + " \"n_cnn_layers\": 3,\n", + " \"n_rnn_layers\": 5,\n", + " \"rnn_dim\": 512,\n", + " \"n_class\": tokenizer.get_vocab_size(),\n", + " \"n_feats\": 128,\n", + " \"stride\": 2,\n", + " \"dropout\": 0.1,\n", + "}\n", + "\n", + "model = SpeechRecognitionModel(**model_params)\n", + "\n", + "checkpoint = torch.load(\"data/epoch67\", map_location=torch.device(\"cpu\"))\n", + "\n", + "state_dict = {\n", + " k[len(\"module.\") :] if k.startswith(\"module.\") else k: v\n", + " for k, v in checkpoint[\"model_state_dict\"].items()\n", + "}\n", + "model.load_state_dict(state_dict, strict=True)\n", + "model.eval()\n", + "\n", + "\n", + "dataset = MLSDataset(**dataset_params,)\n", + "\n", + "data_processing = DataProcessing(\"valid\", tokenizer, {\"n_feats\": model_params[\"n_feats\"]})\n", + "\n", + "dataloader = DataLoader(\n", + " dataset=dataset,\n", + " batch_size=16,\n", + " shuffle = False,\n", + " collate_fn=data_processing,\n", + " num_workers=8,\n", + " pin_memory=True,\n", + ")\n", + "\n", + "best_wer = 1.0\n", + "best_cer = 1.0\n", + "best_config = None\n", + "\n", + "for lm_weight in lm_weights:\n", + " for ws in word_score:\n", + " for beam_size in beam_sizes:\n", + " for beam_threshold in beam_thresholds:\n", + " for beam_size_t in beam_size_token:\n", + " config = base_config.copy()\n", + " config[\"lm_weight\"] = lm_weight\n", + " config[\"word_score\"] = ws\n", + " config[\"beam_size\"] = beam_size\n", + " config[\"beam_threshold\"] = beam_threshold\n", + " config[\"beam_size_token\"] = beam_size_t\n", + " \n", + " decoder = decoder_factory(\"lm\")(tokenizer, {\"lm\": config})\n", + " \n", + " test_cer, test_wer = [], []\n", + " with torch.no_grad():\n", + " model.eval()\n", + " for batch in dataloader:\n", + " # perform inference, decode, compute WER and CER\n", + " spectrograms, labels, input_lengths, label_lengths = batch\n", + " \n", + " output = model(spectrograms)\n", + " output = F.log_softmax(output, dim=2)\n", + " \n", + " decoded_preds = decoder(output)\n", + " decoded_targets = tokenizer.decode_batch(labels)\n", + " \n", + " for j, _ in enumerate(decoded_preds):\n", + " if j >= len(decoded_targets):\n", + " break\n", + " pred = \" \".join(decoded_preds[j][0].words).strip()\n", + " target = decoded_targets[j]\n", + " \n", + " test_cer.append(cer(pred, target))\n", + " test_wer.append(wer(pred, target))\n", + "\n", + " avg_cer = sum(test_cer) / len(test_cer)\n", + " avg_wer = sum(test_wer) / len(test_wer)\n", + " \n", + " if avg_wer < best_wer:\n", + " best_wer = avg_wer\n", + " best_cer = avg_cer\n", + " best_config = config\n", + " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n", + " print(\"Config: \", best_config)\n", + " print(\"LM Weight: \", lm_weight, \n", + " \" Word Score: \", ws, \n", + " \" Beam Size: \", beam_size, \n", + " \" Beam Threshold: \", beam_threshold, \n", + " \" Beam Size Token: \", beam_size_t)\n", + " print(\"--------------------------------------------------------------\")\n", + " \n", + " pbar.update(1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/metrics.csv b/metrics.csv new file mode 100644 index 0000000..22b8cec --- /dev/null +++ b/metrics.csv @@ -0,0 +1,69 @@ +epoch,train_loss,test_loss,cer,wer +0.0,3.25246262550354,3.0130836963653564,1.0,0.9999533337969454 +1.0,2.791025161743164,0.0,0.0,0.0 +2.0,1.5954065322875977,0.0,0.0,0.0 +3.0,1.3106564283370972,0.0,0.0,0.0 +4.0,1.206541895866394,0.0,0.0,0.0 +5.0,1.1116338968276978,0.9584052684355759,0.26248163774768096,0.8057431713202183 +6.0,1.0295032262802124,0.0,0.0,0.0 +7.0,0.957234263420105,0.0,0.0,0.0 +8.0,0.8958202004432678,0.0,0.0,0.0 +9.0,0.8403098583221436,0.0,0.0,0.0 +10.0,0.7934719324111938,0.577774976386505,0.1647645650587519,0.5597785267513198 +11.0,0.7537956833839417,0.0,0.0,0.0 +12.0,0.7180628776550293,0.0,0.0,0.0 +13.0,0.6870554089546204,0.0,0.0,0.0 +14.0,0.6595032811164856,0.0,0.0,0.0 +15.0,0.6374552845954895,0.42232042328030084,0.12030436712014228,0.43601402176865556 +16.0,0.6134707927703857,0.0,0.0,0.0 +17.0,0.5946973562240601,0.0,0.0,0.0 +18.0,0.577201783657074,0.0,0.0,0.0 +19.0,0.5612062811851501,0.0,0.0,0.0 +20.0,0.5256602764129639,0.33855139215787244,0.09390776269838304,0.35605188295180307 +21.0,0.5190389752388,0.0,0.0,0.0 +22.0,0.5163558721542358,0.0,0.0,0.0 +23.0,0.5132778286933899,0.0,0.0,0.0 +24.0,0.5090991854667664,0.0,0.0,0.0 +25.0,0.5072354078292847,0.32589933276176464,0.08999255619329079,0.341225825396658 +26.0,0.5023046731948853,0.0,0.0,0.0 +27.0,0.4994561970233917,0.0,0.0,0.0 +28.0,0.4942632019519806,0.0,0.0,0.0 +29.0,0.4906529486179352,0.0,0.0,0.0 +30.0,0.4855062663555145,0.29864962175995297,0.08296308087950884,0.3177622785738594 +31.0,0.4822919964790344,0.0,0.0,0.0 +32.0,0.4456436336040497,0.0,0.0,0.0 +33.0,0.4389857053756714,0.0,0.0,0.0 +34.0,0.43762147426605225,0.0,0.0,0.0 +35.0,0.4351556599140167,0.5776603897412618,0.16294622142152407,0.5232870602289124 +36.0,0.43377435207366943,0.0,0.0,0.0 +37.0,0.4318349063396454,0.0,0.0,0.0 +38.0,0.43010208010673523,0.0,0.0,0.0 +39.0,0.4276123046875,0.0,0.0,0.0 +40.0,0.4253982901573181,0.5735072294871012,0.1586969400218906,0.5131595862326734 +41.0,0.4236880838871002,0.0,0.0,0.0 +42.0,0.42077934741973877,0.0,0.0,0.0 +43.0,0.4181424081325531,0.0,0.0,0.0 +44.0,0.4154696464538574,0.0,0.0,0.0 +45.0,0.419731080532074,0.5696070055166881,0.15437095897735878,0.5002024974353078 +46.0,0.4099026024341583,0.0,0.0,0.0 +47.0,0.4078012704849243,0.0,0.0,0.0 +48.0,0.40490180253982544,0.0,0.0,0.0 +49.0,0.4024839699268341,0.0,0.0,0.0 +50.0,0.3694721758365631,0.5247387786706288,0.1450933666590186,0.4700957797096995 +51.0,0.36624056100845337,0.0,0.0,0.0 +52.0,0.36418089270591736,0.0,0.0,0.0 +53.0,0.36366793513298035,0.0,0.0,0.0 +54.0,0.36317530274391174,0.0,0.0,0.0 +55.0,0.3624136447906494,0.510421613852183,0.14174752623520492,0.4632967062415951 +56.0,0.36174166202545166,0.0,0.0,0.0 +57.0,0.36113062500953674,0.0,0.0,0.0 +58.0,0.36098596453666687,0.0,0.0,0.0 +59.0,0.35909315943717957,0.0,0.0,0.0 +60.0,0.36021551489830017,0.5095615088939668,0.14084592211118552,0.45461000263956114 +61.0,0.35837724804878235,0.0,0.0,0.0 +62.0,0.3567410409450531,0.0,0.0,0.0 +63.0,0.3565385341644287,0.0,0.0,0.0 +64.0,0.35535314679145813,0.0,0.0,0.0 +65.0,0.35792484879493713,0.5086047914293077,0.13893481611889835,0.45137245514066726 +66.0,0.35215333104133606,0.0,0.0,0.0 +67.0,0.35401859879493713,0.0,0.0,0.0
\ No newline at end of file diff --git a/plots.ipynb b/plots.ipynb new file mode 100644 index 0000000..716834a --- /dev/null +++ b/plots.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n", + "# test_loss, cer, wer should not be plotted if they are 0.0\n", + "# plot train_loss and test_loss in one plot\n", + "# plot cer and wer in one plot\n", + " \n", + "# save plots as png\n", + "\n", + "csv_path = \"metrics.csv\"\n", + "\n", + "# load csv\n", + "df = pd.read_csv(csv_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot train_loss and test_loss\n", + "# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n", + "plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n", + "\n", + "# create zip with epoch and test_loss for all epochs\n", + "# filter out all test_loss with value 0.0\n", + "# plot test_loss\n", + "epoch_loss = zip(df['epoch'], df['test_loss'])\n", + "epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n", + "plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n", + "\n", + "# add markers for test_loss\n", + "for x, y in epoch_loss:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + "\n", + "plt.xlabel('epoch')\n", + "plt.ylabel('loss')\n", + "plt.legend()\n", + "\n", + "# add ticks every 5 epochs\n", + "plt.xticks(range(0, 70, 5))\n", + "\n", + "# set y limits to 0\n", + "plt.ylim(bottom=0)\n", + "# reduce margins\n", + "plt.tight_layout()\n", + "# increase resolution\n", + "plt.savefig('train_test_loss.png', dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "epoch_cer = zip(df['epoch'], df['cer'])\n", + "epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n", + "plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n", + "\n", + "# add markers for cer\n", + "for x, y in epoch_cer:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + " \n", + "epoch_wer = zip(df['epoch'], df['wer'])\n", + "epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n", + "plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n", + "\n", + "# add markers for wer\n", + "for x, y in epoch_wer:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + " \n", + "# set y limits to 0 and 1\n", + "plt.ylim(bottom=0, top=1)\n", + "plt.xlabel('epoch')\n", + "plt.ylabel('error rate')\n", + "plt.legend()\n", + "# reduce margins\n", + "plt.tight_layout()\n", + "\n", + "# add ticks every 5 epochs\n", + "plt.xticks(range(0, 70, 5))\n", + "\n", + "# add ticks every 0.1 \n", + "plt.yticks([x/10 for x in range(0, 11, 1)])\n", + "\n", + "# increase resolution\n", + "plt.savefig('cer_wer.png', dpi=300)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 1e57ba0..5277c16 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -142,8 +142,10 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: for j, _ in enumerate(decoded_preds): if j >= len(decoded_targets): break - test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0])) + pred = " ".join(decoded_preds[j][0].words).strip() # batch, top, words + target = decoded_targets[j] + test_cer.append(cer(target, pred)) + test_wer.append(wer(target, pred)) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index 6ffdef2..1fd002a 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -98,7 +98,8 @@ class GreedyDecoder: if greedy_type == "inference": res = self.inference(output) - res = [[DecoderOutput(words=res)]] + res = [x.split(" ") for x in res] + res = [[DecoderOutput(x)] for x in res] return res def train(self, output, labels, label_lengths): |