aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore7
-rw-r--r--.vscode/settings.json2
-rw-r--r--Dockerfile13
-rw-r--r--Makefile6
-rw-r--r--mypy.ini3
-rw-r--r--poetry.lock510
-rw-r--r--pyproject.toml34
-rw-r--r--readme.md6
-rw-r--r--swr2_asr/inference_test.py74
-rw-r--r--swr2_asr/loss_scores.py185
-rw-r--r--swr2_asr/train.py482
-rw-r--r--tests/__init__.py0
12 files changed, 1107 insertions, 215 deletions
diff --git a/.gitignore b/.gitignore
index 33600ee..8e64e4b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -118,6 +118,8 @@ ipython_config.py
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
+.pdm-python
+.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
@@ -162,14 +164,9 @@ dmypy.json
# Cython debug symbols
cython_debug/
-# linter
-**/.ruff
-
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
-
-data/
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 6d5637c..bd8762b 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -7,4 +7,6 @@
},
"black-formatter.importStrategy": "fromEnvironment",
"python.analysis.typeCheckingMode": "basic",
+ "python.linting.pylintEnabled": true,
+ "python.linting.enabled": true,
} \ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..ca7463f
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,13 @@
+FROM python:3.10
+
+# install python poetry
+RUN curl -sSL https://install.python-poetry.org | python3 -
+
+WORKDIR /app
+
+COPY readme.md mypy.ini poetry.lock pyproject.toml ./
+COPY swr2_asr ./swr2_asr
+ENV POETRY_VIRTUALENVS_IN_PROJECT=true
+RUN /root/.local/bin/poetry --no-interaction install --without dev
+
+ENTRYPOINT [ "/root/.local/bin/poetry", "run", "python", "-m", "swr2_asr" ]
diff --git a/Makefile b/Makefile
index 703405f..4f0ea9c 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
format:
- @black .
+ @poetry run black .
lint:
- @mypy --strict swr2_asr
- @pylint swr2_asr \ No newline at end of file
+ @poetry run mypy --strict swr2_asr
+ @poetry run pylint swr2_asr \ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index 03675ee..f7cfc59 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -6,3 +6,6 @@ ignore_missing_imports = true
[mypy-torch.*]
ignore_missing_imports = true
+
+[mypy-click.*]
+ignore_missing_imports = true
diff --git a/poetry.lock b/poetry.lock
index 9d91798..287e19a 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -14,10 +14,7 @@ files = [
[package.dependencies]
lazy-object-proxy = ">=1.4.0"
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
-wrapt = [
- {version = ">=1.11,<2", markers = "python_version < \"3.11\""},
- {version = ">=1.14,<2", markers = "python_version >= \"3.11\""},
-]
+wrapt = {version = ">=1.11,<2", markers = "python_version < \"3.11\""}
[[package]]
name = "AudioLoader"
@@ -30,9 +27,9 @@ develop = false
[package.source]
type = "git"
-url = "https://github.com/KinWaiCheuk/AudioLoader.git"
+url = "https://github.com/marvinborner/AudioLoader.git"
reference = "HEAD"
-resolved_reference = "c79acea2db7323fab22a0041211208899d0371e2"
+resolved_reference = "8fb829bf7fb98f26f8456dc22ef0fe2c7bb38ac2"
[[package]]
name = "black"
@@ -80,17 +77,6 @@ jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]]
-name = "cfgv"
-version = "3.4.0"
-description = "Validate configuration and produce human readable error messages."
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"},
- {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"},
-]
-
-[[package]]
name = "click"
version = "8.1.7"
description = "Composable command line interface toolkit"
@@ -105,6 +91,35 @@ files = [
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
+name = "cmake"
+version = "3.27.2"
+description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
+optional = false
+python-versions = "*"
+files = [
+ {file = "cmake-3.27.2-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:96ac856c4d6b2104408848f0005a8ab2229d4135b171ea9a03e8c33039ede420"},
+ {file = "cmake-3.27.2-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:11fe6129d07982721c5965fd804a4056b8c6e9c4f482ac9e0fe41bb3abc1ab5f"},
+ {file = "cmake-3.27.2-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:f0c64e89e2ea59592980c4fe3821d712fee0e74cf87c2aaec5b3ab9aa809a57c"},
+ {file = "cmake-3.27.2-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ca7650477dff2a1138776b28b79c0e99127be733d3978922e8f87b56a433eed6"},
+ {file = "cmake-3.27.2-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:ab2e40fe09e76a7ef67da2bbbf7a4cd1f52db4f1c7b6ccdda2539f918830343a"},
+ {file = "cmake-3.27.2-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:980ee19f12c808cb8ddb56fdcee832501a9f9631799d8b4fc625c0a0b5fb4c55"},
+ {file = "cmake-3.27.2-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:115d30ca0760e3861d9ad6b3288cd11ee72a785b81227da0c1765d3b84e2c009"},
+ {file = "cmake-3.27.2-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:efc338c939d6d435890a52458a260bf0942bd8392b648d7532a72c1ec0764e18"},
+ {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:7f7438c60ccc01765b67abfb1797787c3b9459d500a804ed70a4cc181bc02204"},
+ {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:294f008734267e0eee1574ad1b911bed137bc907ab19d60a618dab4615aa1fca"},
+ {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:197a34dc62ee149ced343545fac67e5a30b93fda65250b065726f86ce92bdada"},
+ {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:afb46ad883b174fb64347802ba5878423551dbd5847bb64669c39a5957c06eb7"},
+ {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:83611ffd155e270a6b13bbf0cfd4e8688ebda634f448aa2e3734006c745bf33f"},
+ {file = "cmake-3.27.2-py2.py3-none-win32.whl", hash = "sha256:53e12deb893da935e236f93accd47dbe2806620cd7654986234dc4487cc49652"},
+ {file = "cmake-3.27.2-py2.py3-none-win_amd64.whl", hash = "sha256:611f9722c68c40352d38a6c01960ab038c3d0419e7aee3bf18f95b23031e0dfe"},
+ {file = "cmake-3.27.2-py2.py3-none-win_arm64.whl", hash = "sha256:30620326b51ac2ce0d8f476747af6367a7ea21075c4d065fad9443904b07476a"},
+ {file = "cmake-3.27.2.tar.gz", hash = "sha256:7cd6e2d7d5a1125f8c26c4f65214f8c942e3f276f98c16cb62ae382c35609f25"},
+]
+
+[package.extras]
+test = ["coverage (>=4.2)", "flake8 (>=3.0.4)", "path.py (>=11.5.0)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)", "pytest-runner (>=2.9)", "pytest-virtualenv (>=1.7.0)", "scikit-build (>=0.10.0)", "setuptools (>=28.0.0)", "virtualenv (>=15.0.3)", "wheel"]
+
+[[package]]
name = "colorama"
version = "0.4.6"
description = "Cross-platform colored terminal text."
@@ -130,17 +145,6 @@ files = [
graph = ["objgraph (>=1.7.2)"]
[[package]]
-name = "distlib"
-version = "0.3.7"
-description = "Distribution utilities"
-optional = false
-python-versions = "*"
-files = [
- {file = "distlib-0.3.7-py2.py3-none-any.whl", hash = "sha256:2e24928bc811348f0feb63014e97aaae3037f2cf48712d51ae61df7fd6075057"},
- {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"},
-]
-
-[[package]]
name = "filelock"
version = "3.12.2"
description = "A platform independent file lock."
@@ -156,20 +160,6 @@ docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1
testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"]
[[package]]
-name = "identify"
-version = "2.5.26"
-description = "File identification library for Python"
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "identify-2.5.26-py2.py3-none-any.whl", hash = "sha256:c22a8ead0d4ca11f1edd6c9418c3220669b3b7533ada0a0ffa6cc0ef85cf9b54"},
- {file = "identify-2.5.26.tar.gz", hash = "sha256:7243800bce2f58404ed41b7c002e53d4d22bcf3ae1b7900c2d7aefd95394bf7f"},
-]
-
-[package.extras]
-license = ["ukkonen"]
-
-[[package]]
name = "isort"
version = "5.12.0"
description = "A Python utility / library to sort Python imports."
@@ -249,6 +239,16 @@ files = [
]
[[package]]
+name = "lit"
+version = "16.0.6"
+description = "A Software Testing Tool"
+optional = false
+python-versions = "*"
+files = [
+ {file = "lit-16.0.6.tar.gz", hash = "sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a"},
+]
+
+[[package]]
name = "markupsafe"
version = "2.1.3"
description = "Safely add untrusted strings to HTML/XML markup."
@@ -319,6 +319,33 @@ files = [
]
[[package]]
+name = "mido"
+version = "1.3.0"
+description = "MIDI Objects for Python"
+optional = false
+python-versions = "~=3.7"
+files = [
+ {file = "mido-1.3.0-py3-none-any.whl", hash = "sha256:a710a274c8a1a3fd481f526d174a16e117b5e58d719ad92937a67fb6167a9432"},
+ {file = "mido-1.3.0.tar.gz", hash = "sha256:84282e3ace34bca3f984220db2dbcb98245cfeafb854260c02e000750dca86aa"},
+]
+
+[package.dependencies]
+packaging = ">=23.1,<24.0"
+
+[package.extras]
+build-docs = ["sphinx (>=4.3.2,<4.4.0)", "sphinx-rtd-theme (>=1.2.2,<1.3.0)"]
+check-manifest = ["check-manifest (>=0.49)"]
+dev = ["mido[build-docs]", "mido[check-manifest]", "mido[lint-code]", "mido[lint-reuse]", "mido[release]", "mido[test]"]
+lint-code = ["flake8 (>=5.0.4,<5.1.0)"]
+lint-reuse = ["reuse (>=1.1.2,<1.2.0)"]
+ports-all = ["mido[ports-pygame]", "mido[ports-rtmidi-python]", "mido[ports-rtmidi]"]
+ports-pygame = ["PyGame (>=2.5,<3.0)"]
+ports-rtmidi = ["python-rtmidi (>=1.5.4,<1.6.0)"]
+ports-rtmidi-python = ["rtmidi-python (>=0.2.2,<0.3.0)"]
+release = ["twine (>=4.0.2,<4.1.0)"]
+test-code = ["pytest (>=7.4.0,<7.5.0)"]
+
+[[package]]
name = "mpmath"
version = "1.3.0"
description = "Python library for arbitrary-precision floating-point arithmetic"
@@ -411,20 +438,6 @@ extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.1
test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"]
[[package]]
-name = "nodeenv"
-version = "1.8.0"
-description = "Node.js virtual environment builder"
-optional = false
-python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
-files = [
- {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
- {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
-]
-
-[package.dependencies]
-setuptools = "*"
-
-[[package]]
name = "numpy"
version = "1.25.2"
description = "Fundamental package for array computing in Python"
@@ -459,6 +472,164 @@ files = [
]
[[package]]
+name = "nvidia-cublas-cu11"
+version = "11.10.3.66"
+description = "CUBLAS native runtime libraries"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"},
+ {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"},
+]
+
+[package.dependencies]
+setuptools = "*"
+wheel = "*"
+
+[[package]]
+name = "nvidia-cuda-cupti-cu11"
+version = "11.7.101"
+description = "CUDA profiling tools runtime libs."
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:e0cfd9854e1f2edaa36ca20d21cd0bdd5dcfca4e3b9e130a082e05b33b6c5895"},
+ {file = "nvidia_cuda_cupti_cu11-11.7.101-py3-none-win_amd64.whl", hash = "sha256:7cc5b8f91ae5e1389c3c0ad8866b3b016a175e827ea8f162a672990a402ab2b0"},
+]
+
+[package.dependencies]
+setuptools = "*"
+wheel = "*"
+
+[[package]]
+name = "nvidia-cuda-nvrtc-cu11"
+version = "11.7.99"
+description = "NVRTC native runtime libraries"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"},
+ {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"},
+ {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3"},
+]
+
+[package.dependencies]
+setuptools = "*"
+wheel = "*"
+
+[[package]]
+name = "nvidia-cuda-runtime-cu11"
+version = "11.7.99"
+description = "CUDA Runtime native Libraries"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"},
+ {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"},
+]
+
+[package.dependencies]
+setuptools = "*"
+wheel = "*"
+
+[[package]]
+name = "nvidia-cudnn-cu11"
+version = "8.5.0.96"
+description = "cuDNN runtime libraries"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"},
+ {file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"},
+]
+
+[package.dependencies]
+setuptools = "*"
+wheel = "*"
+
+[[package]]
+name = "nvidia-cufft-cu11"
+version = "10.9.0.58"
+description = "CUFFT native runtime libraries"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl", hash = "sha256:222f9da70c80384632fd6035e4c3f16762d64ea7a843829cb278f98b3cb7dd81"},
+ {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-win_amd64.whl", hash = "sha256:c4d316f17c745ec9c728e30409612eaf77a8404c3733cdf6c9c1569634d1ca03"},
+]
+
+[[package]]
+name = "nvidia-curand-cu11"
+version = "10.2.10.91"
+description = "CURAND native runtime libraries"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:eecb269c970fa599a2660c9232fa46aaccbf90d9170b96c462e13bcb4d129e2c"},
+ {file = "nvidia_curand_cu11-10.2.10.91-py3-none-win_amd64.whl", hash = "sha256:f742052af0e1e75523bde18895a9ed016ecf1e5aa0ecddfcc3658fd11a1ff417"},
+]
+
+[package.dependencies]
+setuptools = "*"
+wheel = "*"
+
+[[package]]
+name = "nvidia-cusolver-cu11"
+version = "11.4.0.1"
+description = "CUDA solver native runtime libraries"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:72fa7261d755ed55c0074960df5904b65e2326f7adce364cbe4945063c1be412"},
+ {file = "nvidia_cusolver_cu11-11.4.0.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:700b781bfefd57d161443aff9ace1878584b93e0b2cfef3d6e9296d96febbf99"},
+ {file = "nvidia_cusolver_cu11-11.4.0.1-py3-none-win_amd64.whl", hash = "sha256:00f70b256add65f8c1eb3b6a65308795a93e7740f6df9e273eccbba770d370c4"},
+]
+
+[package.dependencies]
+setuptools = "*"
+wheel = "*"
+
+[[package]]
+name = "nvidia-cusparse-cu11"
+version = "11.7.4.91"
+description = "CUSPARSE native runtime libraries"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:a3389de714db63321aa11fbec3919271f415ef19fda58aed7f2ede488c32733d"},
+ {file = "nvidia_cusparse_cu11-11.7.4.91-py3-none-win_amd64.whl", hash = "sha256:304a01599534f5186a8ed1c3756879282c72c118bc77dd890dc1ff868cad25b9"},
+]
+
+[package.dependencies]
+setuptools = "*"
+wheel = "*"
+
+[[package]]
+name = "nvidia-nccl-cu11"
+version = "2.14.3"
+description = "NVIDIA Collective Communication Library (NCCL) Runtime"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:5e5534257d1284b8e825bc3a182c6f06acd6eb405e9f89d49340e98cd8f136eb"},
+]
+
+[[package]]
+name = "nvidia-nvtx-cu11"
+version = "11.7.91"
+description = "NVIDIA Tools Extension"
+optional = false
+python-versions = ">=3"
+files = [
+ {file = "nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:b22c64eee426a62fc00952b507d6d29cf62b4c9df7a480fcc417e540e05fd5ac"},
+ {file = "nvidia_nvtx_cu11-11.7.91-py3-none-win_amd64.whl", hash = "sha256:dfd7fcb2a91742513027d63a26b757f38dd8b07fecac282c4d132a9d373ff064"},
+]
+
+[package.dependencies]
+setuptools = "*"
+wheel = "*"
+
+[[package]]
name = "packaging"
version = "23.1"
description = "Core utilities for Python packages"
@@ -496,24 +667,6 @@ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"]
[[package]]
-name = "pre-commit"
-version = "3.3.3"
-description = "A framework for managing and maintaining multi-language pre-commit hooks."
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "pre_commit-3.3.3-py2.py3-none-any.whl", hash = "sha256:10badb65d6a38caff29703362271d7dca483d01da88f9d7e05d0b97171c136cb"},
- {file = "pre_commit-3.3.3.tar.gz", hash = "sha256:a2256f489cd913d575c145132ae196fe335da32d91a8294b7afe6622335dd023"},
-]
-
-[package.dependencies]
-cfgv = ">=2.0.0"
-identify = ">=1.0.0"
-nodeenv = ">=0.11.1"
-pyyaml = ">=5.1"
-virtualenv = ">=20.10.0"
-
-[[package]]
name = "pylint"
version = "2.17.5"
description = "python code static checker"
@@ -527,10 +680,7 @@ files = [
[package.dependencies]
astroid = ">=2.15.6,<=2.17.0-dev0"
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
-dill = [
- {version = ">=0.2", markers = "python_version < \"3.11\""},
- {version = ">=0.3.6", markers = "python_version >= \"3.11\""},
-]
+dill = {version = ">=0.2", markers = "python_version < \"3.11\""}
isort = ">=4.2.5,<6"
mccabe = ">=0.6,<0.8"
platformdirs = ">=2.2.0"
@@ -542,67 +692,18 @@ spelling = ["pyenchant (>=3.2,<4.0)"]
testutils = ["gitpython (>3)"]
[[package]]
-name = "pyyaml"
-version = "6.0.1"
-description = "YAML parser and emitter for Python"
-optional = false
-python-versions = ">=3.6"
-files = [
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
- {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
- {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
- {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
- {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
- {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
- {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
- {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
- {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"},
- {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"},
- {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"},
- {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"},
- {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"},
- {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"},
- {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"},
- {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"},
- {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"},
- {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"},
- {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
- {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
- {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
- {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
- {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
- {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
- {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
- {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
- {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
- {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
-]
-
-[[package]]
name = "setuptools"
-version = "68.1.0"
+version = "68.1.2"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.8"
files = [
- {file = "setuptools-68.1.0-py3-none-any.whl", hash = "sha256:e13e1b0bc760e9b0127eda042845999b2f913e12437046e663b833aa96d89715"},
- {file = "setuptools-68.1.0.tar.gz", hash = "sha256:d59c97e7b774979a5ccb96388efc9eb65518004537e85d52e81eaee89ab6dd91"},
+ {file = "setuptools-68.1.2-py3-none-any.whl", hash = "sha256:3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b"},
+ {file = "setuptools-68.1.2.tar.gz", hash = "sha256:3d4dfa6d95f1b101d695a6160a7626e15583af71a5f52176efa5d39a054d475d"},
]
[package.extras]
-docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5,<=7.1.2)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
@@ -644,60 +745,90 @@ files = [
[[package]]
name = "torch"
-version = "2.0.1+cpu"
+version = "2.0.0"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "torch-2.0.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:fec257249ba014c68629a1994b0c6e7356e20e1afc77a87b9941a40e5095285d"},
- {file = "torch-2.0.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:ca88b499973c4c027e32c4960bf20911d7e984bd0c55cda181dc643559f3d93f"},
- {file = "torch-2.0.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:274d4acf486ef50ce1066ffe9d500beabb32bde69db93e3b71d0892dd148956c"},
- {file = "torch-2.0.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:e2603310bdff4b099c4c41ae132192fc0d6b00932ae2621d52d87218291864be"},
- {file = "torch-2.0.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:8046f49deae5a3d219b9f6059a1f478ae321f232e660249355a8bf6dcaa810c1"},
- {file = "torch-2.0.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:2ac4382ff090035f9045b18afe5763e2865dd35f2d661c02e51f658d95c8065a"},
- {file = "torch-2.0.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:73482a223d577407c45685fde9d2a74ba42f0d8d9f6e1e95c08071dc55c47d7b"},
- {file = "torch-2.0.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:f263f8e908288427ae81441fef540377f61e339a27632b1bbe33cf78292fdaea"},
+ {file = "torch-2.0.0-1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c9090bda7d2eeeecd74f51b721420dbeb44f838d4536cc1b284e879417e3064a"},
+ {file = "torch-2.0.0-1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bd42db2a48a20574d2c33489e120e9f32789c4dc13c514b0c44272972d14a2d7"},
+ {file = "torch-2.0.0-1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:8969aa8375bcbc0c2993e7ede0a7f889df9515f18b9b548433f412affed478d9"},
+ {file = "torch-2.0.0-1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ab2da16567cb55b67ae39e32d520d68ec736191d88ac79526ca5874754c32203"},
+ {file = "torch-2.0.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:7a9319a67294ef02459a19738bbfa8727bb5307b822dadd708bc2ccf6c901aca"},
+ {file = "torch-2.0.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:9f01fe1f6263f31bd04e1757946fd63ad531ae37f28bb2dbf66f5c826ee089f4"},
+ {file = "torch-2.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:527f4ae68df7b8301ee6b1158ca56350282ea633686537b30dbb5d7b4a52622a"},
+ {file = "torch-2.0.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:ce9b5a49bd513dff7950a5a07d6e26594dd51989cee05ba388b03e8e366fd5d5"},
+ {file = "torch-2.0.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:53e1c33c6896583cdb9a583693e22e99266444c4a43392dddc562640d39e542b"},
+ {file = "torch-2.0.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:09651bff72e439d004c991f15add0c397c66f98ab36fe60d5514b44e4da722e8"},
+ {file = "torch-2.0.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d439aec349c98f12819e8564b8c54008e4613dd4428582af0e6e14c24ca85870"},
+ {file = "torch-2.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:2802f84f021907deee7e9470ed10c0e78af7457ac9a08a6cd7d55adef835fede"},
+ {file = "torch-2.0.0-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:01858620f25f25e7a9ec4b547ff38e5e27c92d38ec4ccba9cfbfb31d7071ed9c"},
+ {file = "torch-2.0.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:9a2e53b5783ef5896a6af338b36d782f28e83c8ddfc2ac44b67b066d9d76f498"},
+ {file = "torch-2.0.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:ec5fff2447663e369682838ff0f82187b4d846057ef4d119a8dea7772a0b17dd"},
+ {file = "torch-2.0.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:11b0384fe3c18c01b8fc5992e70fc519cde65e44c51cc87be1838c1803daf42f"},
+ {file = "torch-2.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:e54846aa63855298cfb1195487f032e413e7ac9cbfa978fda32354cc39551475"},
+ {file = "torch-2.0.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:cc788cbbbbc6eb4c90e52c550efd067586c2693092cf367c135b34893a64ae78"},
+ {file = "torch-2.0.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:d292640f0fd72b7a31b2a6e3b635eb5065fcbedd4478f9cad1a1e7a9ec861d35"},
+ {file = "torch-2.0.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:6befaad784004b7af357e3d87fa0863c1f642866291f12a4c2af2de435e8ac5c"},
+ {file = "torch-2.0.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a83b26bd6ae36fbf5fee3d56973d9816e2002e8a3b7d9205531167c28aaa38a7"},
+ {file = "torch-2.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:c7e67195e1c3e33da53954b026e89a8e1ff3bc1aeb9eb32b677172d4a9b5dcbf"},
+ {file = "torch-2.0.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6e0b97beb037a165669c312591f242382e9109a240e20054d5a5782d9236cad0"},
+ {file = "torch-2.0.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:297a4919aff1c0f98a58ebe969200f71350a1d4d4f986dbfd60c02ffce780e99"},
]
[package.dependencies]
filelock = "*"
jinja2 = "*"
networkx = "*"
+nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-cupti-cu11 = {version = "11.7.101", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cufft-cu11 = {version = "10.9.0.58", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-curand-cu11 = {version = "10.2.10.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusolver-cu11 = {version = "11.4.0.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusparse-cu11 = {version = "11.7.4.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nccl-cu11 = {version = "2.14.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nvtx-cu11 = {version = "11.7.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
sympy = "*"
+triton = {version = "2.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
typing-extensions = "*"
[package.extras]
opt-einsum = ["opt-einsum (>=3.3)"]
-[package.source]
-type = "legacy"
-url = "https://download.pytorch.org/whl/cpu"
-reference = "pytorch-cpu"
-
[[package]]
name = "torchaudio"
-version = "2.0.2+cpu"
+version = "2.0.1"
description = "An audio package for PyTorch"
optional = false
python-versions = "*"
files = [
- {file = "torchaudio-2.0.2+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:80eccef075f9e356f7a4ad27c57614c726906a30617424e2d86a08de179feeff"},
- {file = "torchaudio-2.0.2+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:514995f6418d193caf1a2e06c912f603959ec6d8557b48ef52a9ceb8e5d180d1"},
- {file = "torchaudio-2.0.2+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:eb589737c30139c946b4c4f3a73307a6286c5402b6e5273dd5cb8768d0e73b3f"},
- {file = "torchaudio-2.0.2+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:a93c8e9ff28959b2e8b9a1347b34c5f94f33f617bc96fe17011b0d8cb5fb377f"},
- {file = "torchaudio-2.0.2+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:5196d7ed34863deaa28434ad02a043b01cfbd0c42b58d61ca56825577aeca74e"},
- {file = "torchaudio-2.0.2+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:b29d34bb39b8e613c562d8c04bd4108be6ae20221ee043e29b81bbef95f3f5bf"},
- {file = "torchaudio-2.0.2+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:30fc6926b51892cda933e98616fcdbc1f53084680210809cbc03c301d585ed9c"},
- {file = "torchaudio-2.0.2+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:aa7425184157a0356fe08deb74a517f5a0e1f27722f362a5ac3089904bf17001"},
+ {file = "torchaudio-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5d21ebbb55e7040d418d5062b0e882f9660d68b477b38fd436fa6c92ccbb52a"},
+ {file = "torchaudio-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6dbcd93b29d71a2f500f36a34ea5e467f510f773da85322098e6bdd8c9dc9948"},
+ {file = "torchaudio-2.0.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:5fdaba10ff06d098d603d9eb8d2ff541c3f3fe28ba178a78787190cec0d5187f"},
+ {file = "torchaudio-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:6419199c773c5045c594ff950d5e5dbbfa6c830892ec09721d4ed8704b702bfd"},
+ {file = "torchaudio-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:a5c81e480e5dcdcba065af1e3e31678ac29518991f00260094d37a39e63d76e5"},
+ {file = "torchaudio-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e2a047675493c0aa258fec621ef40e8b01abe3d8dbc872152e4b5998418aa3c5"},
+ {file = "torchaudio-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:91a28e587f708a03320eddbcc4a7dd1ad7150b3d4846b6c1557d85cc89a8d06c"},
+ {file = "torchaudio-2.0.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ba7740d98f601218ff667598ab3d9dab5f326878374fcb52d656f4ff033b9e96"},
+ {file = "torchaudio-2.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:f401b192921c8b77cc5e478ede589b256dba463f1cee91172ecb376fea45a288"},
+ {file = "torchaudio-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:0ef6754cf75ca5fd5117cb6243a6cf33552d67e9af0075aa6954b2c34bbf1036"},
+ {file = "torchaudio-2.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:022ca1baa4bb819b78343bd47b57ff6dc6f9fc19fa4ef269946aadf7e62db3c0"},
+ {file = "torchaudio-2.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a153ad5cdb62de8ec9fd1360a0d080bbaf39d578ae04e788db211571e675b7e0"},
+ {file = "torchaudio-2.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:aa7897774ab4156d0b72f7078b823ebc1371ee24c50df965447782889552367a"},
+ {file = "torchaudio-2.0.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:48d133593cddfe0424a350b566d54065bf6fe7469654de7add2f11b3ef03c5d9"},
+ {file = "torchaudio-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:ac65eb067feee435debba81adfe8337fa007a06de6508c0d80261c5562b6d098"},
+ {file = "torchaudio-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e3c6c8f9ea9f0e2df7a0b9375b0dcf955906e38fc12fab542b72a861564af8e7"},
+ {file = "torchaudio-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d0cf0779a334ec1861e9fa28bceb66a633c42e8f6b3322e2e37ff9f20d0ae81"},
+ {file = "torchaudio-2.0.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ab7acd2b5d351a2c65e4d935bb90b9256382bed93df57ee177bdbbe31c3cc984"},
+ {file = "torchaudio-2.0.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:77b953fd7278773269a9477315b8998ae7e5011cc4b2907e0df18162327482f1"},
+ {file = "torchaudio-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:c01bcea9d4c4a6616452e6cbd44d55913d8e6dee58191b925f35d46a2bf6e71b"},
]
[package.dependencies]
-torch = "2.0.1"
-
-[package.source]
-type = "legacy"
-url = "https://download.pytorch.org/whl/cpu"
-reference = "pytorch-cpu"
+torch = "2.0.0"
[[package]]
name = "tqdm"
@@ -720,6 +851,43 @@ slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
+name = "triton"
+version = "2.0.0"
+description = "A language and compiler for custom Deep Learning operations"
+optional = false
+python-versions = "*"
+files = [
+ {file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"},
+ {file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"},
+ {file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"},
+ {file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"},
+ {file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"},
+ {file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"},
+ {file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"},
+ {file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"},
+ {file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"},
+ {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"},
+ {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"},
+ {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"},
+ {file = "triton-2.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fedce6a381901b1547e0e7e1f2546e4f65dca6d91e2d8a7305a2d1f5551895be"},
+ {file = "triton-2.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75834f27926eab6c7f00ce73aaf1ab5bfb9bec6eb57ab7c0bfc0a23fac803b4c"},
+ {file = "triton-2.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0117722f8c2b579cd429e0bee80f7731ae05f63fe8e9414acd9a679885fcbf42"},
+ {file = "triton-2.0.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcd9be5d0c2e45d2b7e6ddc6da20112b6862d69741576f9c3dbaf941d745ecae"},
+ {file = "triton-2.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42a0d2c3fc2eab4ba71384f2e785fbfd47aa41ae05fa58bf12cb31dcbd0aeceb"},
+ {file = "triton-2.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52c47b72c72693198163ece9d90a721299e4fb3b8e24fd13141e384ad952724f"},
+]
+
+[package.dependencies]
+cmake = "*"
+filelock = "*"
+lit = "*"
+torch = "*"
+
+[package.extras]
+tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"]
+tutorials = ["matplotlib", "pandas", "tabulate"]
+
+[[package]]
name = "typing-extensions"
version = "4.7.1"
description = "Backported and Experimental Type Hints for Python 3.7+"
@@ -731,24 +899,18 @@ files = [
]
[[package]]
-name = "virtualenv"
-version = "20.24.3"
-description = "Virtual Python Environment builder"
+name = "wheel"
+version = "0.41.1"
+description = "A built-package format for Python"
optional = false
python-versions = ">=3.7"
files = [
- {file = "virtualenv-20.24.3-py3-none-any.whl", hash = "sha256:95a6e9398b4967fbcb5fef2acec5efaf9aa4972049d9ae41f95e0972a683fd02"},
- {file = "virtualenv-20.24.3.tar.gz", hash = "sha256:e5c3b4ce817b0b328af041506a2a299418c98747c4b1e68cb7527e74ced23efc"},
+ {file = "wheel-0.41.1-py3-none-any.whl", hash = "sha256:473219bd4cbedc62cea0cb309089b593e47c15c4a2531015f94e4e3b9a0f6981"},
+ {file = "wheel-0.41.1.tar.gz", hash = "sha256:12b911f083e876e10c595779709f8a88a59f45aacc646492a67fe9ef796c1b47"},
]
-[package.dependencies]
-distlib = ">=0.3.7,<1"
-filelock = ">=3.12.2,<4"
-platformdirs = ">=3.9.1,<4"
-
[package.extras]
-docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
-test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"]
+test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
[[package]]
name = "wrapt"
@@ -836,5 +998,5 @@ files = [
[metadata]
lock-version = "2.0"
-python-versions = "^3.10"
-content-hash = "e7b0344d7d2f66cddf80ac9fdbc63b839297b7443c043f960805738d54a79d43"
+python-versions = "~3.10"
+content-hash = "98a9b21411812f0514cbac138756a635b4f9283835791153f467c7f3fe1f4fd1"
diff --git a/pyproject.toml b/pyproject.toml
index 606633b..8490aa5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,40 +1,28 @@
[tool.poetry]
-name = "swr2_asr"
+name = "swr2-asr"
version = "0.1.0"
-description = "Automatic speech recognition with pytorch for SWR2"
-authors = ["Philipp Merkel <philippmerkel@outlook.com>", "Marvin Borner <git@marvinborner.de>", "Valentin Schmidt <>", "Silja Kasper <>"]
+description = ""
+authors = ["Philipp Merkel <philippmerkel@outlook.com>"]
license = "MIT"
readme = "readme.md"
packages = [{include = "swr2_asr"}]
[tool.poetry.dependencies]
-python = "^3.10"
-numpy = "^1.25.2"
-click = "^8.1.6"
-audioloader = {git = "https://github.com/KinWaiCheuk/AudioLoader.git"}
+python = "~3.10"
+torch = "2.0.0"
+torchaudio = "2.0.1"
+audioloader = {git = "https://github.com/marvinborner/AudioLoader.git"}
tqdm = "^4.66.1"
-torch = {version = "^2.0.1+cpu", source = "pytorch-cpu"}
-torchaudio = {version = "^2.0.2+cpu", source = "pytorch-cpu"}
+numpy = "^1.25.2"
+mido = "^1.3.0"
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
-pylint = "^2.17.5"
mypy = "^1.5.1"
-pre-commit = "^3.3.3"
-
-[[tool.poetry.source]]
-name = "pytorch-cpu"
-url = "https://download.pytorch.org/whl/cpu"
-priority = "explicit"
-
-
-[[tool.poetry.source]]
-name = "pytorch-gpu"
-url = "https://download.pytorch.org/whl/cu118"
-priority = "explicit"
+pylint = "^2.17.5"
[tool.poetry.scripts]
-train = "swr2_asr.train:main"
+train = "swr2_asr.train:run_cli"
[build-system]
requires = ["poetry-core"]
diff --git a/readme.md b/readme.md
index 99d741c..47d9a31 100644
--- a/readme.md
+++ b/readme.md
@@ -5,7 +5,7 @@ recogniton 2 (SWR2) in the summer term 2023.
# Installation
```
-pip install -r requirements.txt
+poetry install
```
# Usage
@@ -14,13 +14,13 @@ pip install -r requirements.txt
Train using the provided train script:
- poetry run train --data PATH/TO/DATA --lr 0.01
+ poetry run train
## Evaluation
## Inference
- poetry run recognize --data PATH/TO/FILE
+ poetry run recognize
## CI
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
new file mode 100644
index 0000000..a6b0010
--- /dev/null
+++ b/swr2_asr/inference_test.py
@@ -0,0 +1,74 @@
+"""Training script for the ASR model."""
+from AudioLoader.speech.mls import MultilingualLibriSpeech
+import torch
+import torchaudio
+import torchaudio.functional as F
+
+
+class GreedyCTCDecoder(torch.nn.Module):
+ def __init__(self, labels, blank=0) -> None:
+ super().__init__()
+ self.labels = labels
+ self.blank = blank
+
+ def forward(self, emission: torch.Tensor) -> str:
+ """Given a sequence emission over labels, get the best path string
+ Args:
+ emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
+
+ Returns:
+ str: The resulting transcript
+ """
+ indices = torch.argmax(emission, dim=-1) # [num_seq,]
+ indices = torch.unique_consecutive(indices, dim=-1)
+ indices = [i for i in indices if i != self.blank]
+ return "".join([self.labels[i] for i in indices])
+
+
+def main() -> None:
+ """Main function."""
+ # choose between cuda, cpu and mps devices
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ # device = "mps"
+ device = torch.device(device)
+
+ torch.random.manual_seed(42)
+
+ bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
+
+ print(f"Sample rate (model): {bundle.sample_rate}")
+ print(f"Labels (model): {bundle.get_labels()}")
+
+ model = bundle.get_model().to(device)
+
+ print(model.__class__)
+
+ # only do all things for one single sample
+ dataset = MultilingualLibriSpeech(
+ "data", "mls_german_opus", split="train", download=True
+ )
+
+ print(dataset[0])
+
+ # load waveforms and sample rate from dataset
+ waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"]
+
+ if sample_rate != bundle.sample_rate:
+ waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate))
+
+ waveform.to(device)
+
+ with torch.inference_mode():
+ features, _ = model.extract_features(waveform)
+
+ with torch.inference_mode():
+ emission, _ = model(waveform)
+
+ decoder = GreedyCTCDecoder(labels=bundle.get_labels())
+ transcript = decoder(emission[0])
+
+ print(transcript)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py
new file mode 100644
index 0000000..977462d
--- /dev/null
+++ b/swr2_asr/loss_scores.py
@@ -0,0 +1,185 @@
+import numpy as np
+
+
+def avg_wer(wer_scores, combined_ref_len):
+ return float(sum(wer_scores)) / float(combined_ref_len)
+
+
+def _levenshtein_distance(ref, hyp):
+ """Levenshtein distance is a string metric for measuring the difference
+ between two sequences. Informally, the levenshtein disctance is defined as
+ the minimum number of single-character edits (substitutions, insertions or
+ deletions) required to change one word into the other. We can naturally
+ extend the edits to word level when calculate levenshtein disctance for
+ two sentences.
+ """
+ m = len(ref)
+ n = len(hyp)
+
+ # special case
+ if ref == hyp:
+ return 0
+ if m == 0:
+ return n
+ if n == 0:
+ return m
+
+ if m < n:
+ ref, hyp = hyp, ref
+ m, n = n, m
+
+ # use O(min(m, n)) space
+ distance = np.zeros((2, n + 1), dtype=np.int32)
+
+ # initialize distance matrix
+ for j in range(0, n + 1):
+ distance[0][j] = j
+
+ # calculate levenshtein distance
+ for i in range(1, m + 1):
+ prev_row_idx = (i - 1) % 2
+ cur_row_idx = i % 2
+ distance[cur_row_idx][0] = i
+ for j in range(1, n + 1):
+ if ref[i - 1] == hyp[j - 1]:
+ distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
+ else:
+ s_num = distance[prev_row_idx][j - 1] + 1
+ i_num = distance[cur_row_idx][j - 1] + 1
+ d_num = distance[prev_row_idx][j] + 1
+ distance[cur_row_idx][j] = min(s_num, i_num, d_num)
+
+ return distance[m % 2][n]
+
+
+def word_errors(
+ reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "
+):
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in word-level.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param delimiter: Delimiter of input sentences.
+ :type delimiter: char
+ :return: Levenshtein distance and word number of reference sentence.
+ :rtype: list
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ ref_words = reference.split(delimiter)
+ hyp_words = hypothesis.split(delimiter)
+
+ edit_distance = _levenshtein_distance(ref_words, hyp_words)
+ return float(edit_distance), len(ref_words)
+
+
+def char_errors(
+ reference: str,
+ hypothesis: str,
+ ignore_case: bool = False,
+ remove_space: bool = False,
+):
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in char-level.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param remove_space: Whether remove internal space characters
+ :type remove_space: bool
+ :return: Levenshtein distance and length of reference sentence.
+ :rtype: list
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ join_char = " "
+ if remove_space:
+ join_char = ""
+
+ reference = join_char.join(filter(None, reference.split(" ")))
+ hypothesis = join_char.join(filter(None, hypothesis.split(" ")))
+
+ edit_distance = _levenshtein_distance(reference, hypothesis)
+ return float(edit_distance), len(reference)
+
+
+def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "):
+ """Calculate word error rate (WER). WER compares reference text and
+ hypothesis text in word-level. WER is defined as:
+ .. math::
+ WER = (Sw + Dw + Iw) / Nw
+ where
+ .. code-block:: text
+ Sw is the number of words subsituted,
+ Dw is the number of words deleted,
+ Iw is the number of words inserted,
+ Nw is the number of words in the reference
+ We can use levenshtein distance to calculate WER. Please draw an attention
+ that empty items will be removed when splitting sentences by delimiter.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param delimiter: Delimiter of input sentences.
+ :type delimiter: char
+ :return: Word error rate.
+ :rtype: float
+ :raises ValueError: If word number of reference is zero.
+ """
+ edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter)
+
+ if ref_len == 0:
+ raise ValueError("Reference's word number should be greater than 0.")
+
+ wer = float(edit_distance) / ref_len
+ return wer
+
+
+def cer(reference, hypothesis, ignore_case=False, remove_space=False):
+ """Calculate charactor error rate (CER). CER compares reference text and
+ hypothesis text in char-level. CER is defined as:
+ .. math::
+ CER = (Sc + Dc + Ic) / Nc
+ where
+ .. code-block:: text
+ Sc is the number of characters substituted,
+ Dc is the number of characters deleted,
+ Ic is the number of characters inserted
+ Nc is the number of characters in the reference
+ We can use levenshtein distance to calculate CER. Chinese input should be
+ encoded to unicode. Please draw an attention that the leading and tailing
+ space characters will be truncated and multiple consecutive space
+ characters in a sentence will be replaced by one space character.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param remove_space: Whether remove internal space characters
+ :type remove_space: bool
+ :return: Character error rate.
+ :rtype: float
+ :raises ValueError: If the reference length is zero.
+ """
+ edit_distance, ref_len = char_errors(
+ reference, hypothesis, ignore_case, remove_space
+ )
+
+ if ref_len == 0:
+ raise ValueError("Length of reference should be greater than 0.")
+
+ cer = float(edit_distance) / ref_len
+ return cer
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 7a9ffec..29f9372 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,16 +1,484 @@
"""Training script for the ASR model."""
-import os
from AudioLoader.speech.mls import MultilingualLibriSpeech
+import click
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import torchaudio
+from .loss_scores import cer, wer
-def main() -> None:
- """Main function."""
- dataset = MultilingualLibriSpeech(
- "data", "mls_polish_opus", split="train", download=(not os.path.isdir("data"))
+class TextTransform:
+ """Maps characters to integers and vice versa"""
+
+ def __init__(self):
+ char_map_str = """
+ ' 0
+ <SPACE> 1
+ a 2
+ b 3
+ c 4
+ d 5
+ e 6
+ f 7
+ g 8
+ h 9
+ i 10
+ j 11
+ k 12
+ l 13
+ m 14
+ n 15
+ o 16
+ p 17
+ q 18
+ r 19
+ s 20
+ t 21
+ u 22
+ v 23
+ w 24
+ x 25
+ y 26
+ z 27
+ ä 28
+ ö 29
+ ü 30
+ ß 31
+ - 32
+ é 33
+ è 34
+ à 35
+ ù 36
+ ç 37
+ â 38
+ ê 39
+ î 40
+ ô 41
+ û 42
+ ë 43
+ ï 44
+ ü 45
+ """
+ self.char_map = {}
+ self.index_map = {}
+ for line in char_map_str.strip().split("\n"):
+ char, index = line.split()
+ self.char_map[char] = int(index)
+ self.index_map[int(index)] = char
+ self.index_map[1] = " "
+
+ def text_to_int(self, text):
+ """Use a character map and convert text to an integer sequence"""
+ int_sequence = []
+ for char in text:
+ if char == " ":
+ mapped_char = self.char_map["<SPACE>"]
+ else:
+ mapped_char = self.char_map[char]
+ int_sequence.append(mapped_char)
+ return int_sequence
+
+ def int_to_text(self, labels):
+ """Use a character map and convert integer labels to an text sequence"""
+ string = []
+ for i in labels:
+ string.append(self.index_map[i])
+ return "".join(string).replace("<SPACE>", " ")
+
+
+train_audio_transforms = nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+ torchaudio.transforms.TimeMasking(time_mask_param=100),
+)
+
+valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
+
+text_transform = TextTransform()
+
+
+def data_processing(data, data_type="train"):
+ """Return the spectrograms, labels, and their lengths."""
+ spectrograms = []
+ labels = []
+ input_lengths = []
+ label_lengths = []
+ for sample in data:
+ if data_type == "train":
+ spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
+ elif data_type == "valid":
+ spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
+ else:
+ raise ValueError("data_type should be train or valid")
+ spectrograms.append(spec)
+ label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower()))
+ labels.append(label)
+ input_lengths.append(spec.shape[0] // 2)
+ label_lengths.append(len(label))
+
+ spectrograms = (
+ nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
+ .unsqueeze(1)
+ .transpose(2, 3)
)
+ labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+ return spectrograms, labels, input_lengths, label_lengths
+
+
+def greedy_decoder(
+ output, labels, label_lengths, blank_label=28, collapse_repeated=True
+):
+ """Greedily decode a sequence."""
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ targets.append(
+ text_transform.int_to_text(labels[i][: label_lengths[i]].tolist())
+ )
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(text_transform.int_to_text(decode))
+ return decodes, targets
+
+
+class CNNLayerNorm(nn.Module):
+ """Layer normalization built for cnns input"""
+
+ def __init__(self, n_feats: int):
+ super(CNNLayerNorm, self).__init__()
+ self.layer_norm = nn.LayerNorm(n_feats)
+
+ def forward(self, data):
+ """x (batch, channel, feature, time)"""
+ data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature)
+ data = self.layer_norm(data)
+ return data.transpose(2, 3).contiguous() # (batch, channel, feature, time)
+
+
+class ResidualCNN(nn.Module):
+ """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel: int,
+ stride: int,
+ dropout: float,
+ n_feats: int,
+ ):
+ super(ResidualCNN, self).__init__()
+
+ self.cnn1 = nn.Conv2d(
+ in_channels, out_channels, kernel, stride, padding=kernel // 2
+ )
+ self.cnn2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel,
+ stride,
+ padding=kernel // 2,
+ )
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.layer_norm1 = CNNLayerNorm(n_feats)
+ self.layer_norm2 = CNNLayerNorm(n_feats)
+
+ def forward(self, data):
+ """x (batch, channel, feature, time)"""
+ residual = data # (batch, channel, feature, time)
+ data = self.layer_norm1(data)
+ data = F.gelu(data)
+ data = self.dropout1(data)
+ data = self.cnn1(data)
+ data = self.layer_norm2(data)
+ data = F.gelu(data)
+ data = self.dropout2(data)
+ data = self.cnn2(data)
+ data += residual
+ return data # (batch, channel, feature, time)
+
+
+class BidirectionalGRU(nn.Module):
+ """BIdirectional GRU with Layer Normalization and Dropout"""
+
+ def __init__(
+ self,
+ rnn_dim: int,
+ hidden_size: int,
+ dropout: float,
+ batch_first: bool,
+ ):
+ super(BidirectionalGRU, self).__init__()
+
+ self.bi_gru = nn.GRU(
+ input_size=rnn_dim,
+ hidden_size=hidden_size,
+ num_layers=1,
+ batch_first=batch_first,
+ bidirectional=True,
+ )
+ self.layer_norm = nn.LayerNorm(rnn_dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, data):
+ """data (batch, time, feature)"""
+ data = self.layer_norm(data)
+ data = F.gelu(data)
+ data = self.dropout(data)
+ data, _ = self.bi_gru(data)
+ return data
+
+
+class SpeechRecognitionModel(nn.Module):
+ """Speech Recognition Model Inspired by DeepSpeech 2"""
+
+ def __init__(
+ self,
+ n_cnn_layers: int,
+ n_rnn_layers: int,
+ rnn_dim: int,
+ n_class: int,
+ n_feats: int,
+ stride: int = 2,
+ dropout: float = 0.1,
+ ):
+ super(SpeechRecognitionModel, self).__init__()
+ n_feats //= 2
+ self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
+ # n residual cnn layers with filter size of 32
+ self.rescnn_layers = nn.Sequential(
+ *[
+ ResidualCNN(
+ 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
+ )
+ for _ in range(n_cnn_layers)
+ ]
+ )
+ self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
+ self.birnn_layers = nn.Sequential(
+ *[
+ BidirectionalGRU(
+ rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
+ hidden_size=rnn_dim,
+ dropout=dropout,
+ batch_first=i == 0,
+ )
+ for i in range(n_rnn_layers)
+ ]
+ )
+ self.classifier = nn.Sequential(
+ nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(rnn_dim, n_class),
+ )
+
+ def forward(self, data):
+ """data (batch, channel, feature, time)"""
+ data = self.cnn(data)
+ data = self.rescnn_layers(data)
+ sizes = data.size()
+ data = data.view(
+ sizes[0], sizes[1] * sizes[2], sizes[3]
+ ) # (batch, feature, time)
+ data = data.transpose(1, 2) # (batch, time, feature)
+ data = self.fully_connected(data)
+ data = self.birnn_layers(data)
+ data = self.classifier(data)
+ return data
+
+
+class IterMeter(object):
+ """keeps track of total iterations"""
+
+ def __init__(self):
+ self.val = 0
+
+ def step(self):
+ """step"""
+ self.val += 1
+
+ def get(self):
+ """get"""
+ return self.val
+
+
+def train(
+ model,
+ device,
+ train_loader,
+ criterion,
+ optimizer,
+ scheduler,
+ epoch,
+ iter_meter,
+):
+ """Train"""
+ model.train()
+ data_len = len(train_loader.dataset)
+ for batch_idx, _data in enumerate(train_loader):
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
+
+ optimizer.zero_grad()
+
+ output = model(spectrograms) # (batch, time, n_class)
+ output = F.log_softmax(output, dim=2)
+ output = output.transpose(0, 1) # (time, batch, n_class)
+
+ loss = criterion(output, labels, input_lengths, label_lengths)
+ loss.backward()
+
+ optimizer.step()
+ scheduler.step()
+ iter_meter.step()
+ if batch_idx % 100 == 0 or batch_idx == data_len:
+ print(
+ f"Train Epoch: \
+ {epoch} \
+ [{batch_idx * len(spectrograms)}/{data_len} \
+ ({100.0 * batch_idx / len(train_loader)}%)]\t \
+ Loss: {loss.item()}"
+ )
+
+
+def test(model, device, test_loader, criterion):
+ """Test"""
+ print("\nevaluating...")
+ model.eval()
+ test_loss = 0
+ test_cer, test_wer = [], []
+ with torch.no_grad():
+ for _data in test_loader:
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
+
+ output = model(spectrograms) # (batch, time, n_class)
+ output = F.log_softmax(output, dim=2)
+ output = output.transpose(0, 1) # (time, batch, n_class)
+
+ loss = criterion(output, labels, input_lengths, label_lengths)
+ test_loss += loss.item() / len(test_loader)
+
+ decoded_preds, decoded_targets = greedy_decoder(
+ output.transpose(0, 1), labels, label_lengths
+ )
+ for j, pred in enumerate(decoded_preds):
+ test_cer.append(cer(decoded_targets[j], pred))
+ test_wer.append(wer(decoded_targets[j], pred))
+
+ avg_cer = sum(test_cer) / len(test_cer)
+ avg_wer = sum(test_wer) / len(test_wer)
+
+ print(
+ f"Test set: Average loss:\
+ {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n"
+ )
+
+
+def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
+ """Runs the training script."""
+ hparams = {
+ "n_cnn_layers": 3,
+ "n_rnn_layers": 5,
+ "rnn_dim": 512,
+ "n_class": 46,
+ "n_feats": 128,
+ "stride": 2,
+ "dropout": 0.1,
+ "learning_rate": learning_rate,
+ "batch_size": batch_size,
+ "epochs": epochs,
+ }
+
+ use_cuda = torch.cuda.is_available()
+ torch.manual_seed(42)
+ device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
+ # device = torch.device("mps")
+
+ train_dataset = MultilingualLibriSpeech(
+ "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False
+ )
+ test_dataset = MultilingualLibriSpeech(
+ "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False
+ )
+
+ kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
+
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=hparams["batch_size"],
+ shuffle=True,
+ collate_fn=lambda x: data_processing(x, "train"),
+ **kwargs,
+ )
+
+ test_loader = DataLoader(
+ test_dataset,
+ batch_size=hparams["batch_size"],
+ shuffle=True,
+ collate_fn=lambda x: data_processing(x, "train"),
+ **kwargs,
+ )
+
+ model = SpeechRecognitionModel(
+ hparams["n_cnn_layers"],
+ hparams["n_rnn_layers"],
+ hparams["rnn_dim"],
+ hparams["n_class"],
+ hparams["n_feats"],
+ hparams["stride"],
+ hparams["dropout"],
+ ).to(device)
+
+ print(
+ "Num Model Parameters", sum([param.nelement() for param in model.parameters()])
+ )
+
+ optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
+ criterion = nn.CTCLoss(blank=28).to(device)
+
+ scheduler = optim.lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=hparams["learning_rate"],
+ steps_per_epoch=int(len(train_loader)),
+ epochs=hparams["epochs"],
+ anneal_strategy="linear",
+ )
+
+ iter_meter = IterMeter()
+ for epoch in range(1, epochs + 1):
+ train(
+ model,
+ device,
+ train_loader,
+ criterion,
+ optimizer,
+ scheduler,
+ epoch,
+ iter_meter,
+ )
+ test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+
- print(dataset[1])
+@click.command()
+@click.option("--learning-rate", default=1e-3, help="Learning rate")
+@click.option("--batch_size", default=1, help="Batch size")
+@click.option("--epochs", default=1, help="Number of epochs")
+def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None:
+ """Runs the training script."""
+ run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs)
if __name__ == "__main__":
- main()
+ run(learning_rate=5e-4, batch_size=16, epochs=1)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py