aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore7
-rw-r--r--Dockerfile13
-rw-r--r--Makefile9
-rw-r--r--config.cluster.yaml34
-rw-r--r--config.philipp.yaml55
-rw-r--r--config.yaml42
-rw-r--r--data/own/Philipp_HerrK.flacbin0 -> 595064 bytes
-rw-r--r--data/tokenizers/tokens_german.txt38
-rw-r--r--lm_decoder_hparams.ipynb245
-rw-r--r--metrics.csv69
-rw-r--r--plots.ipynb131
-rw-r--r--poetry.lock768
-rw-r--r--pyproject.toml7
-rw-r--r--swr2_asr/inference.py35
-rw-r--r--swr2_asr/train.py34
-rw-r--r--swr2_asr/utils/data.py20
-rw-r--r--swr2_asr/utils/decoder.py169
-rw-r--r--swr2_asr/utils/tokenizer.py14
18 files changed, 1535 insertions, 155 deletions
diff --git a/.gitignore b/.gitignore
index 6e14acc..061bfca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,11 @@
+# pictures
+**/*.png
+
# Training files
data/*
!data/tokenizers
!data/own
+!data/metrics.csv
# Mac
**/.DS_Store
@@ -67,8 +71,7 @@ cover/
*.mo
*.pot
-#Model
-YOUR
+
# Django stuff:
*.log
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index ca7463f..0000000
--- a/Dockerfile
+++ /dev/null
@@ -1,13 +0,0 @@
-FROM python:3.10
-
-# install python poetry
-RUN curl -sSL https://install.python-poetry.org | python3 -
-
-WORKDIR /app
-
-COPY readme.md mypy.ini poetry.lock pyproject.toml ./
-COPY swr2_asr ./swr2_asr
-ENV POETRY_VIRTUALENVS_IN_PROJECT=true
-RUN /root/.local/bin/poetry --no-interaction install --without dev
-
-ENTRYPOINT [ "/root/.local/bin/poetry", "run", "python", "-m", "swr2_asr" ]
diff --git a/Makefile b/Makefile
deleted file mode 100644
index a37644c..0000000
--- a/Makefile
+++ /dev/null
@@ -1,9 +0,0 @@
-format:
- @poetry run black .
-
-format-check:
- @poetry run black --check .
-
-lint:
- @poetry run mypy --strict swr2_asr
- @poetry run pylint swr2_asr \ No newline at end of file
diff --git a/config.cluster.yaml b/config.cluster.yaml
deleted file mode 100644
index 7af0aca..0000000
--- a/config.cluster.yaml
+++ /dev/null
@@ -1,34 +0,0 @@
-model:
- n_cnn_layers: 3
- n_rnn_layers: 5
- rnn_dim: 512
- n_feats: 128 # number of mel features
- stride: 2
- dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets
-
-training:
- learning_rate: 0.0005
- batch_size: 400 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
- epochs: 150
- eval_every_n: 5 # evaluate every n epochs
- num_workers: 12 # number of workers for dataloader
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
-
-dataset:
- download: False
- dataset_root_path: "/mnt/lustre/mladm/mfa252/data" # files will be downloaded into this dir
- language_name: "mls_german_opus"
- limited_supervision: False # set to True if you want to use limited supervision
- dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
- shuffle: True
-
-tokenizer:
- tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
-
-checkpoints:
- model_load_path: "data/runs/epoch50" # path to load model from
- model_save_path: "data/runs/epoch" # path to save model to
-
-inference:
- model_load_path: ~ # path to load model from
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
diff --git a/config.philipp.yaml b/config.philipp.yaml
index f72ce2e..7a93d05 100644
--- a/config.philipp.yaml
+++ b/config.philipp.yaml
@@ -1,34 +1,45 @@
+dataset:
+ download: True
+ dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir
+ language_name: "mls_german_opus"
+ limited_supervision: True # set to True if you want to use limited supervision
+ dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%)
+ shuffle: True
+
model:
n_cnn_layers: 3
n_rnn_layers: 5
rnn_dim: 512
n_feats: 128 # number of mel features
stride: 2
- dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets
-
-training:
- learning_rate: 0.0005
- batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
- epochs: 150
- eval_every_n: 5 # evaluate every n epochs
- num_workers: 4 # number of workers for dataloader
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
-
-dataset:
- download: true
- dataset_root_path: "data" # files will be downloaded into this dir
- language_name: "mls_german_opus"
- limited_supervision: false # set to True if you want to use limited supervision
- dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%)
- shuffle: true
+ dropout: 0.6 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets
tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
-checkpoints:
- model_load_path: "data/runs/epoch31" # path to load model from
- model_save_path: "data/runs/epoch" # path to save model to
+decoder:
+ type: "greedy" # greedy, or lm (beam search)
+
+ lm: # config for lm decoder
+ language_model_path: "data" # path where model and supplementary files are stored
+ language: "german"
+ n_gram: 3 # n-gram size of the language model, 3 or 5
+ beam_size: 50
+ beam_threshold: 50
+ n_best: 1
+ lm_weight: 2
+ word_score: 0
+
+training:
+ learning_rate: 0.0005
+ batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
+ epochs: 100
+ eval_every_n: 1 # evaluate every n epochs
+ num_workers: 8 # number of workers for dataloader
+
+checkpoints: # use "~" to disable saving/loading
+ model_load_path: "data/epoch67" # path to load model from
+ model_save_path: ~ # path to save model to
inference:
- model_load_path: "data/runs/epoch30" # path to load model from
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file
+ model_load_path: "data/epoch67" # path to load model from \ No newline at end of file
diff --git a/config.yaml b/config.yaml
index e5ff43a..d248d43 100644
--- a/config.yaml
+++ b/config.yaml
@@ -1,3 +1,11 @@
+dataset:
+ download: True
+ dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir
+ language_name: "mls_german_opus"
+ limited_supervision: False # set to True if you want to use limited supervision
+ dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
+ shuffle: True
+
model:
n_cnn_layers: 3
n_rnn_layers: 5
@@ -6,29 +14,33 @@ model:
stride: 2
dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets
+tokenizer:
+ tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
+
+decoder:
+ type: "greedy" # greedy, or lm (beam search)
+
+ lm: # config for lm decoder
+ language_model_path: "data" # path where model and supplementary files are stored
+ language: "german"
+ n_gram: 3 # n-gram size of the language model, 3 or 5
+ beam_size: 50
+ beam_threshold: 50
+ n_best: 1
+ lm_weight: 2,
+ word_score: 0,
+
training:
- learning_rate: 5e-4
+ learning_rate: 0.0005
batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
epochs: 3
eval_every_n: 3 # evaluate every n epochs
num_workers: 8 # number of workers for dataloader
-dataset:
- download: True
- dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir
- language_name: "mls_german_opus"
- limited_supervision: False # set to True if you want to use limited supervision
- dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
- shuffle: True
-
-tokenizer:
- tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml"
-
-checkpoints:
+checkpoints: # use "~" to disable saving/loading
model_load_path: "YOUR/PATH" # path to load model from
model_save_path: "YOUR/PATH" # path to save model to
inference:
model_load_path: "YOUR/PATH" # path to load model from
- beam_width: 10 # beam width for beam search
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file
+
diff --git a/data/own/Philipp_HerrK.flac b/data/own/Philipp_HerrK.flac
new file mode 100644
index 0000000..dec59e3
--- /dev/null
+++ b/data/own/Philipp_HerrK.flac
Binary files differ
diff --git a/data/tokenizers/tokens_german.txt b/data/tokenizers/tokens_german.txt
new file mode 100644
index 0000000..57f2c3a
--- /dev/null
+++ b/data/tokenizers/tokens_german.txt
@@ -0,0 +1,38 @@
+_
+<BLANK>
+<UNK>
+<SPACE>
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+-
+'
diff --git a/lm_decoder_hparams.ipynb b/lm_decoder_hparams.ipynb
new file mode 100644
index 0000000..5e56312
--- /dev/null
+++ b/lm_decoder_hparams.ipynb
@@ -0,0 +1,245 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lm_weights = [0, 1.0, 2.5,]\n",
+ "word_score = [-1.5, 0.0, 1.5]\n",
+ "beam_sizes = [50, 500]\n",
+ "beam_thresholds = [50]\n",
+ "beam_size_token = [10, 38]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
+ " from tqdm.autonotebook import tqdm\n",
+ "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "from tqdm.autonotebook import tqdm\n",
+ "\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "from swr2_asr.utils.decoder import decoder_factory\n",
+ "from swr2_asr.utils.tokenizer import CharTokenizer\n",
+ "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n",
+ "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n",
+ "from swr2_asr.utils.loss_scores import cer, wer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "34aafd9aca2541748dc41d8550334536",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/144 [00:00<?, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Download flag not set, skipping download\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/functional/functional.py:576: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (201) may be set too low.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "New best WER: 0.8266228565397248 CER: 0.6048691547202959\n",
+ "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 10, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
+ "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 10 Beam Size Token: 10\n",
+ "--------------------------------------------------------------\n",
+ "New best WER: 0.7900706123452581 CER: 0.49197597466135945\n",
+ "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
+ "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 50 Beam Size Token: 10\n",
+ "--------------------------------------------------------------\n",
+ "New best WER: 0.7877685082828738 CER: 0.48660732878914315\n",
+ "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 100, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
+ "LM Weight: 0 Word Score: -1.5 Beam Size: 100 Beam Threshold: 50 Beam Size Token: 10\n",
+ "--------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "tokenizer = CharTokenizer.from_file(\"data/tokenizers/char_tokenizer_german.json\")\n",
+ "\n",
+ "# manually increment tqdm progress bar\n",
+ "pbar = tqdm(total=len(lm_weights) * len(word_score) * len(beam_sizes) * len(beam_thresholds) * len(beam_size_token))\n",
+ "\n",
+ "base_config = {\n",
+ " \"language\": \"german\",\n",
+ " \"language_model_path\": \"data\", # path where model and supplementary files are stored\n",
+ " \"n_gram\": 3, # n-gram size of ,the language model, 3 or 5\n",
+ " \"beam_size\": 50 ,\n",
+ " \"beam_threshold\": 50,\n",
+ " \"n_best\": 1,\n",
+ " \"lm_weight\": 2,\n",
+ " \"word_score\": 0,\n",
+ " }\n",
+ "\n",
+ "dataset_params = {\n",
+ " \"dataset_path\": \"/Volumes/pherkel 2/SWR2-ASR\",\n",
+ " \"language\": \"mls_german_opus\",\n",
+ " \"split\": Split.DEV,\n",
+ " \"limited\": True,\n",
+ " \"download\": False,\n",
+ " \"size\": 0.01,\n",
+ "}\n",
+ " \n",
+ "\n",
+ "model_params = {\n",
+ " \"n_cnn_layers\": 3,\n",
+ " \"n_rnn_layers\": 5,\n",
+ " \"rnn_dim\": 512,\n",
+ " \"n_class\": tokenizer.get_vocab_size(),\n",
+ " \"n_feats\": 128,\n",
+ " \"stride\": 2,\n",
+ " \"dropout\": 0.1,\n",
+ "}\n",
+ "\n",
+ "model = SpeechRecognitionModel(**model_params)\n",
+ "\n",
+ "checkpoint = torch.load(\"data/epoch67\", map_location=torch.device(\"cpu\"))\n",
+ "\n",
+ "state_dict = {\n",
+ " k[len(\"module.\") :] if k.startswith(\"module.\") else k: v\n",
+ " for k, v in checkpoint[\"model_state_dict\"].items()\n",
+ "}\n",
+ "model.load_state_dict(state_dict, strict=True)\n",
+ "model.eval()\n",
+ "\n",
+ "\n",
+ "dataset = MLSDataset(**dataset_params,)\n",
+ "\n",
+ "data_processing = DataProcessing(\"valid\", tokenizer, {\"n_feats\": model_params[\"n_feats\"]})\n",
+ "\n",
+ "dataloader = DataLoader(\n",
+ " dataset=dataset,\n",
+ " batch_size=16,\n",
+ " shuffle = False,\n",
+ " collate_fn=data_processing,\n",
+ " num_workers=8,\n",
+ " pin_memory=True,\n",
+ ")\n",
+ "\n",
+ "best_wer = 1.0\n",
+ "best_cer = 1.0\n",
+ "best_config = None\n",
+ "\n",
+ "for lm_weight in lm_weights:\n",
+ " for ws in word_score:\n",
+ " for beam_size in beam_sizes:\n",
+ " for beam_threshold in beam_thresholds:\n",
+ " for beam_size_t in beam_size_token:\n",
+ " config = base_config.copy()\n",
+ " config[\"lm_weight\"] = lm_weight\n",
+ " config[\"word_score\"] = ws\n",
+ " config[\"beam_size\"] = beam_size\n",
+ " config[\"beam_threshold\"] = beam_threshold\n",
+ " config[\"beam_size_token\"] = beam_size_t\n",
+ " \n",
+ " decoder = decoder_factory(\"lm\")(tokenizer, {\"lm\": config})\n",
+ " \n",
+ " test_cer, test_wer = [], []\n",
+ " with torch.no_grad():\n",
+ " model.eval()\n",
+ " for batch in dataloader:\n",
+ " # perform inference, decode, compute WER and CER\n",
+ " spectrograms, labels, input_lengths, label_lengths = batch\n",
+ " \n",
+ " output = model(spectrograms)\n",
+ " output = F.log_softmax(output, dim=2)\n",
+ " \n",
+ " decoded_preds = decoder(output)\n",
+ " decoded_targets = tokenizer.decode_batch(labels)\n",
+ " \n",
+ " for j, _ in enumerate(decoded_preds):\n",
+ " if j >= len(decoded_targets):\n",
+ " break\n",
+ " pred = \" \".join(decoded_preds[j][0].words).strip()\n",
+ " target = decoded_targets[j]\n",
+ " \n",
+ " test_cer.append(cer(pred, target))\n",
+ " test_wer.append(wer(pred, target))\n",
+ "\n",
+ " avg_cer = sum(test_cer) / len(test_cer)\n",
+ " avg_wer = sum(test_wer) / len(test_wer)\n",
+ " \n",
+ " if avg_wer < best_wer:\n",
+ " best_wer = avg_wer\n",
+ " best_cer = avg_cer\n",
+ " best_config = config\n",
+ " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n",
+ " print(\"Config: \", best_config)\n",
+ " print(\"LM Weight: \", lm_weight, \n",
+ " \" Word Score: \", ws, \n",
+ " \" Beam Size: \", beam_size, \n",
+ " \" Beam Threshold: \", beam_threshold, \n",
+ " \" Beam Size Token: \", beam_size_t)\n",
+ " print(\"--------------------------------------------------------------\")\n",
+ " \n",
+ " pbar.update(1)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/metrics.csv b/metrics.csv
new file mode 100644
index 0000000..22b8cec
--- /dev/null
+++ b/metrics.csv
@@ -0,0 +1,69 @@
+epoch,train_loss,test_loss,cer,wer
+0.0,3.25246262550354,3.0130836963653564,1.0,0.9999533337969454
+1.0,2.791025161743164,0.0,0.0,0.0
+2.0,1.5954065322875977,0.0,0.0,0.0
+3.0,1.3106564283370972,0.0,0.0,0.0
+4.0,1.206541895866394,0.0,0.0,0.0
+5.0,1.1116338968276978,0.9584052684355759,0.26248163774768096,0.8057431713202183
+6.0,1.0295032262802124,0.0,0.0,0.0
+7.0,0.957234263420105,0.0,0.0,0.0
+8.0,0.8958202004432678,0.0,0.0,0.0
+9.0,0.8403098583221436,0.0,0.0,0.0
+10.0,0.7934719324111938,0.577774976386505,0.1647645650587519,0.5597785267513198
+11.0,0.7537956833839417,0.0,0.0,0.0
+12.0,0.7180628776550293,0.0,0.0,0.0
+13.0,0.6870554089546204,0.0,0.0,0.0
+14.0,0.6595032811164856,0.0,0.0,0.0
+15.0,0.6374552845954895,0.42232042328030084,0.12030436712014228,0.43601402176865556
+16.0,0.6134707927703857,0.0,0.0,0.0
+17.0,0.5946973562240601,0.0,0.0,0.0
+18.0,0.577201783657074,0.0,0.0,0.0
+19.0,0.5612062811851501,0.0,0.0,0.0
+20.0,0.5256602764129639,0.33855139215787244,0.09390776269838304,0.35605188295180307
+21.0,0.5190389752388,0.0,0.0,0.0
+22.0,0.5163558721542358,0.0,0.0,0.0
+23.0,0.5132778286933899,0.0,0.0,0.0
+24.0,0.5090991854667664,0.0,0.0,0.0
+25.0,0.5072354078292847,0.32589933276176464,0.08999255619329079,0.341225825396658
+26.0,0.5023046731948853,0.0,0.0,0.0
+27.0,0.4994561970233917,0.0,0.0,0.0
+28.0,0.4942632019519806,0.0,0.0,0.0
+29.0,0.4906529486179352,0.0,0.0,0.0
+30.0,0.4855062663555145,0.29864962175995297,0.08296308087950884,0.3177622785738594
+31.0,0.4822919964790344,0.0,0.0,0.0
+32.0,0.4456436336040497,0.0,0.0,0.0
+33.0,0.4389857053756714,0.0,0.0,0.0
+34.0,0.43762147426605225,0.0,0.0,0.0
+35.0,0.4351556599140167,0.5776603897412618,0.16294622142152407,0.5232870602289124
+36.0,0.43377435207366943,0.0,0.0,0.0
+37.0,0.4318349063396454,0.0,0.0,0.0
+38.0,0.43010208010673523,0.0,0.0,0.0
+39.0,0.4276123046875,0.0,0.0,0.0
+40.0,0.4253982901573181,0.5735072294871012,0.1586969400218906,0.5131595862326734
+41.0,0.4236880838871002,0.0,0.0,0.0
+42.0,0.42077934741973877,0.0,0.0,0.0
+43.0,0.4181424081325531,0.0,0.0,0.0
+44.0,0.4154696464538574,0.0,0.0,0.0
+45.0,0.419731080532074,0.5696070055166881,0.15437095897735878,0.5002024974353078
+46.0,0.4099026024341583,0.0,0.0,0.0
+47.0,0.4078012704849243,0.0,0.0,0.0
+48.0,0.40490180253982544,0.0,0.0,0.0
+49.0,0.4024839699268341,0.0,0.0,0.0
+50.0,0.3694721758365631,0.5247387786706288,0.1450933666590186,0.4700957797096995
+51.0,0.36624056100845337,0.0,0.0,0.0
+52.0,0.36418089270591736,0.0,0.0,0.0
+53.0,0.36366793513298035,0.0,0.0,0.0
+54.0,0.36317530274391174,0.0,0.0,0.0
+55.0,0.3624136447906494,0.510421613852183,0.14174752623520492,0.4632967062415951
+56.0,0.36174166202545166,0.0,0.0,0.0
+57.0,0.36113062500953674,0.0,0.0,0.0
+58.0,0.36098596453666687,0.0,0.0,0.0
+59.0,0.35909315943717957,0.0,0.0,0.0
+60.0,0.36021551489830017,0.5095615088939668,0.14084592211118552,0.45461000263956114
+61.0,0.35837724804878235,0.0,0.0,0.0
+62.0,0.3567410409450531,0.0,0.0,0.0
+63.0,0.3565385341644287,0.0,0.0,0.0
+64.0,0.35535314679145813,0.0,0.0,0.0
+65.0,0.35792484879493713,0.5086047914293077,0.13893481611889835,0.45137245514066726
+66.0,0.35215333104133606,0.0,0.0,0.0
+67.0,0.35401859879493713,0.0,0.0,0.0 \ No newline at end of file
diff --git a/plots.ipynb b/plots.ipynb
new file mode 100644
index 0000000..716834a
--- /dev/null
+++ b/plots.ipynb
@@ -0,0 +1,131 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n",
+ "# test_loss, cer, wer should not be plotted if they are 0.0\n",
+ "# plot train_loss and test_loss in one plot\n",
+ "# plot cer and wer in one plot\n",
+ " \n",
+ "# save plots as png\n",
+ "\n",
+ "csv_path = \"metrics.csv\"\n",
+ "\n",
+ "# load csv\n",
+ "df = pd.read_csv(csv_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# plot train_loss and test_loss\n",
+ "# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n",
+ "plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n",
+ "\n",
+ "# create zip with epoch and test_loss for all epochs\n",
+ "# filter out all test_loss with value 0.0\n",
+ "# plot test_loss\n",
+ "epoch_loss = zip(df['epoch'], df['test_loss'])\n",
+ "epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n",
+ "plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n",
+ "\n",
+ "# add markers for test_loss\n",
+ "for x, y in epoch_loss:\n",
+ " plt.plot(x, y, marker='o', markersize=3, color='black')\n",
+ "\n",
+ "plt.xlabel('epoch')\n",
+ "plt.ylabel('loss')\n",
+ "plt.legend()\n",
+ "\n",
+ "# add ticks every 5 epochs\n",
+ "plt.xticks(range(0, 70, 5))\n",
+ "\n",
+ "# set y limits to 0\n",
+ "plt.ylim(bottom=0)\n",
+ "# reduce margins\n",
+ "plt.tight_layout()\n",
+ "# increase resolution\n",
+ "plt.savefig('train_test_loss.png', dpi=300)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "epoch_cer = zip(df['epoch'], df['cer'])\n",
+ "epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n",
+ "plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n",
+ "\n",
+ "# add markers for cer\n",
+ "for x, y in epoch_cer:\n",
+ " plt.plot(x, y, marker='o', markersize=3, color='black')\n",
+ " \n",
+ "epoch_wer = zip(df['epoch'], df['wer'])\n",
+ "epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n",
+ "plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n",
+ "\n",
+ "# add markers for wer\n",
+ "for x, y in epoch_wer:\n",
+ " plt.plot(x, y, marker='o', markersize=3, color='black')\n",
+ " \n",
+ "# set y limits to 0 and 1\n",
+ "plt.ylim(bottom=0, top=1)\n",
+ "plt.xlabel('epoch')\n",
+ "plt.ylabel('error rate')\n",
+ "plt.legend()\n",
+ "# reduce margins\n",
+ "plt.tight_layout()\n",
+ "\n",
+ "# add ticks every 5 epochs\n",
+ "plt.xticks(range(0, 70, 5))\n",
+ "\n",
+ "# add ticks every 0.1 \n",
+ "plt.yticks([x/10 for x in range(0, 11, 1)])\n",
+ "\n",
+ "# increase resolution\n",
+ "plt.savefig('cer_wer.png', dpi=300)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/poetry.lock b/poetry.lock
index 5930525..f6fae21 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,15 @@
-# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
+
+[[package]]
+name = "appnope"
+version = "0.1.3"
+description = "Disable App Nap on macOS >= 10.9"
+optional = false
+python-versions = "*"
+files = [
+ {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"},
+ {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"},
+]
[[package]]
name = "astroid"
@@ -20,6 +31,34 @@ wrapt = [
]
[[package]]
+name = "asttokens"
+version = "2.4.0"
+description = "Annotate AST trees with source code positions"
+optional = false
+python-versions = "*"
+files = [
+ {file = "asttokens-2.4.0-py2.py3-none-any.whl", hash = "sha256:cf8fc9e61a86461aa9fb161a14a0841a03c405fa829ac6b202670b3495d2ce69"},
+ {file = "asttokens-2.4.0.tar.gz", hash = "sha256:2e0171b991b2c959acc6c49318049236844a5da1d65ba2672c4880c1c894834e"},
+]
+
+[package.dependencies]
+six = ">=1.12.0"
+
+[package.extras]
+test = ["astroid", "pytest"]
+
+[[package]]
+name = "backcall"
+version = "0.2.0"
+description = "Specifications for callback functions passed in to an API"
+optional = false
+python-versions = "*"
+files = [
+ {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"},
+ {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"},
+]
+
+[[package]]
name = "black"
version = "23.9.1"
description = "The uncompromising code formatter."
@@ -66,6 +105,82 @@ jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]]
+name = "cffi"
+version = "1.15.1"
+description = "Foreign Function Interface for Python calling C code."
+optional = false
+python-versions = "*"
+files = [
+ {file = "cffi-1.15.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a66d3508133af6e8548451b25058d5812812ec3798c886bf38ed24a98216fab2"},
+ {file = "cffi-1.15.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:470c103ae716238bbe698d67ad020e1db9d9dba34fa5a899b5e21577e6d52ed2"},
+ {file = "cffi-1.15.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:9ad5db27f9cabae298d151c85cf2bad1d359a1b9c686a275df03385758e2f914"},
+ {file = "cffi-1.15.1-cp27-cp27m-win32.whl", hash = "sha256:b3bbeb01c2b273cca1e1e0c5df57f12dce9a4dd331b4fa1635b8bec26350bde3"},
+ {file = "cffi-1.15.1-cp27-cp27m-win_amd64.whl", hash = "sha256:e00b098126fd45523dd056d2efba6c5a63b71ffe9f2bbe1a4fe1716e1d0c331e"},
+ {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:d61f4695e6c866a23a21acab0509af1cdfd2c013cf256bbf5b6b5e2695827162"},
+ {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:ed9cb427ba5504c1dc15ede7d516b84757c3e3d7868ccc85121d9310d27eed0b"},
+ {file = "cffi-1.15.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d39875251ca8f612b6f33e6b1195af86d1b3e60086068be9cc053aa4376e21"},
+ {file = "cffi-1.15.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:285d29981935eb726a4399badae8f0ffdff4f5050eaa6d0cfc3f64b857b77185"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3eb6971dcff08619f8d91607cfc726518b6fa2a9eba42856be181c6d0d9515fd"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21157295583fe8943475029ed5abdcf71eb3911894724e360acff1d61c1d54bc"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5635bd9cb9731e6d4a1132a498dd34f764034a8ce60cef4f5319c0541159392f"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2012c72d854c2d03e45d06ae57f40d78e5770d252f195b93f581acf3ba44496e"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd86c085fae2efd48ac91dd7ccffcfc0571387fe1193d33b6394db7ef31fe2a4"},
+ {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:fa6693661a4c91757f4412306191b6dc88c1703f780c8234035eac011922bc01"},
+ {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59c0b02d0a6c384d453fece7566d1c7e6b7bae4fc5874ef2ef46d56776d61c9e"},
+ {file = "cffi-1.15.1-cp310-cp310-win32.whl", hash = "sha256:cba9d6b9a7d64d4bd46167096fc9d2f835e25d7e4c121fb2ddfc6528fb0413b2"},
+ {file = "cffi-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:ce4bcc037df4fc5e3d184794f27bdaab018943698f4ca31630bc7f84a7b69c6d"},
+ {file = "cffi-1.15.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d08afd128ddaa624a48cf2b859afef385b720bb4b43df214f85616922e6a5ac"},
+ {file = "cffi-1.15.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3799aecf2e17cf585d977b780ce79ff0dc9b78d799fc694221ce814c2c19db83"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a591fe9e525846e4d154205572a029f653ada1a78b93697f3b5a8f1f2bc055b9"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3548db281cd7d2561c9ad9984681c95f7b0e38881201e157833a2342c30d5e8c"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91fc98adde3d7881af9b59ed0294046f3806221863722ba7d8d120c575314325"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94411f22c3985acaec6f83c6df553f2dbe17b698cc7f8ae751ff2237d96b9e3c"},
+ {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:03425bdae262c76aad70202debd780501fabeaca237cdfddc008987c0e0f59ef"},
+ {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cc4d65aeeaa04136a12677d3dd0b1c0c94dc43abac5860ab33cceb42b801c1e8"},
+ {file = "cffi-1.15.1-cp311-cp311-win32.whl", hash = "sha256:a0f100c8912c114ff53e1202d0078b425bee3649ae34d7b070e9697f93c5d52d"},
+ {file = "cffi-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:04ed324bda3cda42b9b695d51bb7d54b680b9719cfab04227cdd1e04e5de3104"},
+ {file = "cffi-1.15.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50a74364d85fd319352182ef59c5c790484a336f6db772c1a9231f1c3ed0cbd7"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e263d77ee3dd201c3a142934a086a4450861778baaeeb45db4591ef65550b0a6"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cec7d9412a9102bdc577382c3929b337320c4c4c4849f2c5cdd14d7368c5562d"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4289fc34b2f5316fbb762d75362931e351941fa95fa18789191b33fc4cf9504a"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:173379135477dc8cac4bc58f45db08ab45d228b3363adb7af79436135d028405"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6975a3fac6bc83c4a65c9f9fcab9e47019a11d3d2cf7f3c0d03431bf145a941e"},
+ {file = "cffi-1.15.1-cp36-cp36m-win32.whl", hash = "sha256:2470043b93ff09bf8fb1d46d1cb756ce6132c54826661a32d4e4d132e1977adf"},
+ {file = "cffi-1.15.1-cp36-cp36m-win_amd64.whl", hash = "sha256:30d78fbc8ebf9c92c9b7823ee18eb92f2e6ef79b45ac84db507f52fbe3ec4497"},
+ {file = "cffi-1.15.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:198caafb44239b60e252492445da556afafc7d1e3ab7a1fb3f0584ef6d742375"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5ef34d190326c3b1f822a5b7a45f6c4535e2f47ed06fec77d3d799c450b2651e"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8102eaf27e1e448db915d08afa8b41d6c7ca7a04b7d73af6514df10a3e74bd82"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5df2768244d19ab7f60546d0c7c63ce1581f7af8b5de3eb3004b9b6fc8a9f84b"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8c4917bd7ad33e8eb21e9a5bbba979b49d9a97acb3a803092cbc1133e20343c"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2642fe3142e4cc4af0799748233ad6da94c62a8bec3a6648bf8ee68b1c7426"},
+ {file = "cffi-1.15.1-cp37-cp37m-win32.whl", hash = "sha256:e229a521186c75c8ad9490854fd8bbdd9a0c9aa3a524326b55be83b54d4e0ad9"},
+ {file = "cffi-1.15.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a0b71b1b8fbf2b96e41c4d990244165e2c9be83d54962a9a1d118fd8657d2045"},
+ {file = "cffi-1.15.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:320dab6e7cb2eacdf0e658569d2575c4dad258c0fcc794f46215e1e39f90f2c3"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e74c6b51a9ed6589199c787bf5f9875612ca4a8a0785fb2d4a84429badaf22a"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5c84c68147988265e60416b57fc83425a78058853509c1b0629c180094904a5"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b926aa83d1edb5aa5b427b4053dc420ec295a08e40911296b9eb1b6170f6cca"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:87c450779d0914f2861b8526e035c5e6da0a3199d8f1add1a665e1cbc6fc6d02"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f2c9f67e9821cad2e5f480bc8d83b8742896f1242dba247911072d4fa94c192"},
+ {file = "cffi-1.15.1-cp38-cp38-win32.whl", hash = "sha256:8b7ee99e510d7b66cdb6c593f21c043c248537a32e0bedf02e01e9553a172314"},
+ {file = "cffi-1.15.1-cp38-cp38-win_amd64.whl", hash = "sha256:00a9ed42e88df81ffae7a8ab6d9356b371399b91dbdf0c3cb1e84c03a13aceb5"},
+ {file = "cffi-1.15.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:54a2db7b78338edd780e7ef7f9f6c442500fb0d41a5a4ea24fff1c929d5af585"},
+ {file = "cffi-1.15.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fcd131dd944808b5bdb38e6f5b53013c5aa4f334c5cad0c72742f6eba4b73db0"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7473e861101c9e72452f9bf8acb984947aa1661a7704553a9f6e4baa5ba64415"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c9a799e985904922a4d207a94eae35c78ebae90e128f0c4e521ce339396be9d"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3bcde07039e586f91b45c88f8583ea7cf7a0770df3a1649627bf598332cb6984"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33ab79603146aace82c2427da5ca6e58f2b3f2fb5da893ceac0c42218a40be35"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d598b938678ebf3c67377cdd45e09d431369c3b1a5b331058c338e201f12b27"},
+ {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db0fbb9c62743ce59a9ff687eb5f4afbe77e5e8403d6697f7446e5f609976f76"},
+ {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:98d85c6a2bef81588d9227dde12db8a7f47f639f4a17c9ae08e773aa9c697bf3"},
+ {file = "cffi-1.15.1-cp39-cp39-win32.whl", hash = "sha256:40f4774f5a9d4f5e344f31a32b5096977b5d48560c5592e2f3d2c4374bd543ee"},
+ {file = "cffi-1.15.1-cp39-cp39-win_amd64.whl", hash = "sha256:70df4e3b545a17496c9b3f41f5115e69a4f2e77e94e1d2a8e1070bc0c38c8a3c"},
+ {file = "cffi-1.15.1.tar.gz", hash = "sha256:d400bfb9a37b1351253cb402671cea7e89bdecc294e8016a707f6d1d8ac934f9"},
+]
+
+[package.dependencies]
+pycparser = "*"
+
+[[package]]
name = "click"
version = "8.1.7"
description = "Composable command line interface toolkit"
@@ -120,6 +235,25 @@ files = [
]
[[package]]
+name = "comm"
+version = "0.1.4"
+description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "comm-0.1.4-py3-none-any.whl", hash = "sha256:6d52794cba11b36ed9860999cd10fd02d6b2eac177068fdd585e1e2f8a96e67a"},
+ {file = "comm-0.1.4.tar.gz", hash = "sha256:354e40a59c9dd6db50c5cc6b4acc887d82e9603787f83b68c01a80a923984d15"},
+]
+
+[package.dependencies]
+traitlets = ">=4"
+
+[package.extras]
+lint = ["black (>=22.6.0)", "mdformat (>0.7)", "mdformat-gfm (>=0.3.5)", "ruff (>=0.0.156)"]
+test = ["pytest"]
+typing = ["mypy (>=0.990)"]
+
+[[package]]
name = "contourpy"
version = "1.1.0"
description = "Python library for calculating contours of 2D quadrilateral grids"
@@ -189,6 +323,44 @@ files = [
]
[[package]]
+name = "debugpy"
+version = "1.8.0"
+description = "An implementation of the Debug Adapter Protocol for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "debugpy-1.8.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7fb95ca78f7ac43393cd0e0f2b6deda438ec7c5e47fa5d38553340897d2fbdfb"},
+ {file = "debugpy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef9ab7df0b9a42ed9c878afd3eaaff471fce3fa73df96022e1f5c9f8f8c87ada"},
+ {file = "debugpy-1.8.0-cp310-cp310-win32.whl", hash = "sha256:a8b7a2fd27cd9f3553ac112f356ad4ca93338feadd8910277aff71ab24d8775f"},
+ {file = "debugpy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:5d9de202f5d42e62f932507ee8b21e30d49aae7e46d5b1dd5c908db1d7068637"},
+ {file = "debugpy-1.8.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:ef54404365fae8d45cf450d0544ee40cefbcb9cb85ea7afe89a963c27028261e"},
+ {file = "debugpy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60009b132c91951354f54363f8ebdf7457aeb150e84abba5ae251b8e9f29a8a6"},
+ {file = "debugpy-1.8.0-cp311-cp311-win32.whl", hash = "sha256:8cd0197141eb9e8a4566794550cfdcdb8b3db0818bdf8c49a8e8f8053e56e38b"},
+ {file = "debugpy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:a64093656c4c64dc6a438e11d59369875d200bd5abb8f9b26c1f5f723622e153"},
+ {file = "debugpy-1.8.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:b05a6b503ed520ad58c8dc682749113d2fd9f41ffd45daec16e558ca884008cd"},
+ {file = "debugpy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c6fb41c98ec51dd010d7ed650accfd07a87fe5e93eca9d5f584d0578f28f35f"},
+ {file = "debugpy-1.8.0-cp38-cp38-win32.whl", hash = "sha256:46ab6780159eeabb43c1495d9c84cf85d62975e48b6ec21ee10c95767c0590aa"},
+ {file = "debugpy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:bdc5ef99d14b9c0fcb35351b4fbfc06ac0ee576aeab6b2511702e5a648a2e595"},
+ {file = "debugpy-1.8.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:61eab4a4c8b6125d41a34bad4e5fe3d2cc145caecd63c3fe953be4cc53e65bf8"},
+ {file = "debugpy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:125b9a637e013f9faac0a3d6a82bd17c8b5d2c875fb6b7e2772c5aba6d082332"},
+ {file = "debugpy-1.8.0-cp39-cp39-win32.whl", hash = "sha256:57161629133113c97b387382045649a2b985a348f0c9366e22217c87b68b73c6"},
+ {file = "debugpy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:e3412f9faa9ade82aa64a50b602544efcba848c91384e9f93497a458767e6926"},
+ {file = "debugpy-1.8.0-py2.py3-none-any.whl", hash = "sha256:9c9b0ac1ce2a42888199df1a1906e45e6f3c9555497643a85e0bf2406e3ffbc4"},
+ {file = "debugpy-1.8.0.zip", hash = "sha256:12af2c55b419521e33d5fb21bd022df0b5eb267c3e178f1d374a63a2a6bdccd0"},
+]
+
+[[package]]
+name = "decorator"
+version = "5.1.1"
+description = "Decorators for Humans"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
+ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
+]
+
+[[package]]
name = "dill"
version = "0.3.7"
description = "serialize all of Python"
@@ -203,6 +375,34 @@ files = [
graph = ["objgraph (>=1.7.2)"]
[[package]]
+name = "exceptiongroup"
+version = "1.1.3"
+description = "Backport of PEP 654 (exception groups)"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"},
+ {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"},
+]
+
+[package.extras]
+test = ["pytest (>=6)"]
+
+[[package]]
+name = "executing"
+version = "1.2.0"
+description = "Get the currently executing AST node of a frame, and other information"
+optional = false
+python-versions = "*"
+files = [
+ {file = "executing-1.2.0-py2.py3-none-any.whl", hash = "sha256:0314a69e37426e3608aada02473b4161d4caf5a4b244d1d0c48072b8fee7bacc"},
+ {file = "executing-1.2.0.tar.gz", hash = "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107"},
+]
+
+[package.extras]
+tests = ["asttokens", "littleutils", "pytest", "rich"]
+
+[[package]]
name = "filelock"
version = "3.12.4"
description = "A platform independent file lock."
@@ -276,6 +476,78 @@ unicode = ["unicodedata2 (>=15.0.0)"]
woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"]
[[package]]
+name = "ipykernel"
+version = "6.25.2"
+description = "IPython Kernel for Jupyter"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "ipykernel-6.25.2-py3-none-any.whl", hash = "sha256:2e2ee359baba19f10251b99415bb39de1e97d04e1fab385646f24f0596510b77"},
+ {file = "ipykernel-6.25.2.tar.gz", hash = "sha256:f468ddd1f17acb48c8ce67fcfa49ba6d46d4f9ac0438c1f441be7c3d1372230b"},
+]
+
+[package.dependencies]
+appnope = {version = "*", markers = "platform_system == \"Darwin\""}
+comm = ">=0.1.1"
+debugpy = ">=1.6.5"
+ipython = ">=7.23.1"
+jupyter-client = ">=6.1.12"
+jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
+matplotlib-inline = ">=0.1"
+nest-asyncio = "*"
+packaging = "*"
+psutil = "*"
+pyzmq = ">=20"
+tornado = ">=6.1"
+traitlets = ">=5.4.0"
+
+[package.extras]
+cov = ["coverage[toml]", "curio", "matplotlib", "pytest-cov", "trio"]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "trio"]
+pyqt5 = ["pyqt5"]
+pyside6 = ["pyside6"]
+test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov", "pytest-timeout"]
+
+[[package]]
+name = "ipython"
+version = "8.15.0"
+description = "IPython: Productive Interactive Computing"
+optional = false
+python-versions = ">=3.9"
+files = [
+ {file = "ipython-8.15.0-py3-none-any.whl", hash = "sha256:45a2c3a529296870a97b7de34eda4a31bee16bc7bf954e07d39abe49caf8f887"},
+ {file = "ipython-8.15.0.tar.gz", hash = "sha256:2baeb5be6949eeebf532150f81746f8333e2ccce02de1c7eedde3f23ed5e9f1e"},
+]
+
+[package.dependencies]
+appnope = {version = "*", markers = "sys_platform == \"darwin\""}
+backcall = "*"
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+decorator = "*"
+exceptiongroup = {version = "*", markers = "python_version < \"3.11\""}
+jedi = ">=0.16"
+matplotlib-inline = "*"
+pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""}
+pickleshare = "*"
+prompt-toolkit = ">=3.0.30,<3.0.37 || >3.0.37,<3.1.0"
+pygments = ">=2.4.0"
+stack-data = "*"
+traitlets = ">=5"
+
+[package.extras]
+all = ["black", "curio", "docrepr", "exceptiongroup", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.21)", "pandas", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"]
+black = ["black"]
+doc = ["docrepr", "exceptiongroup", "ipykernel", "matplotlib", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "typing-extensions"]
+kernel = ["ipykernel"]
+nbconvert = ["nbconvert"]
+nbformat = ["nbformat"]
+notebook = ["ipywidgets", "notebook"]
+parallel = ["ipyparallel"]
+qtconsole = ["qtconsole"]
+test = ["pytest (<7.1)", "pytest-asyncio", "testpath"]
+test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"]
+
+[[package]]
name = "isort"
version = "5.12.0"
description = "A Python utility / library to sort Python imports."
@@ -293,6 +565,25 @@ plugins = ["setuptools"]
requirements-deprecated-finder = ["pip-api", "pipreqs"]
[[package]]
+name = "jedi"
+version = "0.19.0"
+description = "An autocompletion tool for Python that can be used for text editors."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "jedi-0.19.0-py2.py3-none-any.whl", hash = "sha256:cb8ce23fbccff0025e9386b5cf85e892f94c9b822378f8da49970471335ac64e"},
+ {file = "jedi-0.19.0.tar.gz", hash = "sha256:bcf9894f1753969cbac8022a8c2eaee06bfa3724e4192470aaffe7eb6272b0c4"},
+]
+
+[package.dependencies]
+parso = ">=0.8.3,<0.9.0"
+
+[package.extras]
+docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"]
+qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
+testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
+
+[[package]]
name = "jinja2"
version = "3.1.2"
description = "A very fast and expressive template engine."
@@ -310,6 +601,48 @@ MarkupSafe = ">=2.0"
i18n = ["Babel (>=2.7)"]
[[package]]
+name = "jupyter-client"
+version = "8.3.1"
+description = "Jupyter protocol implementation and client libraries"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyter_client-8.3.1-py3-none-any.whl", hash = "sha256:5eb9f55eb0650e81de6b7e34308d8b92d04fe4ec41cd8193a913979e33d8e1a5"},
+ {file = "jupyter_client-8.3.1.tar.gz", hash = "sha256:60294b2d5b869356c893f57b1a877ea6510d60d45cf4b38057f1672d85699ac9"},
+]
+
+[package.dependencies]
+jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0"
+python-dateutil = ">=2.8.2"
+pyzmq = ">=23.0"
+tornado = ">=6.2"
+traitlets = ">=5.3"
+
+[package.extras]
+docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"]
+test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"]
+
+[[package]]
+name = "jupyter-core"
+version = "5.3.1"
+description = "Jupyter core package. A base package on which Jupyter projects rely."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "jupyter_core-5.3.1-py3-none-any.whl", hash = "sha256:ae9036db959a71ec1cac33081eeb040a79e681f08ab68b0883e9a676c7a90dce"},
+ {file = "jupyter_core-5.3.1.tar.gz", hash = "sha256:5ba5c7938a7f97a6b0481463f7ff0dbac7c15ba48cf46fa4035ca6e838aa1aba"},
+]
+
+[package.dependencies]
+platformdirs = ">=2.5"
+pywin32 = {version = ">=300", markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\""}
+traitlets = ">=5.3"
+
+[package.extras]
+docs = ["myst-parser", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"]
+test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"]
+
+[[package]]
name = "kiwisolver"
version = "1.4.5"
description = "A fast implementation of the Cassowary constraint solver"
@@ -586,6 +919,20 @@ python-dateutil = ">=2.7"
setuptools_scm = ">=7"
[[package]]
+name = "matplotlib-inline"
+version = "0.1.6"
+description = "Inline Matplotlib backend for Jupyter"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"},
+ {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"},
+]
+
+[package.dependencies]
+traitlets = "*"
+
+[[package]]
name = "mccabe"
version = "0.7.0"
description = "McCabe checker, plugin for flake8"
@@ -698,6 +1045,17 @@ files = [
]
[[package]]
+name = "nest-asyncio"
+version = "1.5.8"
+description = "Patch asyncio to allow nested event loops"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "nest_asyncio-1.5.8-py3-none-any.whl", hash = "sha256:accda7a339a70599cb08f9dd09a67e0c2ef8d8d6f4c07f96ab203f2ae254e48d"},
+ {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"},
+]
+
+[[package]]
name = "networkx"
version = "3.1"
description = "Python package for creating and manipulating graphs and networks"
@@ -919,6 +1277,82 @@ files = [
]
[[package]]
+name = "pandas"
+version = "2.1.0"
+description = "Powerful data structures for data analysis, time series, and statistics"
+optional = false
+python-versions = ">=3.9"
+files = [
+ {file = "pandas-2.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:40dd20439ff94f1b2ed55b393ecee9cb6f3b08104c2c40b0cb7186a2f0046242"},
+ {file = "pandas-2.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d4f38e4fedeba580285eaac7ede4f686c6701a9e618d8a857b138a126d067f2f"},
+ {file = "pandas-2.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e6a0fe052cf27ceb29be9429428b4918f3740e37ff185658f40d8702f0b3e09"},
+ {file = "pandas-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d81e1813191070440d4c7a413cb673052b3b4a984ffd86b8dd468c45742d3cc"},
+ {file = "pandas-2.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eb20252720b1cc1b7d0b2879ffc7e0542dd568f24d7c4b2347cb035206936421"},
+ {file = "pandas-2.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:38f74ef7ebc0ffb43b3d633e23d74882bce7e27bfa09607f3c5d3e03ffd9a4a5"},
+ {file = "pandas-2.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cda72cc8c4761c8f1d97b169661f23a86b16fdb240bdc341173aee17e4d6cedd"},
+ {file = "pandas-2.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d97daeac0db8c993420b10da4f5f5b39b01fc9ca689a17844e07c0a35ac96b4b"},
+ {file = "pandas-2.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8c58b1113892e0c8078f006a167cc210a92bdae23322bb4614f2f0b7a4b510f"},
+ {file = "pandas-2.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:629124923bcf798965b054a540f9ccdfd60f71361255c81fa1ecd94a904b9dd3"},
+ {file = "pandas-2.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:70cf866af3ab346a10debba8ea78077cf3a8cd14bd5e4bed3d41555a3280041c"},
+ {file = "pandas-2.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:d53c8c1001f6a192ff1de1efe03b31a423d0eee2e9e855e69d004308e046e694"},
+ {file = "pandas-2.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:86f100b3876b8c6d1a2c66207288ead435dc71041ee4aea789e55ef0e06408cb"},
+ {file = "pandas-2.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28f330845ad21c11db51e02d8d69acc9035edfd1116926ff7245c7215db57957"},
+ {file = "pandas-2.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9a6ccf0963db88f9b12df6720e55f337447aea217f426a22d71f4213a3099a6"},
+ {file = "pandas-2.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d99e678180bc59b0c9443314297bddce4ad35727a1a2656dbe585fd78710b3b9"},
+ {file = "pandas-2.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b31da36d376d50a1a492efb18097b9101bdbd8b3fbb3f49006e02d4495d4c644"},
+ {file = "pandas-2.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:0164b85937707ec7f70b34a6c3a578dbf0f50787f910f21ca3b26a7fd3363437"},
+ {file = "pandas-2.1.0.tar.gz", hash = "sha256:62c24c7fc59e42b775ce0679cfa7b14a5f9bfb7643cfbe708c960699e05fb918"},
+]
+
+[package.dependencies]
+numpy = [
+ {version = ">=1.22.4", markers = "python_version < \"3.11\""},
+ {version = ">=1.23.2", markers = "python_version >= \"3.11\""},
+]
+python-dateutil = ">=2.8.2"
+pytz = ">=2020.1"
+tzdata = ">=2022.1"
+
+[package.extras]
+all = ["PyQt5 (>=5.15.6)", "SQLAlchemy (>=1.4.36)", "beautifulsoup4 (>=4.11.1)", "bottleneck (>=1.3.4)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=0.8.1)", "fsspec (>=2022.05.0)", "gcsfs (>=2022.05.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.8.0)", "matplotlib (>=3.6.1)", "numba (>=0.55.2)", "numexpr (>=2.8.0)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.10)", "pandas-gbq (>=0.17.5)", "psycopg2 (>=2.9.3)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.5)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "pyxlsb (>=1.0.9)", "qtpy (>=2.2.0)", "s3fs (>=2022.05.0)", "scipy (>=1.8.1)", "tables (>=3.7.0)", "tabulate (>=0.8.10)", "xarray (>=2022.03.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.3)", "zstandard (>=0.17.0)"]
+aws = ["s3fs (>=2022.05.0)"]
+clipboard = ["PyQt5 (>=5.15.6)", "qtpy (>=2.2.0)"]
+compression = ["zstandard (>=0.17.0)"]
+computation = ["scipy (>=1.8.1)", "xarray (>=2022.03.0)"]
+consortium-standard = ["dataframe-api-compat (>=0.1.7)"]
+excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.10)", "pyxlsb (>=1.0.9)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.3)"]
+feather = ["pyarrow (>=7.0.0)"]
+fss = ["fsspec (>=2022.05.0)"]
+gcp = ["gcsfs (>=2022.05.0)", "pandas-gbq (>=0.17.5)"]
+hdf5 = ["tables (>=3.7.0)"]
+html = ["beautifulsoup4 (>=4.11.1)", "html5lib (>=1.1)", "lxml (>=4.8.0)"]
+mysql = ["SQLAlchemy (>=1.4.36)", "pymysql (>=1.0.2)"]
+output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.8.10)"]
+parquet = ["pyarrow (>=7.0.0)"]
+performance = ["bottleneck (>=1.3.4)", "numba (>=0.55.2)", "numexpr (>=2.8.0)"]
+plot = ["matplotlib (>=3.6.1)"]
+postgresql = ["SQLAlchemy (>=1.4.36)", "psycopg2 (>=2.9.3)"]
+spss = ["pyreadstat (>=1.1.5)"]
+sql-other = ["SQLAlchemy (>=1.4.36)"]
+test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
+xml = ["lxml (>=4.8.0)"]
+
+[[package]]
+name = "parso"
+version = "0.8.3"
+description = "A Python Parser"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"},
+ {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"},
+]
+
+[package.extras]
+qa = ["flake8 (==3.8.3)", "mypy (==0.782)"]
+testing = ["docopt", "pytest (<6.0.0)"]
+
+[[package]]
name = "pathspec"
version = "0.11.2"
description = "Utility library for gitignore style pattern matching of file paths."
@@ -930,6 +1364,31 @@ files = [
]
[[package]]
+name = "pexpect"
+version = "4.8.0"
+description = "Pexpect allows easy control of interactive console applications."
+optional = false
+python-versions = "*"
+files = [
+ {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"},
+ {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"},
+]
+
+[package.dependencies]
+ptyprocess = ">=0.5"
+
+[[package]]
+name = "pickleshare"
+version = "0.7.5"
+description = "Tiny 'shelve'-like database with concurrency support"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"},
+ {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"},
+]
+
+[[package]]
name = "pillow"
version = "10.0.1"
description = "Python Imaging Library (Fork)"
@@ -1012,6 +1471,96 @@ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"]
[[package]]
+name = "prompt-toolkit"
+version = "3.0.39"
+description = "Library for building powerful interactive command lines in Python"
+optional = false
+python-versions = ">=3.7.0"
+files = [
+ {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"},
+ {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"},
+]
+
+[package.dependencies]
+wcwidth = "*"
+
+[[package]]
+name = "psutil"
+version = "5.9.5"
+description = "Cross-platform lib for process and system monitoring in Python."
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"},
+ {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"},
+ {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"},
+ {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"},
+ {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"},
+ {file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"},
+ {file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"},
+ {file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"},
+ {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"},
+ {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"},
+ {file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"},
+ {file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"},
+ {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"},
+ {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"},
+]
+
+[package.extras]
+test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
+
+[[package]]
+name = "ptyprocess"
+version = "0.7.0"
+description = "Run a subprocess in a pseudo terminal"
+optional = false
+python-versions = "*"
+files = [
+ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
+ {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"},
+]
+
+[[package]]
+name = "pure-eval"
+version = "0.2.2"
+description = "Safely evaluate AST nodes without side effects"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"},
+ {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"},
+]
+
+[package.extras]
+tests = ["pytest"]
+
+[[package]]
+name = "pycparser"
+version = "2.21"
+description = "C parser in Python"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
+ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
+]
+
+[[package]]
+name = "pygments"
+version = "2.16.1"
+description = "Pygments is a syntax highlighting package written in Python."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "Pygments-2.16.1-py3-none-any.whl", hash = "sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692"},
+ {file = "Pygments-2.16.1.tar.gz", hash = "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29"},
+]
+
+[package.extras]
+plugins = ["importlib-metadata"]
+
+[[package]]
name = "pylint"
version = "2.17.5"
description = "python code static checker"
@@ -1068,6 +1617,40 @@ files = [
six = ">=1.5"
[[package]]
+name = "pytz"
+version = "2023.3.post1"
+description = "World timezone definitions, modern and historical"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pytz-2023.3.post1-py2.py3-none-any.whl", hash = "sha256:ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7"},
+ {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"},
+]
+
+[[package]]
+name = "pywin32"
+version = "306"
+description = "Python for Window Extensions"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"},
+ {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"},
+ {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"},
+ {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"},
+ {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"},
+ {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"},
+ {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"},
+ {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"},
+ {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"},
+ {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"},
+ {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"},
+ {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"},
+ {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"},
+ {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"},
+]
+
+[[package]]
name = "pyyaml"
version = "6.0.1"
description = "YAML parser and emitter for Python"
@@ -1117,6 +1700,111 @@ files = [
]
[[package]]
+name = "pyzmq"
+version = "25.1.1"
+description = "Python bindings for 0MQ"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:381469297409c5adf9a0e884c5eb5186ed33137badcbbb0560b86e910a2f1e76"},
+ {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:955215ed0604dac5b01907424dfa28b40f2b2292d6493445dd34d0dfa72586a8"},
+ {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:985bbb1316192b98f32e25e7b9958088431d853ac63aca1d2c236f40afb17c83"},
+ {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:afea96f64efa98df4da6958bae37f1cbea7932c35878b185e5982821bc883369"},
+ {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76705c9325d72a81155bb6ab48d4312e0032bf045fb0754889133200f7a0d849"},
+ {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:77a41c26205d2353a4c94d02be51d6cbdf63c06fbc1295ea57dad7e2d3381b71"},
+ {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:12720a53e61c3b99d87262294e2b375c915fea93c31fc2336898c26d7aed34cd"},
+ {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:57459b68e5cd85b0be8184382cefd91959cafe79ae019e6b1ae6e2ba8a12cda7"},
+ {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:292fe3fc5ad4a75bc8df0dfaee7d0babe8b1f4ceb596437213821f761b4589f9"},
+ {file = "pyzmq-25.1.1-cp310-cp310-win32.whl", hash = "sha256:35b5ab8c28978fbbb86ea54958cd89f5176ce747c1fb3d87356cf698048a7790"},
+ {file = "pyzmq-25.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:11baebdd5fc5b475d484195e49bae2dc64b94a5208f7c89954e9e354fc609d8f"},
+ {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:d20a0ddb3e989e8807d83225a27e5c2eb2260eaa851532086e9e0fa0d5287d83"},
+ {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e1c1be77bc5fb77d923850f82e55a928f8638f64a61f00ff18a67c7404faf008"},
+ {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d89528b4943d27029a2818f847c10c2cecc79fa9590f3cb1860459a5be7933eb"},
+ {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90f26dc6d5f241ba358bef79be9ce06de58d477ca8485e3291675436d3827cf8"},
+ {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2b92812bd214018e50b6380ea3ac0c8bb01ac07fcc14c5f86a5bb25e74026e9"},
+ {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2f957ce63d13c28730f7fd6b72333814221c84ca2421298f66e5143f81c9f91f"},
+ {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:047a640f5c9c6ade7b1cc6680a0e28c9dd5a0825135acbd3569cc96ea00b2505"},
+ {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7f7e58effd14b641c5e4dec8c7dab02fb67a13df90329e61c869b9cc607ef752"},
+ {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c2910967e6ab16bf6fbeb1f771c89a7050947221ae12a5b0b60f3bca2ee19bca"},
+ {file = "pyzmq-25.1.1-cp311-cp311-win32.whl", hash = "sha256:76c1c8efb3ca3a1818b837aea423ff8a07bbf7aafe9f2f6582b61a0458b1a329"},
+ {file = "pyzmq-25.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:44e58a0554b21fc662f2712814a746635ed668d0fbc98b7cb9d74cb798d202e6"},
+ {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:e1ffa1c924e8c72778b9ccd386a7067cddf626884fd8277f503c48bb5f51c762"},
+ {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1af379b33ef33757224da93e9da62e6471cf4a66d10078cf32bae8127d3d0d4a"},
+ {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cff084c6933680d1f8b2f3b4ff5bbb88538a4aac00d199ac13f49d0698727ecb"},
+ {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2400a94f7dd9cb20cd012951a0cbf8249e3d554c63a9c0cdfd5cbb6c01d2dec"},
+ {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d81f1ddae3858b8299d1da72dd7d19dd36aab654c19671aa8a7e7fb02f6638a"},
+ {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:255ca2b219f9e5a3a9ef3081512e1358bd4760ce77828e1028b818ff5610b87b"},
+ {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a882ac0a351288dd18ecae3326b8a49d10c61a68b01419f3a0b9a306190baf69"},
+ {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:724c292bb26365659fc434e9567b3f1adbdb5e8d640c936ed901f49e03e5d32e"},
+ {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ca1ed0bb2d850aa8471387882247c68f1e62a4af0ce9c8a1dbe0d2bf69e41fb"},
+ {file = "pyzmq-25.1.1-cp312-cp312-win32.whl", hash = "sha256:b3451108ab861040754fa5208bca4a5496c65875710f76789a9ad27c801a0075"},
+ {file = "pyzmq-25.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:eadbefd5e92ef8a345f0525b5cfd01cf4e4cc651a2cffb8f23c0dd184975d787"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:db0b2af416ba735c6304c47f75d348f498b92952f5e3e8bff449336d2728795d"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7c133e93b405eb0d36fa430c94185bdd13c36204a8635470cccc200723c13bb"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:273bc3959bcbff3f48606b28229b4721716598d76b5aaea2b4a9d0ab454ec062"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cbc8df5c6a88ba5ae385d8930da02201165408dde8d8322072e3e5ddd4f68e22"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:18d43df3f2302d836f2a56f17e5663e398416e9dd74b205b179065e61f1a6edf"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:73461eed88a88c866656e08f89299720a38cb4e9d34ae6bf5df6f71102570f2e"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:34c850ce7976d19ebe7b9d4b9bb8c9dfc7aac336c0958e2651b88cbd46682123"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-win32.whl", hash = "sha256:d2045d6d9439a0078f2a34b57c7b18c4a6aef0bee37f22e4ec9f32456c852c71"},
+ {file = "pyzmq-25.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:458dea649f2f02a0b244ae6aef8dc29325a2810aa26b07af8374dc2a9faf57e3"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7cff25c5b315e63b07a36f0c2bab32c58eafbe57d0dce61b614ef4c76058c115"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1579413ae492b05de5a6174574f8c44c2b9b122a42015c5292afa4be2507f28"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3d0a409d3b28607cc427aa5c30a6f1e4452cc44e311f843e05edb28ab5e36da0"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:21eb4e609a154a57c520e3d5bfa0d97e49b6872ea057b7c85257b11e78068222"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:034239843541ef7a1aee0c7b2cb7f6aafffb005ede965ae9cbd49d5ff4ff73cf"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f8115e303280ba09f3898194791a153862cbf9eef722ad8f7f741987ee2a97c7"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1a5d26fe8f32f137e784f768143728438877d69a586ddeaad898558dc971a5ae"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-win32.whl", hash = "sha256:f32260e556a983bc5c7ed588d04c942c9a8f9c2e99213fec11a031e316874c7e"},
+ {file = "pyzmq-25.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:abf34e43c531bbb510ae7e8f5b2b1f2a8ab93219510e2b287a944432fad135f3"},
+ {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:87e34f31ca8f168c56d6fbf99692cc8d3b445abb5bfd08c229ae992d7547a92a"},
+ {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c9c6c9b2c2f80747a98f34ef491c4d7b1a8d4853937bb1492774992a120f475d"},
+ {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5619f3f5a4db5dbb572b095ea3cb5cc035335159d9da950830c9c4db2fbb6995"},
+ {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5a34d2395073ef862b4032343cf0c32a712f3ab49d7ec4f42c9661e0294d106f"},
+ {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25f0e6b78220aba09815cd1f3a32b9c7cb3e02cb846d1cfc526b6595f6046618"},
+ {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3669cf8ee3520c2f13b2e0351c41fea919852b220988d2049249db10046a7afb"},
+ {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2d163a18819277e49911f7461567bda923461c50b19d169a062536fffe7cd9d2"},
+ {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:df27ffddff4190667d40de7beba4a950b5ce78fe28a7dcc41d6f8a700a80a3c0"},
+ {file = "pyzmq-25.1.1-cp38-cp38-win32.whl", hash = "sha256:a382372898a07479bd34bda781008e4a954ed8750f17891e794521c3e21c2e1c"},
+ {file = "pyzmq-25.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:52533489f28d62eb1258a965f2aba28a82aa747202c8fa5a1c7a43b5db0e85c1"},
+ {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:03b3f49b57264909aacd0741892f2aecf2f51fb053e7d8ac6767f6c700832f45"},
+ {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:330f9e188d0d89080cde66dc7470f57d1926ff2fb5576227f14d5be7ab30b9fa"},
+ {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2ca57a5be0389f2a65e6d3bb2962a971688cbdd30b4c0bd188c99e39c234f414"},
+ {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d457aed310f2670f59cc5b57dcfced452aeeed77f9da2b9763616bd57e4dbaae"},
+ {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c56d748ea50215abef7030c72b60dd723ed5b5c7e65e7bc2504e77843631c1a6"},
+ {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f03d3f0d01cb5a018debeb412441996a517b11c5c17ab2001aa0597c6d6882c"},
+ {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:820c4a08195a681252f46926de10e29b6bbf3e17b30037bd4250d72dd3ddaab8"},
+ {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:17ef5f01d25b67ca8f98120d5fa1d21efe9611604e8eb03a5147360f517dd1e2"},
+ {file = "pyzmq-25.1.1-cp39-cp39-win32.whl", hash = "sha256:04ccbed567171579ec2cebb9c8a3e30801723c575601f9a990ab25bcac6b51e2"},
+ {file = "pyzmq-25.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:e61f091c3ba0c3578411ef505992d356a812fb200643eab27f4f70eed34a29ef"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ade6d25bb29c4555d718ac6d1443a7386595528c33d6b133b258f65f963bb0f6"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0c95ddd4f6e9fca4e9e3afaa4f9df8552f0ba5d1004e89ef0a68e1f1f9807c7"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48e466162a24daf86f6b5ca72444d2bf39a5e58da5f96370078be67c67adc978"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abc719161780932c4e11aaebb203be3d6acc6b38d2f26c0f523b5b59d2fc1996"},
+ {file = "pyzmq-25.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ccf825981640b8c34ae54231b7ed00271822ea1c6d8ba1090ebd4943759abf5"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c2f20ce161ebdb0091a10c9ca0372e023ce24980d0e1f810f519da6f79c60800"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:deee9ca4727f53464daf089536e68b13e6104e84a37820a88b0a057b97bba2d2"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aa8d6cdc8b8aa19ceb319aaa2b660cdaccc533ec477eeb1309e2a291eaacc43a"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019e59ef5c5256a2c7378f2fb8560fc2a9ff1d315755204295b2eab96b254d0a"},
+ {file = "pyzmq-25.1.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:b9af3757495c1ee3b5c4e945c1df7be95562277c6e5bccc20a39aec50f826cd0"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:548d6482dc8aadbe7e79d1b5806585c8120bafa1ef841167bc9090522b610fa6"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:057e824b2aae50accc0f9a0570998adc021b372478a921506fddd6c02e60308e"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2243700cc5548cff20963f0ca92d3e5e436394375ab8a354bbea2b12911b20b0"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79986f3b4af059777111409ee517da24a529bdbd46da578b33f25580adcff728"},
+ {file = "pyzmq-25.1.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:11d58723d44d6ed4dd677c5615b2ffb19d5c426636345567d6af82be4dff8a55"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:49d238cf4b69652257db66d0c623cd3e09b5d2e9576b56bc067a396133a00d4a"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fedbdc753827cf014c01dbbee9c3be17e5a208dcd1bf8641ce2cd29580d1f0d4"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc16ac425cc927d0a57d242589f87ee093884ea4804c05a13834d07c20db203c"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11c1d2aed9079c6b0c9550a7257a836b4a637feb334904610f06d70eb44c56d2"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e8a701123029cc240cea61dd2d16ad57cab4691804143ce80ecd9286b464d180"},
+ {file = "pyzmq-25.1.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:61706a6b6c24bdece85ff177fec393545a3191eeda35b07aaa1458a027ad1304"},
+ {file = "pyzmq-25.1.1.tar.gz", hash = "sha256:259c22485b71abacdfa8bf79720cd7bcf4b9d128b30ea554f01ae71fdbfdaa23"},
+]
+
+[package.dependencies]
+cffi = {version = "*", markers = "implementation_name == \"pypy\""}
+
+[[package]]
name = "ruff"
version = "0.0.285"
description = "An extremely fast Python linter, written in Rust."
@@ -1191,6 +1879,25 @@ files = [
]
[[package]]
+name = "stack-data"
+version = "0.6.2"
+description = "Extract data from python stack frames and tracebacks for informative displays"
+optional = false
+python-versions = "*"
+files = [
+ {file = "stack_data-0.6.2-py3-none-any.whl", hash = "sha256:cbb2a53eb64e5785878201a97ed7c7b94883f48b87bfb0bbe8b623c74679e4a8"},
+ {file = "stack_data-0.6.2.tar.gz", hash = "sha256:32d2dd0376772d01b6cb9fc996f3c8b57a357089dec328ed4b6553d037eaf815"},
+]
+
+[package.dependencies]
+asttokens = ">=2.1.0"
+executing = ">=1.2.0"
+pure-eval = "*"
+
+[package.extras]
+tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
+
+[[package]]
name = "sympy"
version = "1.12"
description = "Computer algebra system (CAS) in Python"
@@ -1368,6 +2075,26 @@ files = [
torch = "2.0.0"
[[package]]
+name = "tornado"
+version = "6.3.3"
+description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
+optional = false
+python-versions = ">= 3.8"
+files = [
+ {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:502fba735c84450974fec147340016ad928d29f1e91f49be168c0a4c18181e1d"},
+ {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:805d507b1f588320c26f7f097108eb4023bbaa984d63176d1652e184ba24270a"},
+ {file = "tornado-6.3.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bd19ca6c16882e4d37368e0152f99c099bad93e0950ce55e71daed74045908f"},
+ {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ac51f42808cca9b3613f51ffe2a965c8525cb1b00b7b2d56828b8045354f76a"},
+ {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71a8db65160a3c55d61839b7302a9a400074c9c753040455494e2af74e2501f2"},
+ {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:ceb917a50cd35882b57600709dd5421a418c29ddc852da8bcdab1f0db33406b0"},
+ {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:7d01abc57ea0dbb51ddfed477dfe22719d376119844e33c661d873bf9c0e4a16"},
+ {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:9dc4444c0defcd3929d5c1eb5706cbe1b116e762ff3e0deca8b715d14bf6ec17"},
+ {file = "tornado-6.3.3-cp38-abi3-win32.whl", hash = "sha256:65ceca9500383fbdf33a98c0087cb975b2ef3bfb874cb35b8de8740cf7f41bd3"},
+ {file = "tornado-6.3.3-cp38-abi3-win_amd64.whl", hash = "sha256:22d3c2fa10b5793da13c807e6fc38ff49a4f6e1e3868b0a6f4164768bb8e20f5"},
+ {file = "tornado-6.3.3.tar.gz", hash = "sha256:e7d8db41c0181c80d76c982aacc442c0783a2c54d6400fe028954201a2e032fe"},
+]
+
+[[package]]
name = "tqdm"
version = "4.66.1"
description = "Fast, Extensible Progress Meter"
@@ -1388,6 +2115,21 @@ slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
+name = "traitlets"
+version = "5.10.0"
+description = "Traitlets Python configuration system"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "traitlets-5.10.0-py3-none-any.whl", hash = "sha256:417745a96681fbb358e723d5346a547521f36e9bd0d50ba7ab368fff5d67aa54"},
+ {file = "traitlets-5.10.0.tar.gz", hash = "sha256:f584ea209240466e66e91f3c81aa7d004ba4cf794990b0c775938a1544217cd1"},
+]
+
+[package.extras]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
+test = ["argcomplete (>=3.0.3)", "mypy (>=1.5.1)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"]
+
+[[package]]
name = "triton"
version = "2.0.0"
description = "A language and compiler for custom Deep Learning operations"
@@ -1436,6 +2178,28 @@ files = [
]
[[package]]
+name = "tzdata"
+version = "2023.3"
+description = "Provider of IANA time zone data"
+optional = false
+python-versions = ">=2"
+files = [
+ {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"},
+ {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"},
+]
+
+[[package]]
+name = "wcwidth"
+version = "0.2.6"
+description = "Measures the displayed width of unicode strings in a terminal"
+optional = false
+python-versions = "*"
+files = [
+ {file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"},
+ {file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"},
+]
+
+[[package]]
name = "wheel"
version = "0.41.2"
description = "A built-package format for Python"
@@ -1536,4 +2300,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "e45a9c1ba8b67cbe83c4b010c3f4718eee990b064b90a3ccd64380387e734faf"
+content-hash = "71f6812aa2d6e8c0aa48f0deef4f0636a4e5efa9c12d5c49cae021a3ba832fd5"
diff --git a/pyproject.toml b/pyproject.toml
index f6d19dd..6f8915f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,12 +18,14 @@ tokenizers = "^0.13.3"
click = "^8.1.7"
matplotlib = "^3.7.2"
pyyaml = "^6.0.1"
+pandas = "^2.1.0"
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
mypy = "^1.5.1"
pylint = "^2.17.5"
ruff = "^0.0.285"
+ipykernel = "^6.25.2"
[tool.ruff]
select = ["E", "F", "B", "I"]
@@ -35,9 +37,8 @@ target-version = "py310"
line-length = 100
[tool.poetry.scripts]
-train = "swr2_asr.train:run_cli"
-train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer"
-train-char-tokenizer = "swr2_asr.tokenizer:train_char_tokenizer"
+train = "swr2_asr.train:main"
+recognize = "swr2_asr.inference:main"
[build-system]
requires = ["poetry-core"]
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index 3c58af0..64a6eeb 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -6,25 +6,10 @@ import torchaudio
import yaml
from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.tokenizer import CharTokenizer
-def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True):
- """Greedily decode a sequence."""
- arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- blank_label = tokenizer.get_blank_token()
- decodes = []
- for args in arg_maxes:
- decode = []
- for j, index in enumerate(args):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == args[j - 1]:
- continue
- decode.append(index.item())
- decodes.append(tokenizer.decode(decode))
- return decodes
-
-
@click.command()
@click.option(
"--config_path",
@@ -46,11 +31,16 @@ def main(config_path: str, file_path: str) -> None:
model_config = config_dict.get("model", {})
tokenizer_config = config_dict.get("tokenizer", {})
inference_config = config_dict.get("inference", {})
+ decoder_config = config_dict.get("decoder", {})
- if inference_config["device"] == "cpu":
+ if inference_config.get("device", "") == "cpu":
device = "cpu"
- elif inference_config["device"] == "cuda":
+ elif inference_config.get("device", "") == "cuda":
device = "cuda" if torch.cuda.is_available() else "cpu"
+ elif inference_config.get("device", "") == "mps":
+ device = "mps"
+ else:
+ device = "cpu"
device = torch.device(device) # pylint: disable=no-member
tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
@@ -90,11 +80,16 @@ def main(config_path: str, file_path: str) -> None:
spec = spec.unsqueeze(0)
spec = spec.transpose(1, 2)
spec = spec.unsqueeze(0)
+ spec = spec.to(device)
output = model(spec) # pylint: disable=not-callable
output = F.log_softmax(output, dim=2) # (batch, time, n_class)
- decoded_preds = greedy_decoder(output, tokenizer)
- print(decoded_preds)
+ decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config)
+
+ preds = decoder(output)
+ preds = " ".join(preds[0][0].words).strip()
+
+ print(preds)
if __name__ == "__main__":
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 9c7ede9..5277c16 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
-from swr2_asr.utils.decoder import greedy_decoder
-from swr2_asr.utils.tokenizer import CharTokenizer
-
+from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.loss_scores import cer, wer
+from swr2_asr.utils.tokenizer import CharTokenizer
class IterMeter:
@@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
# get values from test_args:
model, device, test_loader, criterion, tokenizer, decoder = test_args.values()
- if decoder == "greedy":
- decoder = greedy_decoder
-
model.eval()
test_loss = 0
test_cer, test_wer = [], []
@@ -141,12 +137,15 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
- decoded_preds, decoded_targets = greedy_decoder(
- output.transpose(0, 1), labels, label_lengths, tokenizer
- )
+ decoded_targets = tokenizer.decode_batch(labels)
+ decoded_preds = decoder(output.transpose(0, 1))
for j, _ in enumerate(decoded_preds):
- test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
- test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
+ if j >= len(decoded_targets):
+ break
+ pred = " ".join(decoded_preds[j][0].words).strip() # batch, top, words
+ target = decoded_targets[j]
+ test_cer.append(cer(target, pred))
+ test_wer.append(wer(target, pred))
avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
@@ -187,6 +186,7 @@ def main(config_path: str):
dataset_config = config_dict.get("dataset", {})
tokenizer_config = config_dict.get("tokenizer", {})
checkpoints_config = config_dict.get("checkpoints", {})
+ decoder_config = config_dict.get("decoder", {})
if not os.path.isdir(dataset_config["dataset_root_path"]):
os.makedirs(dataset_config["dataset_root_path"])
@@ -262,12 +262,19 @@ def main(config_path: str):
if checkpoints_config["model_load_path"] is not None:
checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device)
- model.load_state_dict(checkpoint["model_state_dict"])
+ state_dict = {
+ k[len("module.") :] if k.startswith("module.") else k: v
+ for k, v in checkpoint["model_state_dict"].items()
+ }
+
+ model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]
iter_meter = IterMeter()
+ decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config)
+
for epoch in range(prev_epoch + 1, training_config["epochs"] + 1):
train_args: TrainArgs = {
"model": model,
@@ -283,14 +290,13 @@ def main(config_path: str):
train_loss = train(train_args)
test_loss, test_cer, test_wer = 0, 0, 0
-
test_args: TestArgs = {
"model": model,
"device": device,
"test_loader": valid_loader,
"criterion": criterion,
"tokenizer": tokenizer,
- "decoder": "greedy",
+ "decoder": decoder,
}
if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0:
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index d551c98..74cd572 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -6,7 +6,7 @@ import numpy as np
import torch
import torchaudio
from torch import Tensor, nn
-from torch.utils.data import DataLoader, Dataset
+from torch.utils.data import Dataset
from torchaudio.datasets.utils import _extract_tar
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -343,3 +343,21 @@ class MLSDataset(Dataset):
dataset_lookup_entry["chapterid"],
idx,
) # type: ignore
+
+
+def create_lexicon(vocab_counts_path, lexicon_path):
+ """Create a lexicon from the vocab_counts.txt file"""
+ words_list = []
+ with open(vocab_counts_path, "r", encoding="utf-8") as file:
+ for line in file:
+ words = line.split()
+ if len(words) >= 1:
+ word = words[0]
+ words_list.append(word)
+
+ with open(lexicon_path, "w", encoding="utf-8") as file:
+ for word in words_list:
+ file.write(f"{word} ")
+ for char in word:
+ file.write(char + " ")
+ file.write("<SPACE>")
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index fcddb79..1fd002a 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -1,26 +1,155 @@
"""Decoder for CTC-based ASR.""" ""
+import os
+from dataclasses import dataclass
+
import torch
+from torchaudio.datasets.utils import _extract_tar
+from torchaudio.models.decoder import ctc_decoder
+from swr2_asr.utils.data import create_lexicon
from swr2_asr.utils.tokenizer import CharTokenizer
-# TODO: refactor to use torch CTC decoder class
-def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
- """Greedily decode a sequence."""
- blank_label = tokenizer.get_blank_token()
- arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- decodes = []
- targets = []
- for i, args in enumerate(arg_maxes):
- decode = []
- targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
- for j, index in enumerate(args):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == args[j - 1]:
- continue
- decode.append(index.item())
- decodes.append(tokenizer.decode(decode))
- return decodes, targets
-
-
-# TODO: add beam search decoder
+@dataclass
+class DecoderOutput:
+ """Decoder output."""
+
+ words: list[str]
+
+
+def decoder_factory(decoder_type: str = "greedy") -> callable:
+ """Decoder factory."""
+ if decoder_type == "greedy":
+ return get_greedy_decoder
+ if decoder_type == "lm":
+ return get_beam_search_decoder
+ raise NotImplementedError
+
+
+def get_greedy_decoder(
+ tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
+ *_,
+):
+ """Greedy decoder."""
+ return GreedyDecoder(tokenizer)
+
+
+def get_beam_search_decoder(
+ tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
+ hparams: dict, # pylint: disable=redefined-outer-name
+):
+ """Beam search decoder."""
+ hparams = hparams.get("lm", {})
+ language, lang_model_path, n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = (
+ hparams["language"],
+ hparams["language_model_path"],
+ hparams["n_gram"],
+ hparams["beam_size"],
+ hparams["beam_threshold"],
+ hparams["n_best"],
+ hparams["lm_weight"],
+ hparams["word_score"],
+ )
+
+ if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")):
+ url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz"
+ torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz")
+ _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True)
+
+ tokens_path = os.path.join(lang_model_path, f"mls_lm_{language}", "tokens.txt")
+ if not os.path.isfile(tokens_path):
+ tokenizer.create_tokens_txt(tokens_path)
+
+ lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt")
+ if not os.path.isfile(lexicon_path):
+ occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt")
+ create_lexicon(occurences_path, lexicon_path)
+
+ lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa")
+
+ decoder = ctc_decoder(
+ lexicon=lexicon_path,
+ tokens=tokens_path,
+ lm=lm_path,
+ blank_token="_",
+ sil_token="<SPACE>",
+ unk_word="<UNK>",
+ nbest=n_best,
+ beam_size=beam_size,
+ beam_threshold=beam_threshold,
+ lm_weight=lm_weight,
+ word_score=word_score,
+ )
+ return decoder
+
+
+class GreedyDecoder:
+ """Greedy decoder."""
+
+ def __init__(self, tokenizer: CharTokenizer): # pylint: disable=redefined-outer-name
+ self.tokenizer = tokenizer
+
+ def __call__(
+ self, output, greedy_type: str = "inference", labels=None, label_lengths=None
+ ): # pylint: disable=redefined-outer-name
+ """Greedily decode a sequence."""
+ if greedy_type == "train":
+ res = self.train(output, labels, label_lengths)
+ if greedy_type == "inference":
+ res = self.inference(output)
+
+ res = [x.split(" ") for x in res]
+ res = [[DecoderOutput(x)] for x in res]
+ return res
+
+ def train(self, output, labels, label_lengths):
+ """Greedily decode a sequence with known labels."""
+ blank_label = tokenizer.get_blank_token()
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ targets.append(self.tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(self.tokenizer.decode(decode))
+ return decodes, targets
+
+ def inference(self, output):
+ """Greedily decode a sequence."""
+ collapse_repeated = True
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ blank_label = self.tokenizer.get_blank_token()
+ decodes = []
+ for args in arg_maxes:
+ decode = []
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(self.tokenizer.decode(decode))
+
+ return decodes
+
+
+if __name__ == "__main__":
+ tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+ tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt")
+
+ hparams = {
+ "language": "german",
+ "lang_model_path": "data",
+ "n_gram": 3,
+ "beam_size": 100,
+ "beam_threshold": 100,
+ "n_best": 1,
+ "lm_weight": 0.5,
+ "word_score": 1.0,
+ }
+
+ get_beam_search_decoder(tokenizer, hparams)
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 1cc7b84..4e3fddd 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -29,9 +29,17 @@ class CharTokenizer:
"""Use a character map and convert integer labels to an text sequence"""
string = []
for i in labels:
+ i = int(i)
string.append(self.index_map[i])
return "".join(string).replace("<SPACE>", " ")
+ def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ """Use a character map and convert integer labels to an text sequence"""
+ string = []
+ for label in labels:
+ string.append(self.decode(label))
+ return string
+
def get_vocab_size(self) -> int:
"""Get the number of unique characters in the dataset"""
return len(self.char_map)
@@ -120,3 +128,9 @@ class CharTokenizer:
load_tokenizer.char_map[char] = int(index)
load_tokenizer.index_map[int(index)] = char
return load_tokenizer
+
+ def create_tokens_txt(self, path: str):
+ """Create a txt file with all the characters"""
+ with open(path, "w", encoding="utf-8") as file:
+ for char, _ in self.char_map.items():
+ file.write(f"{char}\n")