From 9ca17d8a83369257f4cc42c963e25baf35a28f8f Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 20:51:44 +0200 Subject: oopsie, messed up deps --- poetry.lock | 108 +++++++++++++++++++++++++++--------------------------------- 1 file changed, 49 insertions(+), 59 deletions(-) (limited to 'poetry.lock') diff --git a/poetry.lock b/poetry.lock index fcac817..3901b8c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "astroid" @@ -21,33 +21,33 @@ wrapt = [ [[package]] name = "black" -version = "23.7.0" +version = "23.9.1" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, - {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, - {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, - {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, - {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, - {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, - {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, - {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, - {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, - {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, - {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, - {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, - {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, - {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, - {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, - {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:d6bc09188020c9ac2555a498949401ab35bb6bf76d4e0f8ee251694664df6301"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:13ef033794029b85dfea8032c9d3b92b42b526f1ff4bf13b2182ce4e917f5100"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:75a2dc41b183d4872d3a500d2b9c9016e67ed95738a3624f4751a0cb4818fe71"}, + {file = "black-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13a2e4a93bb8ca74a749b6974925c27219bb3df4d42fc45e948a5d9feb5122b7"}, + {file = "black-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:adc3e4442eef57f99b5590b245a328aad19c99552e0bdc7f0b04db6656debd80"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:8431445bf62d2a914b541da7ab3e2b4f3bc052d2ccbf157ebad18ea126efb91f"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:8fc1ddcf83f996247505db6b715294eba56ea9372e107fd54963c7553f2b6dfe"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:7d30ec46de88091e4316b17ae58bbbfc12b2de05e069030f6b747dfc649ad186"}, + {file = "black-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:031e8c69f3d3b09e1aa471a926a1eeb0b9071f80b17689a655f7885ac9325a6f"}, + {file = "black-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:538efb451cd50f43aba394e9ec7ad55a37598faae3348d723b59ea8e91616300"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:638619a559280de0c2aa4d76f504891c9860bb8fa214267358f0a20f27c12948"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:a732b82747235e0542c03bf352c126052c0fbc458d8a239a94701175b17d4855"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:cf3a4d00e4cdb6734b64bf23cd4341421e8953615cba6b3670453737a72ec204"}, + {file = "black-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf99f3de8b3273a8317681d8194ea222f10e0133a24a7548c73ce44ea1679377"}, + {file = "black-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:14f04c990259576acd093871e7e9b14918eb28f1866f91968ff5524293f9c573"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:c619f063c2d68f19b2d7270f4cf3192cb81c9ec5bc5ba02df91471d0b88c4c5c"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:6a3b50e4b93f43b34a9d3ef00d9b6728b4a722c997c99ab09102fd5efdb88325"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c46767e8df1b7beefb0899c4a95fb43058fa8500b6db144f4ff3ca38eb2f6393"}, + {file = "black-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50254ebfa56aa46a9fdd5d651f9637485068a1adf42270148cd101cdf56e0ad9"}, + {file = "black-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:403397c033adbc45c2bd41747da1f7fc7eaa44efbee256b53842470d4ac5a70f"}, + {file = "black-23.9.1-py3-none-any.whl", hash = "sha256:6ccd59584cc834b6d127628713e4b6b968e5f79572da66284532525a042549f9"}, + {file = "black-23.9.1.tar.gz", hash = "sha256:24b6b3ff5c6d9ea08a8888f6977eae858e1f340d7260cf56d70a49823236b62d"}, ] [package.dependencies] @@ -57,6 +57,7 @@ packaging = ">=22.0" pathspec = ">=0.9.0" platformdirs = ">=2" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} [package.extras] colorama = ["colorama (>=0.4.3)"] @@ -80,28 +81,28 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "cmake" -version = "3.27.2" +version = "3.27.4.1" description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software" optional = false python-versions = "*" files = [ - {file = "cmake-3.27.2-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:96ac856c4d6b2104408848f0005a8ab2229d4135b171ea9a03e8c33039ede420"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:11fe6129d07982721c5965fd804a4056b8c6e9c4f482ac9e0fe41bb3abc1ab5f"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:f0c64e89e2ea59592980c4fe3821d712fee0e74cf87c2aaec5b3ab9aa809a57c"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ca7650477dff2a1138776b28b79c0e99127be733d3978922e8f87b56a433eed6"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:ab2e40fe09e76a7ef67da2bbbf7a4cd1f52db4f1c7b6ccdda2539f918830343a"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:980ee19f12c808cb8ddb56fdcee832501a9f9631799d8b4fc625c0a0b5fb4c55"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:115d30ca0760e3861d9ad6b3288cd11ee72a785b81227da0c1765d3b84e2c009"}, - {file = "cmake-3.27.2-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:efc338c939d6d435890a52458a260bf0942bd8392b648d7532a72c1ec0764e18"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:7f7438c60ccc01765b67abfb1797787c3b9459d500a804ed70a4cc181bc02204"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:294f008734267e0eee1574ad1b911bed137bc907ab19d60a618dab4615aa1fca"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:197a34dc62ee149ced343545fac67e5a30b93fda65250b065726f86ce92bdada"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:afb46ad883b174fb64347802ba5878423551dbd5847bb64669c39a5957c06eb7"}, - {file = "cmake-3.27.2-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:83611ffd155e270a6b13bbf0cfd4e8688ebda634f448aa2e3734006c745bf33f"}, - {file = "cmake-3.27.2-py2.py3-none-win32.whl", hash = "sha256:53e12deb893da935e236f93accd47dbe2806620cd7654986234dc4487cc49652"}, - {file = "cmake-3.27.2-py2.py3-none-win_amd64.whl", hash = "sha256:611f9722c68c40352d38a6c01960ab038c3d0419e7aee3bf18f95b23031e0dfe"}, - {file = "cmake-3.27.2-py2.py3-none-win_arm64.whl", hash = "sha256:30620326b51ac2ce0d8f476747af6367a7ea21075c4d065fad9443904b07476a"}, - {file = "cmake-3.27.2.tar.gz", hash = "sha256:7cd6e2d7d5a1125f8c26c4f65214f8c942e3f276f98c16cb62ae382c35609f25"}, + {file = "cmake-3.27.4.1-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:65b79f1e8b6fa254697ee0b411aa4dff0d2309c1405af3448adf06cbd7ef0ac5"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:4a1d22ee72dcdc32d0f8bbf5691d2e9367585db8bfeafe7cffa2c4274127a801"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:37e9cad75184fbefe837311528d026901278b606707e9d14b58e3767d49d0aa6"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2f217fb281b068696fdcc4b198de62e9ded8bc0a93877684afc59db3507ccb44"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:a3bd8b9d0e294bd2b3ce27a850c9d924aee7e4f4c0bb56d66641cc1544314f58"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:871c8b5eaac959f079c2389c7a7f198fa5f86a029e8726fcb1f3e13d030c33e9"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:ec4a5bc2376dfc57065bfde6806183b331165f33457b7cc0fc0511260dde7c72"}, + {file = "cmake-3.27.4.1-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c36eb106dec60198264b25d4bd23cd9ea30b0af9200a143ec1db887c095306f7"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:2b0b53ec2e45cfe9982d0adf833b3519efc328c1e3cffae4d237841a1ed6edf4"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:a504815bcba0ece9aafb48a6b7770d6479756fda92f8b62f9ab7ff8a403a12d5"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:1aca07fccfa042a0379bb027e30f090a8239b18fd3f959391c5d77c22dd0a809"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:315eb37233e2d0b8fa01580e33439eaeaef65f1e41ad9ca269cbe68cc0a039a4"}, + {file = "cmake-3.27.4.1-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:94fec5a8bae1f3b62d8a8653ebcb7fa4007e2d0e713f94e4b2089f708c13548f"}, + {file = "cmake-3.27.4.1-py2.py3-none-win32.whl", hash = "sha256:8249a1ba7901b53661b44e59bdf6fd1e977e10843795788efe25d3374de6ed95"}, + {file = "cmake-3.27.4.1-py2.py3-none-win_amd64.whl", hash = "sha256:b72db11e13eafb46b9c53797d141e89886293db768feabef4b841accf666de54"}, + {file = "cmake-3.27.4.1-py2.py3-none-win_arm64.whl", hash = "sha256:0fb68660ce3954de99d1f41bedcf87063325c4cc891003f12de36472fa1efa28"}, + {file = "cmake-3.27.4.1.tar.gz", hash = "sha256:70526bbff5eeb7d4d6b921af1b80d2d29828302882f94a2cba93ad7d469b90f6"}, ] [package.extras] @@ -1110,19 +1111,19 @@ files = [ [[package]] name = "setuptools" -version = "68.1.2" +version = "68.2.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-68.1.2-py3-none-any.whl", hash = "sha256:3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b"}, - {file = "setuptools-68.1.2.tar.gz", hash = "sha256:3d4dfa6d95f1b101d695a6160a7626e15583af71a5f52176efa5d39a054d475d"}, + {file = "setuptools-68.2.1-py3-none-any.whl", hash = "sha256:eff96148eb336377ab11beee0c73ed84f1709a40c0b870298b0d058828761bae"}, + {file = "setuptools-68.2.1.tar.gz", hash = "sha256:56ee14884fd8d0cd015411f4a13f40b4356775a0aefd9ebc1d3bfb9a1acb32f1"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5,<=7.1.2)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" @@ -1369,17 +1370,6 @@ torch = "*" tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] -[[package]] -name = "types-tqdm" -version = "4.66.0.2" -description = "Typing stubs for tqdm" -optional = false -python-versions = "*" -files = [ - {file = "types-tqdm-4.66.0.2.tar.gz", hash = "sha256:9553a5e44c1d485fce19f505b8bd65c0c3e87e870678d1f2ed764ae59a55d45f"}, - {file = "types_tqdm-4.66.0.2-py3-none-any.whl", hash = "sha256:13dddd38908834abdf0acdc2b70cab7ac4bcc5ad7356ced450471662e58a0ffc"}, -] - [[package]] name = "typing-extensions" version = "4.7.1" @@ -1492,4 +1482,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "097dc1f31ef2717f69dd11bdf98584ff46ef5f563e97c76ffe169591105d4287" +content-hash = "b9efbbcd85e7d70578496491d81aa6ef8a610a77ffe134c08446300d5de42ed5" -- cgit v1.2.3 From 58b30927bd870604a4077a8af9ec3cad7b0be21c Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 21:52:42 +0200 Subject: changed config to yaml! --- config.philipp.yaml | 29 ++++++ config.train.yaml | 28 ++++++ poetry.lock | 51 ++++++++++- pyproject.toml | 1 + requirements.txt | 1 + swr2_asr/__main__.py | 12 --- swr2_asr/inference.py | 16 ++-- swr2_asr/model_deep_speech.py | 17 ---- swr2_asr/train.py | 192 +++++++++++++++++++--------------------- swr2_asr/utils/data.py | 7 +- swr2_asr/utils/tokenizer.py | 8 +- swr2_asr/utils/visualization.py | 8 +- 12 files changed, 218 insertions(+), 152 deletions(-) create mode 100644 config.philipp.yaml create mode 100644 config.train.yaml delete mode 100644 swr2_asr/__main__.py (limited to 'poetry.lock') diff --git a/config.philipp.yaml b/config.philipp.yaml new file mode 100644 index 0000000..638b5ef --- /dev/null +++ b/config.philipp.yaml @@ -0,0 +1,29 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 0.0005 + batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 1 # evaluate every n epochs + num_workers: 4 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: True # set to True if you want to use limited supervision + dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) + shuffle: True + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.json" + +checkpoints: + model_load_path: ~ # path to load model from + model_save_path: ~ # path to save model to \ No newline at end of file diff --git a/config.train.yaml b/config.train.yaml new file mode 100644 index 0000000..c82439d --- /dev/null +++ b/config.train.yaml @@ -0,0 +1,28 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 5e-4 + batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 3 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" + +checkpoints: + model_load_path: "YOUR/PATH" # path to load model from + model_save_path: "YOUR/PATH" # path to save model to \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 3901b8c..a1f916b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1083,6 +1083,55 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "ruff" version = "0.0.285" @@ -1482,4 +1531,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "b9efbbcd85e7d70578496491d81aa6ef8a610a77ffe134c08446300d5de42ed5" +content-hash = "e45a9c1ba8b67cbe83c4b010c3f4718eee990b064b90a3ccd64380387e734faf" diff --git a/pyproject.toml b/pyproject.toml index 38cc51a..f6d19dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ mido = "^1.3.0" tokenizers = "^0.13.3" click = "^8.1.7" matplotlib = "^3.7.2" +pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] black = "^23.7.0" diff --git a/requirements.txt b/requirements.txt index 3b39b56..040fed0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,7 @@ platformdirs==3.10.0 pylint==2.17.5 pyparsing==3.0.9 python-dateutil==2.8.2 +PyYAML==6.0.1 ruff==0.0.285 six==1.16.0 sympy==1.12 diff --git a/swr2_asr/__main__.py b/swr2_asr/__main__.py deleted file mode 100644 index be294fb..0000000 --- a/swr2_asr/__main__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Main entrypoint for swr2-asr.""" -import torch -import torchaudio - -if __name__ == "__main__": - # test if GPU is available - print("GPU available: ", torch.cuda.is_available()) - - # test if torchaudio is installed correctly - print("torchaudio version: ", torchaudio.__version__) - print("torchaudio backend: ", torchaudio.get_audio_backend()) - print("torchaudio info: ", torchaudio.get_audio_backend()) diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index c3eec42..f8342f7 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,11 +1,12 @@ """Training script for the ASR model.""" +from typing import TypedDict + import torch -import torchaudio import torch.nn.functional as F -from typing import TypedDict +import torchaudio -from swr2_asr.tokenizer import CharTokenizer from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.tokenizer import CharTokenizer class HParams(TypedDict): @@ -28,8 +29,7 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member blank_label = tokenizer.encode(" ").ids[0] decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): + for _i, args in enumerate(arg_maxes): decode = [] for j, index in enumerate(args): if index != blank_label: @@ -44,7 +44,7 @@ def main() -> None: """inference function.""" device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) + device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") @@ -90,7 +90,7 @@ def main() -> None: model.load_state_dict(state_dict) # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") + waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member if sample_rate != spectrogram_hparams["sample_rate"]: resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) waveform = resampler(waveform) @@ -103,7 +103,7 @@ def main() -> None: specs = [spec] specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) - output = model(specs) + output = model(specs) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) decodes = greedy_decoder(output, tokenizer) diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index 8ddbd99..77f4c8a 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -3,27 +3,10 @@ Following definition by Assembly AI (https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) """ -from typing import TypedDict - import torch.nn.functional as F from torch import nn -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ac7666b..eb79ee2 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,11 +5,12 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +import yaml from torch import nn, optim from torch.utils.data import DataLoader from tqdm.autonotebook import tqdm -from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel +from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split from swr2_asr.utils.decoder import greedy_decoder from swr2_asr.utils.tokenizer import CharTokenizer @@ -17,7 +18,7 @@ from swr2_asr.utils.tokenizer import CharTokenizer from .utils.loss_scores import cer, wer -class IterMeter(object): +class IterMeter: """keeps track of total iterations""" def __init__(self): @@ -116,6 +117,7 @@ class TestArgs(TypedDict): def test(test_args: TestArgs) -> tuple[float, float, float]: + """Test""" print("\nevaluating...") # get values from test_args: @@ -128,7 +130,7 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")): + for _data in tqdm(test_loader, desc="Validation Batches"): spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) @@ -142,8 +144,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: decoded_preds, decoded_targets = greedy_decoder( output.transpose(0, 1), labels, label_lengths, tokenizer ) - if i == 1: - print(f"decoding first sample: {decoded_preds}") for j, _ in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], decoded_preds[j])) test_wer.append(wer(decoded_targets[j], decoded_preds[j])) @@ -161,157 +161,149 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: return test_loss, avg_cer, avg_wer -def main( - learning_rate: float, - batch_size: int, - epochs: int, - dataset_path: str, - language: str, - limited_supervision: bool, - model_load_path: str, - model_save_path: str, - dataset_percentage: float, - eval_every: int, - num_workers: int, -): +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +def main(config_path: str): """Main function for training the model. - Args: - learning_rate: learning rate for the optimizer - batch_size: batch size - epochs: number of epochs to train - dataset_path: path for the dataset - language: language of the dataset - limited_supervision: whether to use only limited supervision - model_load_path: path to load a model from - model_save_path: path to save the model to - dataset_percentage: percentage of the dataset to use - eval_every: evaluate every n epochs - num_workers: number of workers for the dataloader + Gets all configuration arguments from yaml config file. """ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member torch.manual_seed(7) - if not os.path.isdir(dataset_path): - os.makedirs(dataset_path) + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + training_config = config_dict.get("training", {}) + dataset_config = config_dict.get("dataset", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + checkpoints_config = config_dict.get("checkpoints", {}) + + print(training_config["learning_rate"]) + + if not os.path.isdir(dataset_config["dataset_root_path"]): + os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TEST, - download=True, - limited=limited_supervision, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) valid_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TRAIN, - download=False, - limited=Falimited_supervisionlse, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) - # TODO: initialize and possibly train tokenizer if none found - - kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - ) + kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {} + + if tokenizer_config["tokenizer_path"] is None: + print("Tokenizer not found!") + if click.confirm("Do you want to train a new tokenizer?", default=True): + pass + else: + return + tokenizer = CharTokenizer.train( + dataset_config["dataset_root_path"], dataset_config["language_name"] + ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) train_data_processing = DataProcessing("train", tokenizer) valid_data_processing = DataProcessing("valid", tokenizer) train_loader = DataLoader( dataset=train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=train_data_processing, **kwargs, ) valid_loader = DataLoader( dataset=valid_dataset, - batch_size=hparams["batch_size"], - shuffle=False, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=valid_data_processing, **kwargs, ) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"]) criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=hparams["learning_rate"], + max_lr=training_config["learning_rate"], steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], + epochs=training_config["epochs"], anneal_strategy="linear", ) prev_epoch = 0 - if model_load_path is not None: - checkpoint = torch.load(model_load_path) + if checkpoints_config["model_load_path"] is not None: + checkpoint = torch.load(checkpoints_config["model_load_path"]) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) - for epoch in range(prev_epoch + 1, epochs + 1): - train_args: TrainArgs = dict( - model=model, - device=device, - train_loader=train_loader, - criterion=criterion, - optimizer=optimizer, - scheduler=scheduler, - epoch=epoch, - iter_meter=iter_meter, - ) + + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): + train_args: TrainArgs = { + "model": model, + "device": device, + "train_loader": train_loader, + "criterion": criterion, + "optimizer": optimizer, + "scheduler": scheduler, + "epoch": epoch, + "iter_meter": iter_meter, + } train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = dict( - model=model, - device=device, - test_loader=valid_loader, - criterion=criterion, - tokenizer=tokenizer, - decoder="greedy", - ) + test_args: TestArgs = { + "model": model, + "device": device, + "test_loader": valid_loader, + "criterion": criterion, + "tokenizer": tokenizer, + "decoder": "greedy", + } - if epoch % eval_every == 0: + if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: test_loss, test_cer, test_wer = test(test_args) - if model_save_path is None: + if checkpoints_config["model_save_path"] is None: continue - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) + if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])): + os.makedirs(os.path.dirname(checkpoints_config["model_save_path"])) + torch.save( { "epoch": epoch, @@ -322,7 +314,7 @@ def main( "avg_cer": test_cer, "avg_wer": test_wer, }, - model_save_path + str(epoch), + checkpoints_config["model_save_path"] + str(epoch), ) diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index e939e1d..0e06eec 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -1,13 +1,12 @@ """Class containing utils for the ASR system.""" import os from enum import Enum -from typing import TypedDict import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -125,7 +124,7 @@ class MLSDataset(Dataset): self._handle_download_dataset(download) self._validate_local_directory() - if limited and (split == Split.TRAIN or split == Split.VALID): + if limited and split in (Split.TRAIN, Split.VALID): self.initialize_limited() else: self.initialize() @@ -351,8 +350,6 @@ class MLSDataset(Dataset): if __name__ == "__main__": - from torch.utils.data import DataLoader - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" LANGUAGE = "mls_german_opus" split = Split.DEV diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 5482bbe..22569eb 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -1,8 +1,6 @@ """Tokenizer for Multilingual Librispeech datasets""" - - -from datetime import datetime import os +from datetime import datetime from tqdm.autonotebook import tqdm @@ -119,8 +117,8 @@ class CharTokenizer: line = line.strip() if line: char, index = line.split() - tokenizer.char_map[char] = int(index) - tokenizer.index_map[int(index)] = char + load_tokenizer.char_map[char] = int(index) + load_tokenizer.index_map[int(index)] = char return load_tokenizer diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index 80f942a..a55d0d5 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -6,10 +6,10 @@ import torch def plot(epochs, path): """Plots the losses over the epochs""" - losses = list() - test_losses = list() - cers = list() - wers = list() + losses = [] + test_losses = [] + cers = [] + wers = [] for epoch in range(1, epochs + 1): current_state = torch.load(path + str(epoch)) losses.append(current_state["loss"]) -- cgit v1.2.3