diff options
-rw-r--r-- | .gitignore | 7 | ||||
-rw-r--r-- | .vscode/settings.json | 2 | ||||
-rw-r--r-- | Dockerfile | 13 | ||||
-rw-r--r-- | Makefile | 6 | ||||
-rw-r--r-- | mypy.ini | 3 | ||||
-rw-r--r-- | poetry.lock | 510 | ||||
-rw-r--r-- | pyproject.toml | 34 | ||||
-rw-r--r-- | readme.md | 6 | ||||
-rw-r--r-- | swr2_asr/inference_test.py | 74 | ||||
-rw-r--r-- | swr2_asr/loss_scores.py | 185 | ||||
-rw-r--r-- | swr2_asr/train.py | 482 | ||||
-rw-r--r-- | tests/__init__.py | 0 |
12 files changed, 1107 insertions, 215 deletions
@@ -118,6 +118,8 @@ ipython_config.py # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml +.pdm-python +.pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ @@ -162,14 +164,9 @@ dmypy.json # Cython debug symbols cython_debug/ -# linter -**/.ruff - # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ - -data/ diff --git a/.vscode/settings.json b/.vscode/settings.json index 6d5637c..bd8762b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,4 +7,6 @@ }, "black-formatter.importStrategy": "fromEnvironment", "python.analysis.typeCheckingMode": "basic", + "python.linting.pylintEnabled": true, + "python.linting.enabled": true, }
\ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ca7463f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.10 + +# install python poetry +RUN curl -sSL https://install.python-poetry.org | python3 - + +WORKDIR /app + +COPY readme.md mypy.ini poetry.lock pyproject.toml ./ +COPY swr2_asr ./swr2_asr +ENV POETRY_VIRTUALENVS_IN_PROJECT=true +RUN /root/.local/bin/poetry --no-interaction install --without dev + +ENTRYPOINT [ "/root/.local/bin/poetry", "run", "python", "-m", "swr2_asr" ] @@ -1,6 +1,6 @@ format: - @black . + @poetry run black . lint: - @mypy --strict swr2_asr - @pylint swr2_asr
\ No newline at end of file + @poetry run mypy --strict swr2_asr + @poetry run pylint swr2_asr
\ No newline at end of file @@ -6,3 +6,6 @@ ignore_missing_imports = true [mypy-torch.*] ignore_missing_imports = true + +[mypy-click.*] +ignore_missing_imports = true diff --git a/poetry.lock b/poetry.lock index 9d91798..287e19a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -14,10 +14,7 @@ files = [ [package.dependencies] lazy-object-proxy = ">=1.4.0" typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} -wrapt = [ - {version = ">=1.11,<2", markers = "python_version < \"3.11\""}, - {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, -] +wrapt = {version = ">=1.11,<2", markers = "python_version < \"3.11\""} [[package]] name = "AudioLoader" @@ -30,9 +27,9 @@ develop = false [package.source] type = "git" -url = "https://github.com/KinWaiCheuk/AudioLoader.git" +url = "https://github.com/marvinborner/AudioLoader.git" reference = "HEAD" -resolved_reference = "c79acea2db7323fab22a0041211208899d0371e2" +resolved_reference = "8fb829bf7fb98f26f8456dc22ef0fe2c7bb38ac2" [[package]] name = "black" @@ -80,17 +77,6 @@ jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] [[package]] -name = "cfgv" -version = "3.4.0" -description = "Validate configuration and produce human readable error messages." -optional = false -python-versions = ">=3.8" -files = [ - {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, - {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, -] - -[[package]] name = "click" version = "8.1.7" description = "Composable command line interface toolkit" @@ -105,6 +91,35 @@ files = [ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] +name = "cmake" +version = "3.27.2" +description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software" +optional = false +python-versions = "*" +files = [ + {file = "cmake-3.27.2-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:96ac856c4d6b2104408848f0005a8ab2229d4135b171ea9a03e8c33039ede420"}, + {file = "cmake-3.27.2-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:11fe6129d07982721c5965fd804a4056b8c6e9c4f482ac9e0fe41bb3abc1ab5f"}, + {file = "cmake-3.27.2-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:f0c64e89e2ea59592980c4fe3821d712fee0e74cf87c2aaec5b3ab9aa809a57c"}, + {file = "cmake-3.27.2-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ca7650477dff2a1138776b28b79c0e99127be733d3978922e8f87b56a433eed6"}, + {file = "cmake-3.27.2-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:ab2e40fe09e76a7ef67da2bbbf7a4cd1f52db4f1c7b6ccdda2539f918830343a"}, + {file = "cmake-3.27.2-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:980ee19f12c808cb8ddb56fdcee832501a9f9631799d8b4fc625c0a0b5fb4c55"}, + {file = "cmake-3.27.2-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:115d30ca0760e3861d9ad6b3288cd11ee72a785b81227da0c1765d3b84e2c009"}, + {file = "cmake-3.27.2-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:efc338c939d6d435890a52458a260bf0942bd8392b648d7532a72c1ec0764e18"}, + {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:7f7438c60ccc01765b67abfb1797787c3b9459d500a804ed70a4cc181bc02204"}, + {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:294f008734267e0eee1574ad1b911bed137bc907ab19d60a618dab4615aa1fca"}, + {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:197a34dc62ee149ced343545fac67e5a30b93fda65250b065726f86ce92bdada"}, + {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:afb46ad883b174fb64347802ba5878423551dbd5847bb64669c39a5957c06eb7"}, + {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:83611ffd155e270a6b13bbf0cfd4e8688ebda634f448aa2e3734006c745bf33f"}, + {file = "cmake-3.27.2-py2.py3-none-win32.whl", hash = "sha256:53e12deb893da935e236f93accd47dbe2806620cd7654986234dc4487cc49652"}, + {file = "cmake-3.27.2-py2.py3-none-win_amd64.whl", hash = "sha256:611f9722c68c40352d38a6c01960ab038c3d0419e7aee3bf18f95b23031e0dfe"}, + {file = "cmake-3.27.2-py2.py3-none-win_arm64.whl", hash = "sha256:30620326b51ac2ce0d8f476747af6367a7ea21075c4d065fad9443904b07476a"}, + {file = "cmake-3.27.2.tar.gz", hash = "sha256:7cd6e2d7d5a1125f8c26c4f65214f8c942e3f276f98c16cb62ae382c35609f25"}, +] + +[package.extras] +test = ["coverage (>=4.2)", "flake8 (>=3.0.4)", "path.py (>=11.5.0)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)", "pytest-runner (>=2.9)", "pytest-virtualenv (>=1.7.0)", "scikit-build (>=0.10.0)", "setuptools (>=28.0.0)", "virtualenv (>=15.0.3)", "wheel"] + +[[package]] name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." @@ -130,17 +145,6 @@ files = [ graph = ["objgraph (>=1.7.2)"] [[package]] -name = "distlib" -version = "0.3.7" -description = "Distribution utilities" -optional = false -python-versions = "*" -files = [ - {file = "distlib-0.3.7-py2.py3-none-any.whl", hash = "sha256:2e24928bc811348f0feb63014e97aaae3037f2cf48712d51ae61df7fd6075057"}, - {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, -] - -[[package]] name = "filelock" version = "3.12.2" description = "A platform independent file lock." @@ -156,20 +160,6 @@ docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] [[package]] -name = "identify" -version = "2.5.26" -description = "File identification library for Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "identify-2.5.26-py2.py3-none-any.whl", hash = "sha256:c22a8ead0d4ca11f1edd6c9418c3220669b3b7533ada0a0ffa6cc0ef85cf9b54"}, - {file = "identify-2.5.26.tar.gz", hash = "sha256:7243800bce2f58404ed41b7c002e53d4d22bcf3ae1b7900c2d7aefd95394bf7f"}, -] - -[package.extras] -license = ["ukkonen"] - -[[package]] name = "isort" version = "5.12.0" description = "A Python utility / library to sort Python imports." @@ -249,6 +239,16 @@ files = [ ] [[package]] +name = "lit" +version = "16.0.6" +description = "A Software Testing Tool" +optional = false +python-versions = "*" +files = [ + {file = "lit-16.0.6.tar.gz", hash = "sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a"}, +] + +[[package]] name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." @@ -319,6 +319,33 @@ files = [ ] [[package]] +name = "mido" +version = "1.3.0" +description = "MIDI Objects for Python" +optional = false +python-versions = "~=3.7" +files = [ + {file = "mido-1.3.0-py3-none-any.whl", hash = "sha256:a710a274c8a1a3fd481f526d174a16e117b5e58d719ad92937a67fb6167a9432"}, + {file = "mido-1.3.0.tar.gz", hash = "sha256:84282e3ace34bca3f984220db2dbcb98245cfeafb854260c02e000750dca86aa"}, +] + +[package.dependencies] +packaging = ">=23.1,<24.0" + +[package.extras] +build-docs = ["sphinx (>=4.3.2,<4.4.0)", "sphinx-rtd-theme (>=1.2.2,<1.3.0)"] +check-manifest = ["check-manifest (>=0.49)"] +dev = ["mido[build-docs]", "mido[check-manifest]", "mido[lint-code]", "mido[lint-reuse]", "mido[release]", "mido[test]"] +lint-code = ["flake8 (>=5.0.4,<5.1.0)"] +lint-reuse = ["reuse (>=1.1.2,<1.2.0)"] +ports-all = ["mido[ports-pygame]", "mido[ports-rtmidi-python]", "mido[ports-rtmidi]"] +ports-pygame = ["PyGame (>=2.5,<3.0)"] +ports-rtmidi = ["python-rtmidi (>=1.5.4,<1.6.0)"] +ports-rtmidi-python = ["rtmidi-python (>=0.2.2,<0.3.0)"] +release = ["twine (>=4.0.2,<4.1.0)"] +test-code = ["pytest (>=7.4.0,<7.5.0)"] + +[[package]] name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" @@ -411,20 +438,6 @@ extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.1 test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] -name = "nodeenv" -version = "1.8.0" -description = "Node.js virtual environment builder" -optional = false -python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" -files = [ - {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, - {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, -] - -[package.dependencies] -setuptools = "*" - -[[package]] name = "numpy" version = "1.25.2" description = "Fundamental package for array computing in Python" @@ -459,6 +472,164 @@ files = [ ] [[package]] +name = "nvidia-cublas-cu11" +version = "11.10.3.66" +description = "CUBLAS native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"}, + {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"}, +] + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cuda-cupti-cu11" +version = "11.7.101" +description = "CUDA profiling tools runtime libs." +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:e0cfd9854e1f2edaa36ca20d21cd0bdd5dcfca4e3b9e130a082e05b33b6c5895"}, + {file = "nvidia_cuda_cupti_cu11-11.7.101-py3-none-win_amd64.whl", hash = "sha256:7cc5b8f91ae5e1389c3c0ad8866b3b016a175e827ea8f162a672990a402ab2b0"}, +] + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cuda-nvrtc-cu11" +version = "11.7.99" +description = "NVRTC native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"}, + {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"}, + {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3"}, +] + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cuda-runtime-cu11" +version = "11.7.99" +description = "CUDA Runtime native Libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"}, + {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"}, +] + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cudnn-cu11" +version = "8.5.0.96" +description = "cuDNN runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"}, + {file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"}, +] + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cufft-cu11" +version = "10.9.0.58" +description = "CUFFT native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl", hash = "sha256:222f9da70c80384632fd6035e4c3f16762d64ea7a843829cb278f98b3cb7dd81"}, + {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-win_amd64.whl", hash = "sha256:c4d316f17c745ec9c728e30409612eaf77a8404c3733cdf6c9c1569634d1ca03"}, +] + +[[package]] +name = "nvidia-curand-cu11" +version = "10.2.10.91" +description = "CURAND native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:eecb269c970fa599a2660c9232fa46aaccbf90d9170b96c462e13bcb4d129e2c"}, + {file = "nvidia_curand_cu11-10.2.10.91-py3-none-win_amd64.whl", hash = "sha256:f742052af0e1e75523bde18895a9ed016ecf1e5aa0ecddfcc3658fd11a1ff417"}, +] + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cusolver-cu11" +version = "11.4.0.1" +description = "CUDA solver native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:72fa7261d755ed55c0074960df5904b65e2326f7adce364cbe4945063c1be412"}, + {file = "nvidia_cusolver_cu11-11.4.0.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:700b781bfefd57d161443aff9ace1878584b93e0b2cfef3d6e9296d96febbf99"}, + {file = "nvidia_cusolver_cu11-11.4.0.1-py3-none-win_amd64.whl", hash = "sha256:00f70b256add65f8c1eb3b6a65308795a93e7740f6df9e273eccbba770d370c4"}, +] + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-cusparse-cu11" +version = "11.7.4.91" +description = "CUSPARSE native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:a3389de714db63321aa11fbec3919271f415ef19fda58aed7f2ede488c32733d"}, + {file = "nvidia_cusparse_cu11-11.7.4.91-py3-none-win_amd64.whl", hash = "sha256:304a01599534f5186a8ed1c3756879282c72c118bc77dd890dc1ff868cad25b9"}, +] + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] +name = "nvidia-nccl-cu11" +version = "2.14.3" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:5e5534257d1284b8e825bc3a182c6f06acd6eb405e9f89d49340e98cd8f136eb"}, +] + +[[package]] +name = "nvidia-nvtx-cu11" +version = "11.7.91" +description = "NVIDIA Tools Extension" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl", hash = "sha256:b22c64eee426a62fc00952b507d6d29cf62b4c9df7a480fcc417e540e05fd5ac"}, + {file = "nvidia_nvtx_cu11-11.7.91-py3-none-win_amd64.whl", hash = "sha256:dfd7fcb2a91742513027d63a26b757f38dd8b07fecac282c4d132a9d373ff064"}, +] + +[package.dependencies] +setuptools = "*" +wheel = "*" + +[[package]] name = "packaging" version = "23.1" description = "Core utilities for Python packages" @@ -496,24 +667,6 @@ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx- test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] [[package]] -name = "pre-commit" -version = "3.3.3" -description = "A framework for managing and maintaining multi-language pre-commit hooks." -optional = false -python-versions = ">=3.8" -files = [ - {file = "pre_commit-3.3.3-py2.py3-none-any.whl", hash = "sha256:10badb65d6a38caff29703362271d7dca483d01da88f9d7e05d0b97171c136cb"}, - {file = "pre_commit-3.3.3.tar.gz", hash = "sha256:a2256f489cd913d575c145132ae196fe335da32d91a8294b7afe6622335dd023"}, -] - -[package.dependencies] -cfgv = ">=2.0.0" -identify = ">=1.0.0" -nodeenv = ">=0.11.1" -pyyaml = ">=5.1" -virtualenv = ">=20.10.0" - -[[package]] name = "pylint" version = "2.17.5" description = "python code static checker" @@ -527,10 +680,7 @@ files = [ [package.dependencies] astroid = ">=2.15.6,<=2.17.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} -dill = [ - {version = ">=0.2", markers = "python_version < \"3.11\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, -] +dill = {version = ">=0.2", markers = "python_version < \"3.11\""} isort = ">=4.2.5,<6" mccabe = ">=0.6,<0.8" platformdirs = ">=2.2.0" @@ -542,67 +692,18 @@ spelling = ["pyenchant (>=3.2,<4.0)"] testutils = ["gitpython (>3)"] [[package]] -name = "pyyaml" -version = "6.0.1" -description = "YAML parser and emitter for Python" -optional = false -python-versions = ">=3.6" -files = [ - {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, - {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, - {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, - {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, - {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, - {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, - {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, - {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, - {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, -] - -[[package]] name = "setuptools" -version = "68.1.0" +version = "68.1.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-68.1.0-py3-none-any.whl", hash = "sha256:e13e1b0bc760e9b0127eda042845999b2f913e12437046e663b833aa96d89715"}, - {file = "setuptools-68.1.0.tar.gz", hash = "sha256:d59c97e7b774979a5ccb96388efc9eb65518004537e85d52e81eaee89ab6dd91"}, + {file = "setuptools-68.1.2-py3-none-any.whl", hash = "sha256:3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b"}, + {file = "setuptools-68.1.2.tar.gz", hash = "sha256:3d4dfa6d95f1b101d695a6160a7626e15583af71a5f52176efa5d39a054d475d"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5,<=7.1.2)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] @@ -644,60 +745,90 @@ files = [ [[package]] name = "torch" -version = "2.0.1+cpu" +version = "2.0.0" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.0.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:fec257249ba014c68629a1994b0c6e7356e20e1afc77a87b9941a40e5095285d"}, - {file = "torch-2.0.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:ca88b499973c4c027e32c4960bf20911d7e984bd0c55cda181dc643559f3d93f"}, - {file = "torch-2.0.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:274d4acf486ef50ce1066ffe9d500beabb32bde69db93e3b71d0892dd148956c"}, - {file = "torch-2.0.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:e2603310bdff4b099c4c41ae132192fc0d6b00932ae2621d52d87218291864be"}, - {file = "torch-2.0.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:8046f49deae5a3d219b9f6059a1f478ae321f232e660249355a8bf6dcaa810c1"}, - {file = "torch-2.0.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:2ac4382ff090035f9045b18afe5763e2865dd35f2d661c02e51f658d95c8065a"}, - {file = "torch-2.0.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:73482a223d577407c45685fde9d2a74ba42f0d8d9f6e1e95c08071dc55c47d7b"}, - {file = "torch-2.0.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:f263f8e908288427ae81441fef540377f61e339a27632b1bbe33cf78292fdaea"}, + {file = "torch-2.0.0-1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c9090bda7d2eeeecd74f51b721420dbeb44f838d4536cc1b284e879417e3064a"}, + {file = "torch-2.0.0-1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bd42db2a48a20574d2c33489e120e9f32789c4dc13c514b0c44272972d14a2d7"}, + {file = "torch-2.0.0-1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:8969aa8375bcbc0c2993e7ede0a7f889df9515f18b9b548433f412affed478d9"}, + {file = "torch-2.0.0-1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ab2da16567cb55b67ae39e32d520d68ec736191d88ac79526ca5874754c32203"}, + {file = "torch-2.0.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:7a9319a67294ef02459a19738bbfa8727bb5307b822dadd708bc2ccf6c901aca"}, + {file = "torch-2.0.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:9f01fe1f6263f31bd04e1757946fd63ad531ae37f28bb2dbf66f5c826ee089f4"}, + {file = "torch-2.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:527f4ae68df7b8301ee6b1158ca56350282ea633686537b30dbb5d7b4a52622a"}, + {file = "torch-2.0.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:ce9b5a49bd513dff7950a5a07d6e26594dd51989cee05ba388b03e8e366fd5d5"}, + {file = "torch-2.0.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:53e1c33c6896583cdb9a583693e22e99266444c4a43392dddc562640d39e542b"}, + {file = "torch-2.0.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:09651bff72e439d004c991f15add0c397c66f98ab36fe60d5514b44e4da722e8"}, + {file = "torch-2.0.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d439aec349c98f12819e8564b8c54008e4613dd4428582af0e6e14c24ca85870"}, + {file = "torch-2.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:2802f84f021907deee7e9470ed10c0e78af7457ac9a08a6cd7d55adef835fede"}, + {file = "torch-2.0.0-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:01858620f25f25e7a9ec4b547ff38e5e27c92d38ec4ccba9cfbfb31d7071ed9c"}, + {file = "torch-2.0.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:9a2e53b5783ef5896a6af338b36d782f28e83c8ddfc2ac44b67b066d9d76f498"}, + {file = "torch-2.0.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:ec5fff2447663e369682838ff0f82187b4d846057ef4d119a8dea7772a0b17dd"}, + {file = "torch-2.0.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:11b0384fe3c18c01b8fc5992e70fc519cde65e44c51cc87be1838c1803daf42f"}, + {file = "torch-2.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:e54846aa63855298cfb1195487f032e413e7ac9cbfa978fda32354cc39551475"}, + {file = "torch-2.0.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:cc788cbbbbc6eb4c90e52c550efd067586c2693092cf367c135b34893a64ae78"}, + {file = "torch-2.0.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:d292640f0fd72b7a31b2a6e3b635eb5065fcbedd4478f9cad1a1e7a9ec861d35"}, + {file = "torch-2.0.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:6befaad784004b7af357e3d87fa0863c1f642866291f12a4c2af2de435e8ac5c"}, + {file = "torch-2.0.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a83b26bd6ae36fbf5fee3d56973d9816e2002e8a3b7d9205531167c28aaa38a7"}, + {file = "torch-2.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:c7e67195e1c3e33da53954b026e89a8e1ff3bc1aeb9eb32b677172d4a9b5dcbf"}, + {file = "torch-2.0.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6e0b97beb037a165669c312591f242382e9109a240e20054d5a5782d9236cad0"}, + {file = "torch-2.0.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:297a4919aff1c0f98a58ebe969200f71350a1d4d4f986dbfd60c02ffce780e99"}, ] [package.dependencies] filelock = "*" jinja2 = "*" networkx = "*" +nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu11 = {version = "11.7.101", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu11 = {version = "10.9.0.58", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu11 = {version = "10.2.10.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu11 = {version = "11.4.0.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu11 = {version = "11.7.4.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu11 = {version = "2.14.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu11 = {version = "11.7.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" +triton = {version = "2.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} typing-extensions = "*" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -[package.source] -type = "legacy" -url = "https://download.pytorch.org/whl/cpu" -reference = "pytorch-cpu" - [[package]] name = "torchaudio" -version = "2.0.2+cpu" +version = "2.0.1" description = "An audio package for PyTorch" optional = false python-versions = "*" files = [ - {file = "torchaudio-2.0.2+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:80eccef075f9e356f7a4ad27c57614c726906a30617424e2d86a08de179feeff"}, - {file = "torchaudio-2.0.2+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:514995f6418d193caf1a2e06c912f603959ec6d8557b48ef52a9ceb8e5d180d1"}, - {file = "torchaudio-2.0.2+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:eb589737c30139c946b4c4f3a73307a6286c5402b6e5273dd5cb8768d0e73b3f"}, - {file = "torchaudio-2.0.2+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:a93c8e9ff28959b2e8b9a1347b34c5f94f33f617bc96fe17011b0d8cb5fb377f"}, - {file = "torchaudio-2.0.2+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:5196d7ed34863deaa28434ad02a043b01cfbd0c42b58d61ca56825577aeca74e"}, - {file = "torchaudio-2.0.2+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:b29d34bb39b8e613c562d8c04bd4108be6ae20221ee043e29b81bbef95f3f5bf"}, - {file = "torchaudio-2.0.2+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:30fc6926b51892cda933e98616fcdbc1f53084680210809cbc03c301d585ed9c"}, - {file = "torchaudio-2.0.2+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:aa7425184157a0356fe08deb74a517f5a0e1f27722f362a5ac3089904bf17001"}, + {file = "torchaudio-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5d21ebbb55e7040d418d5062b0e882f9660d68b477b38fd436fa6c92ccbb52a"}, + {file = "torchaudio-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6dbcd93b29d71a2f500f36a34ea5e467f510f773da85322098e6bdd8c9dc9948"}, + {file = "torchaudio-2.0.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:5fdaba10ff06d098d603d9eb8d2ff541c3f3fe28ba178a78787190cec0d5187f"}, + {file = "torchaudio-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:6419199c773c5045c594ff950d5e5dbbfa6c830892ec09721d4ed8704b702bfd"}, + {file = "torchaudio-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:a5c81e480e5dcdcba065af1e3e31678ac29518991f00260094d37a39e63d76e5"}, + {file = "torchaudio-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e2a047675493c0aa258fec621ef40e8b01abe3d8dbc872152e4b5998418aa3c5"}, + {file = "torchaudio-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:91a28e587f708a03320eddbcc4a7dd1ad7150b3d4846b6c1557d85cc89a8d06c"}, + {file = "torchaudio-2.0.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ba7740d98f601218ff667598ab3d9dab5f326878374fcb52d656f4ff033b9e96"}, + {file = "torchaudio-2.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:f401b192921c8b77cc5e478ede589b256dba463f1cee91172ecb376fea45a288"}, + {file = "torchaudio-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:0ef6754cf75ca5fd5117cb6243a6cf33552d67e9af0075aa6954b2c34bbf1036"}, + {file = "torchaudio-2.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:022ca1baa4bb819b78343bd47b57ff6dc6f9fc19fa4ef269946aadf7e62db3c0"}, + {file = "torchaudio-2.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a153ad5cdb62de8ec9fd1360a0d080bbaf39d578ae04e788db211571e675b7e0"}, + {file = "torchaudio-2.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:aa7897774ab4156d0b72f7078b823ebc1371ee24c50df965447782889552367a"}, + {file = "torchaudio-2.0.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:48d133593cddfe0424a350b566d54065bf6fe7469654de7add2f11b3ef03c5d9"}, + {file = "torchaudio-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:ac65eb067feee435debba81adfe8337fa007a06de6508c0d80261c5562b6d098"}, + {file = "torchaudio-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e3c6c8f9ea9f0e2df7a0b9375b0dcf955906e38fc12fab542b72a861564af8e7"}, + {file = "torchaudio-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d0cf0779a334ec1861e9fa28bceb66a633c42e8f6b3322e2e37ff9f20d0ae81"}, + {file = "torchaudio-2.0.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ab7acd2b5d351a2c65e4d935bb90b9256382bed93df57ee177bdbbe31c3cc984"}, + {file = "torchaudio-2.0.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:77b953fd7278773269a9477315b8998ae7e5011cc4b2907e0df18162327482f1"}, + {file = "torchaudio-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:c01bcea9d4c4a6616452e6cbd44d55913d8e6dee58191b925f35d46a2bf6e71b"}, ] [package.dependencies] -torch = "2.0.1" - -[package.source] -type = "legacy" -url = "https://download.pytorch.org/whl/cpu" -reference = "pytorch-cpu" +torch = "2.0.0" [[package]] name = "tqdm" @@ -720,6 +851,43 @@ slack = ["slack-sdk"] telegram = ["requests"] [[package]] +name = "triton" +version = "2.0.0" +description = "A language and compiler for custom Deep Learning operations" +optional = false +python-versions = "*" +files = [ + {file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"}, + {file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"}, + {file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"}, + {file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"}, + {file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"}, + {file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"}, + {file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"}, + {file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"}, + {file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"}, + {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"}, + {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"}, + {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"}, + {file = "triton-2.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fedce6a381901b1547e0e7e1f2546e4f65dca6d91e2d8a7305a2d1f5551895be"}, + {file = "triton-2.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75834f27926eab6c7f00ce73aaf1ab5bfb9bec6eb57ab7c0bfc0a23fac803b4c"}, + {file = "triton-2.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0117722f8c2b579cd429e0bee80f7731ae05f63fe8e9414acd9a679885fcbf42"}, + {file = "triton-2.0.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcd9be5d0c2e45d2b7e6ddc6da20112b6862d69741576f9c3dbaf941d745ecae"}, + {file = "triton-2.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42a0d2c3fc2eab4ba71384f2e785fbfd47aa41ae05fa58bf12cb31dcbd0aeceb"}, + {file = "triton-2.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52c47b72c72693198163ece9d90a721299e4fb3b8e24fd13141e384ad952724f"}, +] + +[package.dependencies] +cmake = "*" +filelock = "*" +lit = "*" +torch = "*" + +[package.extras] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + +[[package]] name = "typing-extensions" version = "4.7.1" description = "Backported and Experimental Type Hints for Python 3.7+" @@ -731,24 +899,18 @@ files = [ ] [[package]] -name = "virtualenv" -version = "20.24.3" -description = "Virtual Python Environment builder" +name = "wheel" +version = "0.41.1" +description = "A built-package format for Python" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.24.3-py3-none-any.whl", hash = "sha256:95a6e9398b4967fbcb5fef2acec5efaf9aa4972049d9ae41f95e0972a683fd02"}, - {file = "virtualenv-20.24.3.tar.gz", hash = "sha256:e5c3b4ce817b0b328af041506a2a299418c98747c4b1e68cb7527e74ced23efc"}, + {file = "wheel-0.41.1-py3-none-any.whl", hash = "sha256:473219bd4cbedc62cea0cb309089b593e47c15c4a2531015f94e4e3b9a0f6981"}, + {file = "wheel-0.41.1.tar.gz", hash = "sha256:12b911f083e876e10c595779709f8a88a59f45aacc646492a67fe9ef796c1b47"}, ] -[package.dependencies] -distlib = ">=0.3.7,<1" -filelock = ">=3.12.2,<4" -platformdirs = ">=3.9.1,<4" - [package.extras] -docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +test = ["pytest (>=6.0.0)", "setuptools (>=65)"] [[package]] name = "wrapt" @@ -836,5 +998,5 @@ files = [ [metadata] lock-version = "2.0" -python-versions = "^3.10" -content-hash = "e7b0344d7d2f66cddf80ac9fdbc63b839297b7443c043f960805738d54a79d43" +python-versions = "~3.10" +content-hash = "98a9b21411812f0514cbac138756a635b4f9283835791153f467c7f3fe1f4fd1" diff --git a/pyproject.toml b/pyproject.toml index 606633b..8490aa5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,40 +1,28 @@ [tool.poetry] -name = "swr2_asr" +name = "swr2-asr" version = "0.1.0" -description = "Automatic speech recognition with pytorch for SWR2" -authors = ["Philipp Merkel <philippmerkel@outlook.com>", "Marvin Borner <git@marvinborner.de>", "Valentin Schmidt <>", "Silja Kasper <>"] +description = "" +authors = ["Philipp Merkel <philippmerkel@outlook.com>"] license = "MIT" readme = "readme.md" packages = [{include = "swr2_asr"}] [tool.poetry.dependencies] -python = "^3.10" -numpy = "^1.25.2" -click = "^8.1.6" -audioloader = {git = "https://github.com/KinWaiCheuk/AudioLoader.git"} +python = "~3.10" +torch = "2.0.0" +torchaudio = "2.0.1" +audioloader = {git = "https://github.com/marvinborner/AudioLoader.git"} tqdm = "^4.66.1" -torch = {version = "^2.0.1+cpu", source = "pytorch-cpu"} -torchaudio = {version = "^2.0.2+cpu", source = "pytorch-cpu"} +numpy = "^1.25.2" +mido = "^1.3.0" [tool.poetry.group.dev.dependencies] black = "^23.7.0" -pylint = "^2.17.5" mypy = "^1.5.1" -pre-commit = "^3.3.3" - -[[tool.poetry.source]] -name = "pytorch-cpu" -url = "https://download.pytorch.org/whl/cpu" -priority = "explicit" - - -[[tool.poetry.source]] -name = "pytorch-gpu" -url = "https://download.pytorch.org/whl/cu118" -priority = "explicit" +pylint = "^2.17.5" [tool.poetry.scripts] -train = "swr2_asr.train:main" +train = "swr2_asr.train:run_cli" [build-system] requires = ["poetry-core"] @@ -5,7 +5,7 @@ recogniton 2 (SWR2) in the summer term 2023. # Installation ``` -pip install -r requirements.txt +poetry install ``` # Usage @@ -14,13 +14,13 @@ pip install -r requirements.txt Train using the provided train script: - poetry run train --data PATH/TO/DATA --lr 0.01 + poetry run train ## Evaluation ## Inference - poetry run recognize --data PATH/TO/FILE + poetry run recognize ## CI diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py new file mode 100644 index 0000000..a6b0010 --- /dev/null +++ b/swr2_asr/inference_test.py @@ -0,0 +1,74 @@ +"""Training script for the ASR model.""" +from AudioLoader.speech.mls import MultilingualLibriSpeech +import torch +import torchaudio +import torchaudio.functional as F + + +class GreedyCTCDecoder(torch.nn.Module): + def __init__(self, labels, blank=0) -> None: + super().__init__() + self.labels = labels + self.blank = blank + + def forward(self, emission: torch.Tensor) -> str: + """Given a sequence emission over labels, get the best path string + Args: + emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. + + Returns: + str: The resulting transcript + """ + indices = torch.argmax(emission, dim=-1) # [num_seq,] + indices = torch.unique_consecutive(indices, dim=-1) + indices = [i for i in indices if i != self.blank] + return "".join([self.labels[i] for i in indices]) + + +def main() -> None: + """Main function.""" + # choose between cuda, cpu and mps devices + device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "mps" + device = torch.device(device) + + torch.random.manual_seed(42) + + bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H + + print(f"Sample rate (model): {bundle.sample_rate}") + print(f"Labels (model): {bundle.get_labels()}") + + model = bundle.get_model().to(device) + + print(model.__class__) + + # only do all things for one single sample + dataset = MultilingualLibriSpeech( + "data", "mls_german_opus", split="train", download=True + ) + + print(dataset[0]) + + # load waveforms and sample rate from dataset + waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"] + + if sample_rate != bundle.sample_rate: + waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate)) + + waveform.to(device) + + with torch.inference_mode(): + features, _ = model.extract_features(waveform) + + with torch.inference_mode(): + emission, _ = model(waveform) + + decoder = GreedyCTCDecoder(labels=bundle.get_labels()) + transcript = decoder(emission[0]) + + print(transcript) + + +if __name__ == "__main__": + main() diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py new file mode 100644 index 0000000..977462d --- /dev/null +++ b/swr2_asr/loss_scores.py @@ -0,0 +1,185 @@ +import numpy as np + + +def avg_wer(wer_scores, combined_ref_len): + return float(sum(wer_scores)) / float(combined_ref_len) + + +def _levenshtein_distance(ref, hyp): + """Levenshtein distance is a string metric for measuring the difference + between two sequences. Informally, the levenshtein disctance is defined as + the minimum number of single-character edits (substitutions, insertions or + deletions) required to change one word into the other. We can naturally + extend the edits to word level when calculate levenshtein disctance for + two sentences. + """ + m = len(ref) + n = len(hyp) + + # special case + if ref == hyp: + return 0 + if m == 0: + return n + if n == 0: + return m + + if m < n: + ref, hyp = hyp, ref + m, n = n, m + + # use O(min(m, n)) space + distance = np.zeros((2, n + 1), dtype=np.int32) + + # initialize distance matrix + for j in range(0, n + 1): + distance[0][j] = j + + # calculate levenshtein distance + for i in range(1, m + 1): + prev_row_idx = (i - 1) % 2 + cur_row_idx = i % 2 + distance[cur_row_idx][0] = i + for j in range(1, n + 1): + if ref[i - 1] == hyp[j - 1]: + distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] + else: + s_num = distance[prev_row_idx][j - 1] + 1 + i_num = distance[cur_row_idx][j - 1] + 1 + d_num = distance[prev_row_idx][j] + 1 + distance[cur_row_idx][j] = min(s_num, i_num, d_num) + + return distance[m % 2][n] + + +def word_errors( + reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " +): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in word-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Levenshtein distance and word number of reference sentence. + :rtype: list + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + ref_words = reference.split(delimiter) + hyp_words = hypothesis.split(delimiter) + + edit_distance = _levenshtein_distance(ref_words, hyp_words) + return float(edit_distance), len(ref_words) + + +def char_errors( + reference: str, + hypothesis: str, + ignore_case: bool = False, + remove_space: bool = False, +): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in char-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Levenshtein distance and length of reference sentence. + :rtype: list + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + join_char = " " + if remove_space: + join_char = "" + + reference = join_char.join(filter(None, reference.split(" "))) + hypothesis = join_char.join(filter(None, hypothesis.split(" "))) + + edit_distance = _levenshtein_distance(reference, hypothesis) + return float(edit_distance), len(reference) + + +def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "): + """Calculate word error rate (WER). WER compares reference text and + hypothesis text in word-level. WER is defined as: + .. math:: + WER = (Sw + Dw + Iw) / Nw + where + .. code-block:: text + Sw is the number of words subsituted, + Dw is the number of words deleted, + Iw is the number of words inserted, + Nw is the number of words in the reference + We can use levenshtein distance to calculate WER. Please draw an attention + that empty items will be removed when splitting sentences by delimiter. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Word error rate. + :rtype: float + :raises ValueError: If word number of reference is zero. + """ + edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter) + + if ref_len == 0: + raise ValueError("Reference's word number should be greater than 0.") + + wer = float(edit_distance) / ref_len + return wer + + +def cer(reference, hypothesis, ignore_case=False, remove_space=False): + """Calculate charactor error rate (CER). CER compares reference text and + hypothesis text in char-level. CER is defined as: + .. math:: + CER = (Sc + Dc + Ic) / Nc + where + .. code-block:: text + Sc is the number of characters substituted, + Dc is the number of characters deleted, + Ic is the number of characters inserted + Nc is the number of characters in the reference + We can use levenshtein distance to calculate CER. Chinese input should be + encoded to unicode. Please draw an attention that the leading and tailing + space characters will be truncated and multiple consecutive space + characters in a sentence will be replaced by one space character. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Character error rate. + :rtype: float + :raises ValueError: If the reference length is zero. + """ + edit_distance, ref_len = char_errors( + reference, hypothesis, ignore_case, remove_space + ) + + if ref_len == 0: + raise ValueError("Length of reference should be greater than 0.") + + cer = float(edit_distance) / ref_len + return cer diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 7a9ffec..29f9372 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,16 +1,484 @@ """Training script for the ASR model.""" -import os from AudioLoader.speech.mls import MultilingualLibriSpeech +import click +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torchaudio +from .loss_scores import cer, wer -def main() -> None: - """Main function.""" - dataset = MultilingualLibriSpeech( - "data", "mls_polish_opus", split="train", download=(not os.path.isdir("data")) +class TextTransform: + """Maps characters to integers and vice versa""" + + def __init__(self): + char_map_str = """ + ' 0 + <SPACE> 1 + a 2 + b 3 + c 4 + d 5 + e 6 + f 7 + g 8 + h 9 + i 10 + j 11 + k 12 + l 13 + m 14 + n 15 + o 16 + p 17 + q 18 + r 19 + s 20 + t 21 + u 22 + v 23 + w 24 + x 25 + y 26 + z 27 + ä 28 + ö 29 + ü 30 + ß 31 + - 32 + é 33 + è 34 + à 35 + ù 36 + ç 37 + â 38 + ê 39 + î 40 + ô 41 + û 42 + ë 43 + ï 44 + ü 45 + """ + self.char_map = {} + self.index_map = {} + for line in char_map_str.strip().split("\n"): + char, index = line.split() + self.char_map[char] = int(index) + self.index_map[int(index)] = char + self.index_map[1] = " " + + def text_to_int(self, text): + """Use a character map and convert text to an integer sequence""" + int_sequence = [] + for char in text: + if char == " ": + mapped_char = self.char_map["<SPACE>"] + else: + mapped_char = self.char_map[char] + int_sequence.append(mapped_char) + return int_sequence + + def int_to_text(self, labels): + """Use a character map and convert integer labels to an text sequence""" + string = [] + for i in labels: + string.append(self.index_map[i]) + return "".join(string).replace("<SPACE>", " ") + + +train_audio_transforms = nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), +) + +valid_audio_transforms = torchaudio.transforms.MelSpectrogram() + +text_transform = TextTransform() + + +def data_processing(data, data_type="train"): + """Return the spectrograms, labels, and their lengths.""" + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + for sample in data: + if data_type == "train": + spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) + elif data_type == "valid": + spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) + else: + raise ValueError("data_type should be train or valid") + spectrograms.append(spec) + label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower())) + labels.append(label) + input_lengths.append(spec.shape[0] // 2) + label_lengths.append(len(label)) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + .unsqueeze(1) + .transpose(2, 3) ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths + + +def greedy_decoder( + output, labels, label_lengths, blank_label=28, collapse_repeated=True +): + """Greedily decode a sequence.""" + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append( + text_transform.int_to_text(labels[i][: label_lengths[i]].tolist()) + ) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(text_transform.int_to_text(decode)) + return decodes, targets + + +class CNNLayerNorm(nn.Module): + """Layer normalization built for cnns input""" + + def __init__(self, n_feats: int): + super(CNNLayerNorm, self).__init__() + self.layer_norm = nn.LayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) + data = self.layer_norm(data) + return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) + + +class ResidualCNN(nn.Module): + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + stride: int, + dropout: float, + n_feats: int, + ): + super(ResidualCNN, self).__init__() + + self.cnn1 = nn.Conv2d( + in_channels, out_channels, kernel, stride, padding=kernel // 2 + ) + self.cnn2 = nn.Conv2d( + out_channels, + out_channels, + kernel, + stride, + padding=kernel // 2, + ) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.layer_norm1 = CNNLayerNorm(n_feats) + self.layer_norm2 = CNNLayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + residual = data # (batch, channel, feature, time) + data = self.layer_norm1(data) + data = F.gelu(data) + data = self.dropout1(data) + data = self.cnn1(data) + data = self.layer_norm2(data) + data = F.gelu(data) + data = self.dropout2(data) + data = self.cnn2(data) + data += residual + return data # (batch, channel, feature, time) + + +class BidirectionalGRU(nn.Module): + """BIdirectional GRU with Layer Normalization and Dropout""" + + def __init__( + self, + rnn_dim: int, + hidden_size: int, + dropout: float, + batch_first: bool, + ): + super(BidirectionalGRU, self).__init__() + + self.bi_gru = nn.GRU( + input_size=rnn_dim, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=True, + ) + self.layer_norm = nn.LayerNorm(rnn_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, data): + """data (batch, time, feature)""" + data = self.layer_norm(data) + data = F.gelu(data) + data = self.dropout(data) + data, _ = self.bi_gru(data) + return data + + +class SpeechRecognitionModel(nn.Module): + """Speech Recognition Model Inspired by DeepSpeech 2""" + + def __init__( + self, + n_cnn_layers: int, + n_rnn_layers: int, + rnn_dim: int, + n_class: int, + n_feats: int, + stride: int = 2, + dropout: float = 0.1, + ): + super(SpeechRecognitionModel, self).__init__() + n_feats //= 2 + self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + # n residual cnn layers with filter size of 32 + self.rescnn_layers = nn.Sequential( + *[ + ResidualCNN( + 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats + ) + for _ in range(n_cnn_layers) + ] + ) + self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) + self.birnn_layers = nn.Sequential( + *[ + BidirectionalGRU( + rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, + hidden_size=rnn_dim, + dropout=dropout, + batch_first=i == 0, + ) + for i in range(n_rnn_layers) + ] + ) + self.classifier = nn.Sequential( + nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(rnn_dim, n_class), + ) + + def forward(self, data): + """data (batch, channel, feature, time)""" + data = self.cnn(data) + data = self.rescnn_layers(data) + sizes = data.size() + data = data.view( + sizes[0], sizes[1] * sizes[2], sizes[3] + ) # (batch, feature, time) + data = data.transpose(1, 2) # (batch, time, feature) + data = self.fully_connected(data) + data = self.birnn_layers(data) + data = self.classifier(data) + return data + + +class IterMeter(object): + """keeps track of total iterations""" + + def __init__(self): + self.val = 0 + + def step(self): + """step""" + self.val += 1 + + def get(self): + """get""" + return self.val + + +def train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, +): + """Train""" + model.train() + data_len = len(train_loader.dataset) + for batch_idx, _data in enumerate(train_loader): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + optimizer.zero_grad() + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + loss.backward() + + optimizer.step() + scheduler.step() + iter_meter.step() + if batch_idx % 100 == 0 or batch_idx == data_len: + print( + f"Train Epoch: \ + {epoch} \ + [{batch_idx * len(spectrograms)}/{data_len} \ + ({100.0 * batch_idx / len(train_loader)}%)]\t \ + Loss: {loss.item()}" + ) + + +def test(model, device, test_loader, criterion): + """Test""" + print("\nevaluating...") + model.eval() + test_loss = 0 + test_cer, test_wer = [], [] + with torch.no_grad(): + for _data in test_loader: + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + test_loss += loss.item() / len(test_loader) + + decoded_preds, decoded_targets = greedy_decoder( + output.transpose(0, 1), labels, label_lengths + ) + for j, pred in enumerate(decoded_preds): + test_cer.append(cer(decoded_targets[j], pred)) + test_wer.append(wer(decoded_targets[j], pred)) + + avg_cer = sum(test_cer) / len(test_cer) + avg_wer = sum(test_wer) / len(test_wer) + + print( + f"Test set: Average loss:\ + {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" + ) + + +def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: + """Runs the training script.""" + hparams = { + "n_cnn_layers": 3, + "n_rnn_layers": 5, + "rnn_dim": 512, + "n_class": 46, + "n_feats": 128, + "stride": 2, + "dropout": 0.1, + "learning_rate": learning_rate, + "batch_size": batch_size, + "epochs": epochs, + } + + use_cuda = torch.cuda.is_available() + torch.manual_seed(42) + device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member + # device = torch.device("mps") + + train_dataset = MultilingualLibriSpeech( + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False + ) + test_dataset = MultilingualLibriSpeech( + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False + ) + + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + + train_loader = DataLoader( + train_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + test_loader = DataLoader( + test_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + model = SpeechRecognitionModel( + hparams["n_cnn_layers"], + hparams["n_rnn_layers"], + hparams["rnn_dim"], + hparams["n_class"], + hparams["n_feats"], + hparams["stride"], + hparams["dropout"], + ).to(device) + + print( + "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) + ) + + optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + criterion = nn.CTCLoss(blank=28).to(device) + + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=hparams["learning_rate"], + steps_per_epoch=int(len(train_loader)), + epochs=hparams["epochs"], + anneal_strategy="linear", + ) + + iter_meter = IterMeter() + for epoch in range(1, epochs + 1): + train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) + test(model=model, device=device, test_loader=test_loader, criterion=criterion) + - print(dataset[1]) +@click.command() +@click.option("--learning-rate", default=1e-3, help="Learning rate") +@click.option("--batch_size", default=1, help="Batch size") +@click.option("--epochs", default=1, help="Number of epochs") +def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None: + """Runs the training script.""" + run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs) if __name__ == "__main__": - main() + run(learning_rate=5e-4, batch_size=16, epochs=1) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py |