diff options
-rw-r--r-- | poetry.lock | 187 | ||||
-rw-r--r-- | pyproject.toml | 1 | ||||
-rw-r--r-- | swr2_asr/utils.py | 108 |
3 files changed, 267 insertions, 29 deletions
diff --git a/poetry.lock b/poetry.lock index 1f3609a..77643f5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -20,6 +20,50 @@ wrapt = [ ] [[package]] +name = "attrs" +version = "19.3.0" +description = "Classes Without Boilerplate" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "attrs-19.3.0-py2.py3-none-any.whl", hash = "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c"}, + {file = "attrs-19.3.0.tar.gz", hash = "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72"}, +] + +[package.extras] +azure-pipelines = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-azurepipelines", "six", "zope.interface"] +dev = ["coverage", "hypothesis", "pre-commit", "pympler", "pytest (>=4.3.0)", "six", "sphinx", "zope.interface"] +docs = ["sphinx", "zope.interface"] +tests = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] + +[[package]] +name = "audio-metadata" +version = "0.11.1" +description = "A library for reading and, in the future, writing metadata from audio files." +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "audio-metadata-0.11.1.tar.gz", hash = "sha256:9e7ba79d49cf048a911d5f7d55bb2715c10be5c127fe5db0987c5fe1aa7335eb"}, + {file = "audio_metadata-0.11.1-py3-none-any.whl", hash = "sha256:f5b85ad087324c255f8d1223574c3e7d3c27b649e411d1dd54aa3bf342fe93fb"}, +] + +[package.dependencies] +attrs = ">=18.2,<19.4" +bidict = "<1.0.0" +bitstruct = ">=6.0,<9.0" +more-itertools = ">=4.0,<9.0" +pendulum = ">=2.0,<2.0.5 || >2.0.5,<2.1.0 || >2.1.0,<=3.0" +pprintpp = "<1.0.0" +tbm-utils = ">=2.3,<3.0" +wrapt = ">=1.0,<2.0" + +[package.extras] +dev = ["coverage[toml] (>=5.0,<6.0)", "flake8 (>=3.5,<4.0)", "flake8-builtins (>=1.0,<2.0)", "flake8-comprehensions (>=2.0,<=4.0)", "flake8-import-order (>=0.18,<0.19)", "flake8-import-order-tbm (>=1.0,<2.0)", "nox (>=2019,<2020)", "sphinx (>=2.0,<3.0)", "sphinx-material (<1.0.0)", "ward (>=0.42.0-beta.0)"] +doc = ["sphinx (>=2.0,<3.0)", "sphinx-material (<1.0.0)"] +lint = ["flake8 (>=3.5,<4.0)", "flake8-builtins (>=1.0,<2.0)", "flake8-comprehensions (>=2.0,<=4.0)", "flake8-import-order (>=0.18,<0.19)", "flake8-import-order-tbm (>=1.0,<2.0)"] +test = ["coverage[toml] (>=5.0,<6.0)", "nox (>=2019,<2020)", "ward (>=0.42.0-beta.0)"] + +[[package]] name = "AudioLoader" version = "0.1.4" description = "A collection of PyTorch audio datasets for speech and music applications" @@ -35,6 +79,32 @@ reference = "HEAD" resolved_reference = "8fb829bf7fb98f26f8456dc22ef0fe2c7bb38ac2" [[package]] +name = "bidict" +version = "0.22.1" +description = "The bidirectional mapping library for Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "bidict-0.22.1-py3-none-any.whl", hash = "sha256:6ef212238eb884b664f28da76f33f1d28b260f665fc737b413b287d5487d1e7b"}, + {file = "bidict-0.22.1.tar.gz", hash = "sha256:1e0f7f74e4860e6d0943a05d4134c63a2fad86f3d4732fb265bd79e4e856d81d"}, +] + +[package.extras] +docs = ["furo", "sphinx", "sphinx-copybutton"] +lint = ["pre-commit"] +test = ["hypothesis", "pytest", "pytest-benchmark[histogram]", "pytest-cov", "pytest-xdist", "sortedcollections", "sortedcontainers", "sphinx"] + +[[package]] +name = "bitstruct" +version = "8.17.0" +description = "This module performs conversions between Python values and C bit field structs represented as Python byte strings." +optional = false +python-versions = "*" +files = [ + {file = "bitstruct-8.17.0.tar.gz", hash = "sha256:eb94b40e4218a23aa8f90406b836a9e6ed83e48b8d112ce3f96408463bd1b874"}, +] + +[[package]] name = "black" version = "23.7.0" description = "The uncompromising code formatter." @@ -349,6 +419,17 @@ release = ["twine (>=4.0.2,<4.1.0)"] test-code = ["pytest (>=7.4.0,<7.5.0)"] [[package]] +name = "more-itertools" +version = "8.14.0" +description = "More routines for operating on iterables, beyond itertools" +optional = false +python-versions = ">=3.5" +files = [ + {file = "more-itertools-8.14.0.tar.gz", hash = "sha256:c09443cd3d5438b8dafccd867a6bc1cb0894389e90cb53d227456b0b0bccb750"}, + {file = "more_itertools-8.14.0-py3-none-any.whl", hash = "sha256:1bc4f91ee5b1b31ac7ceacc17c09befe6a40a503907baf9c839c229b5095cfd2"}, +] + +[[package]] name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" @@ -655,6 +736,40 @@ files = [ ] [[package]] +name = "pendulum" +version = "2.1.2" +description = "Python datetimes made easy" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "pendulum-2.1.2-cp27-cp27m-macosx_10_15_x86_64.whl", hash = "sha256:b6c352f4bd32dff1ea7066bd31ad0f71f8d8100b9ff709fb343f3b86cee43efe"}, + {file = "pendulum-2.1.2-cp27-cp27m-win_amd64.whl", hash = "sha256:318f72f62e8e23cd6660dbafe1e346950281a9aed144b5c596b2ddabc1d19739"}, + {file = "pendulum-2.1.2-cp35-cp35m-macosx_10_15_x86_64.whl", hash = "sha256:0731f0c661a3cb779d398803655494893c9f581f6488048b3fb629c2342b5394"}, + {file = "pendulum-2.1.2-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:3481fad1dc3f6f6738bd575a951d3c15d4b4ce7c82dce37cf8ac1483fde6e8b0"}, + {file = "pendulum-2.1.2-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:9702069c694306297ed362ce7e3c1ef8404ac8ede39f9b28b7c1a7ad8c3959e3"}, + {file = "pendulum-2.1.2-cp35-cp35m-win_amd64.whl", hash = "sha256:fb53ffa0085002ddd43b6ca61a7b34f2d4d7c3ed66f931fe599e1a531b42af9b"}, + {file = "pendulum-2.1.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:c501749fdd3d6f9e726086bf0cd4437281ed47e7bca132ddb522f86a1645d360"}, + {file = "pendulum-2.1.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:c807a578a532eeb226150d5006f156632df2cc8c5693d778324b43ff8c515dd0"}, + {file = "pendulum-2.1.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:2d1619a721df661e506eff8db8614016f0720ac171fe80dda1333ee44e684087"}, + {file = "pendulum-2.1.2-cp36-cp36m-win_amd64.whl", hash = "sha256:f888f2d2909a414680a29ae74d0592758f2b9fcdee3549887779cd4055e975db"}, + {file = "pendulum-2.1.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:e95d329384717c7bf627bf27e204bc3b15c8238fa8d9d9781d93712776c14002"}, + {file = "pendulum-2.1.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:4c9c689747f39d0d02a9f94fcee737b34a5773803a64a5fdb046ee9cac7442c5"}, + {file = "pendulum-2.1.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:1245cd0075a3c6d889f581f6325dd8404aca5884dea7223a5566c38aab94642b"}, + {file = "pendulum-2.1.2-cp37-cp37m-win_amd64.whl", hash = "sha256:db0a40d8bcd27b4fb46676e8eb3c732c67a5a5e6bfab8927028224fbced0b40b"}, + {file = "pendulum-2.1.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:f5e236e7730cab1644e1b87aca3d2ff3e375a608542e90fe25685dae46310116"}, + {file = "pendulum-2.1.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:de42ea3e2943171a9e95141f2eecf972480636e8e484ccffaf1e833929e9e052"}, + {file = "pendulum-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7c5ec650cb4bec4c63a89a0242cc8c3cebcec92fcfe937c417ba18277d8560be"}, + {file = "pendulum-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33fb61601083f3eb1d15edeb45274f73c63b3c44a8524703dc143f4212bf3269"}, + {file = "pendulum-2.1.2-cp39-cp39-manylinux1_i686.whl", hash = "sha256:29c40a6f2942376185728c9a0347d7c0f07905638c83007e1d262781f1e6953a"}, + {file = "pendulum-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:94b1fc947bfe38579b28e1cccb36f7e28a15e841f30384b5ad6c5e31055c85d7"}, + {file = "pendulum-2.1.2.tar.gz", hash = "sha256:b06a0ca1bfe41c990bbf0c029f0b6501a7f2ec4e38bfec730712015e8860f207"}, +] + +[package.dependencies] +python-dateutil = ">=2.6,<3.0" +pytzdata = ">=2020.1" + +[[package]] name = "platformdirs" version = "3.10.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." @@ -670,6 +785,17 @@ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx- test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] [[package]] +name = "pprintpp" +version = "0.4.0" +description = "A drop-in replacement for pprint that's actually pretty" +optional = false +python-versions = "*" +files = [ + {file = "pprintpp-0.4.0-py2.py3-none-any.whl", hash = "sha256:b6b4dcdd0c0c0d75e4d7b2f21a9e933e5b2ce62b26e1a54537f9651ae5a5c01d"}, + {file = "pprintpp-0.4.0.tar.gz", hash = "sha256:ea826108e2c7f49dc6d66c752973c3fc9749142a798d6b254e1e301cfdbc6403"}, +] + +[[package]] name = "pylint" version = "2.17.5" description = "python code static checker" @@ -698,6 +824,31 @@ spelling = ["pyenchant (>=3.2,<4.0)"] testutils = ["gitpython (>3)"] [[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "pytzdata" +version = "2020.1" +description = "The Olson timezone database for Python." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pytzdata-2020.1-py2.py3-none-any.whl", hash = "sha256:e1e14750bcf95016381e4d472bad004eef710f2d6417240904070b3d6654485f"}, + {file = "pytzdata-2020.1.tar.gz", hash = "sha256:3efa13b335a00a8de1d345ae41ec78dd11c9f8807f522d39850f2dd828681540"}, +] + +[[package]] name = "ruff" version = "0.0.285" description = "An extremely fast Python linter, written in Rust." @@ -740,6 +891,17 @@ testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[l testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + +[[package]] name = "sympy" version = "1.12" description = "Computer algebra system (CAS) in Python" @@ -754,6 +916,29 @@ files = [ mpmath = ">=0.19" [[package]] +name = "tbm-utils" +version = "2.6.0" +description = "A commonly-used set of utilities used by me (thebigmunch)." +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "tbm-utils-2.6.0.tar.gz", hash = "sha256:235748cceeb22c042e32d2fdfd4d710021bac9b938c4f2c35e1fce1cfd58f7ec"}, + {file = "tbm_utils-2.6.0-py3-none-any.whl", hash = "sha256:692b5cde2b810bb84e55ca0f5f6c055ca6ad321c4d5acc0cade98af97c5998e2"}, +] + +[package.dependencies] +attrs = ">=18.2,<19.4" +pendulum = ">=2.0,<2.0.5 || >2.0.5,<2.1.0 || >2.1.0,<=3.0" +pprintpp = "<1.0.0" +wrapt = ">=1.0,<2.0" + +[package.extras] +dev = ["coverage[toml] (>=4.5,<6.0)", "flake8 (>=3.5,<4.0)", "flake8-builtins (>=1.0,<2.0)", "flake8-comprehensions (>=2.0,<=4.0)", "flake8-import-order (>=0.18,<0.19)", "flake8-import-order-tbm (>=1.0,<2.0)", "nox (>=2019,<2020)", "pytest (>=4.0,<6.0)", "sphinx (>=2.0,<3.0)", "sphinx-material (<1.0.0)"] +doc = ["sphinx (>=2.0,<3.0)", "sphinx-material (<1.0.0)"] +lint = ["flake8 (>=3.5,<4.0)", "flake8-builtins (>=1.0,<2.0)", "flake8-comprehensions (>=2.0,<=4.0)", "flake8-import-order (>=0.18,<0.19)", "flake8-import-order-tbm (>=1.0,<2.0)"] +test = ["coverage[toml] (>=4.5,<6.0)", "nox (>=2019,<2020)", "pytest (>=4.0,<6.0)"] + +[[package]] name = "tokenizers" version = "0.13.3" description = "Fast and Customizable Tokenizers" @@ -1096,4 +1281,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "6b42e36364178f1670267137f73e8d2b2f3fc1d534a2b198d4ca3f65457d55c2" +content-hash = "992e0eb975f6fb490726bc9db470a682986c1a287078bec407a4a6dc96e45b3c" diff --git a/pyproject.toml b/pyproject.toml index 94f7553..dc136f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ numpy = "^1.25.2" mido = "^1.3.0" tokenizers = "^0.13.3" click = "^8.1.7" +audio-metadata = "^0.11.1" [tool.poetry.group.dev.dependencies] black = "^23.7.0" diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 786fbcf..404661d 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -1,6 +1,7 @@ """Class containing utils for the ASR system.""" import os from enum import Enum +from multiprocessing import Pool from typing import TypedDict import numpy as np @@ -8,6 +9,8 @@ import torch import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset +from tqdm import tqdm +import audio_metadata from swr2_asr.tokenizer import CharTokenizer, TokenizerType @@ -133,11 +136,13 @@ class MLSDataset(Dataset): for path, utterance in zip(identifier, utterances, strict=False) ] + self.max_spec_length = 0 + self.max_utterance_length = 0 + def set_tokenizer(self, tokenizer: type[TokenizerType]): """Sets the tokenizer""" self.tokenizer = tokenizer - - self.calc_paddings() + # self.calc_paddings() def _handle_download_dataset(self, download: bool): """Download the dataset""" @@ -158,27 +163,73 @@ class MLSDataset(Dataset): if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") - def calc_paddings(self): - """Sets the maximum length of the spectrogram""" + def _calculate_max_length(self, chunk): + """Calculates the maximum length of the spectrogram and the utterance + + to be called in a multiprocessing pool + """ + max_spec_length = 0 + max_utterance_length = 0 + + for sample in chunk: + audio_path = os.path.join( + self.dataset_path, + self.language, + self.mls_split, + "audio", + sample["speakerid"], + sample["bookid"], + "_".join( + [ + sample["speakerid"], + sample["bookid"], + sample["chapterid"], + ] + ) + + self.file_ext, + ) + metadata = audio_metadata.load(audio_path) + audio_duration = metadata.streaminfo.duration + sample_rate = metadata.streaminfo.sample_rate + + max_spec_length = int(max(max_spec_length, (audio_duration * sample_rate) // 200)) + max_utterance_length = max(max_utterance_length, len(self.tokenizer.encode(sample["utterance"]).ids)) + + return max_spec_length, max_utterance_length + + def calc_paddings(self) -> None: + """Sets the maximum length of the spectrogram and the utterance""" # check if dataset has been loaded and tokenizer has been set if not self.dataset_lookup: raise ValueError("Dataset not loaded") if not self.tokenizer: raise ValueError("Tokenizer not set") - - max_spec_length = 0 - max_uterance_length = 0 - for sample in self.dataset_lookup: - spec_length = sample["spectrogram"].shape[0] - if spec_length > max_spec_length: - max_spec_length = spec_length - - utterance_length = sample["utterance"].shape[0] - if utterance_length > max_uterance_length: - max_uterance_length = utterance_length - - self.max_spec_length = max_spec_length - self.max_utterance_length = max_uterance_length + # check if paddings have been calculated already + if os.path.isfile(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt")): + print("Paddings already calculated") + with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "r") as f: + self.max_spec_length, self.max_utterance_length = [int(line.strip()) for line in f.readlines()] + return + else: + print("Calculating paddings...") + + thread_count = os.cpu_count() + if thread_count is None: + thread_count = 4 + chunk_size = len(self.dataset_lookup) // thread_count + chunks = [self.dataset_lookup[i : i + chunk_size] for i in range(0, len(self.dataset_lookup), chunk_size)] + + with Pool(thread_count) as p: + results = list(p.imap(self._calculate_max_length, chunks)) + + for spec, utterance in results: + self.max_spec_length = max(self.max_spec_length, spec) + self.max_utterance_length = max(self.max_utterance_length, utterance) + + # write to file + with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "w") as f: + f.write(f"{self.max_spec_length}\n") + f.write(f"{self.max_utterance_length}") def __len__(self): """Returns the length of the dataset""" @@ -208,18 +259,18 @@ class MLSDataset(Dataset): ) waveform, sample_rate = torchaudio.load(audio_path) # type: ignore - # TODO: figure out if we have to resample or not - # TODO: pad correctly (manually) + + # resample if necessary + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) + waveform = resampler(waveform) + sample_rate = 16000 + spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1) - print(f"spec.shape: {spec.shape}") - input_length = spec.shape[0] // 2 - spec = ( - torch.nn.functional.pad(spec, pad=(0, self.max_spec_length), mode="constant", value=0) - .unsqueeze(1) - .transpose(2, 3) - ) + input_length = spec.shape[0] // 2 utterance_length = len(utterance) + self.tokenizer.enable_padding() utterance = self.tokenizer.encode( utterance, @@ -260,4 +311,5 @@ if __name__ == "__main__": dataset.set_tokenizer(tok) dataset.calc_paddings() - print(dataset[41]["spectrogram"].shape) + print(f"Spectrogram shape: {dataset[41]['spectrogram'].shape}") + print(f"Utterance shape: {dataset[41]['utterance'].shape}") |