From 4c31ecc1bb748242d4740ab5b42514598006d10b Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 21 Aug 2023 18:29:09 +0200 Subject: added simple custom dataloader --- swr2_asr/utils.py | 284 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 swr2_asr/utils.py (limited to 'swr2_asr/utils.py') diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py new file mode 100644 index 0000000..c4aeb0b --- /dev/null +++ b/swr2_asr/utils.py @@ -0,0 +1,284 @@ +"""Class containing utils for the ASR system.""" +from dataclasses import dataclass +import os +from AudioLoader.speech import MultilingualLibriSpeech +import numpy as np +import torch +import torchaudio +from torch import nn +from torch.utils.data import Dataset, DataLoader +from enum import Enum + +from tokenizers import Tokenizer +from swr2_asr.tokenizer import CharTokenizer + + +# create enum specifiying dataset splits +class MLSSplit(str, Enum): + """Enum specifying dataset as they are defined in the + Multilingual LibriSpeech dataset""" + + train = "train" + test = "test" + dev = "dev" + + +class Split(str, Enum): + """Extending the MLSSplit class to allow for a custom validatio split""" + + train = "train" + valid = "valid" + test = "test" + dev = "dev" + + +def split_to_mls_split(split: Split) -> MLSSplit: + """Converts the custom split to a MLSSplit""" + if split == Split.valid: + return MLSSplit.train + else: + return split # type: ignore + + +@dataclass +class Sample: + """Dataclass for a sample in the dataset""" + + waveform: torch.Tensor + spectrogram: torch.Tensor + utterance: str + sample_rate: int + speaker_id: str + book_id: str + chapter_id: str + + +def tokenizer_factory(tokenizer_path: str, tokenizer_type: str = "BPE"): + """Factory for Tokenizer class + + Args: + tokenizer_type (str, optional): Type of tokenizer to use. Defaults to "BPE". + + Returns: + nn.Module: Tokenizer class + """ + if tokenizer_type == "BPE": + return Tokenizer.from_file(tokenizer_path) + elif tokenizer_type == "char": + return CharTokenizer.from_file(tokenizer_path) + + +class MLSDataset(Dataset): + """Custom Dataset for reading Multilingual LibriSpeech + + Attributes: + dataset_path (str): + path to the dataset + language (str): + language of the dataset + split (Split): + split of the dataset + mls_split (MLSSplit): + split of the dataset as defined in the Multilingual LibriSpeech dataset + dataset_lookup (list): + list of dicts containing the speakerid, bookid, chapterid and utterance + + directory structure: + + ├── + │ ├── train + │ │ ├── transcripts.txt + │ │ └── audio + │ │ └── + │ │ └── + │ │ └── __.opus / .flac + + each line in transcripts.txt has the following format: + __ + """ + + def __init__(self, dataset_path: str, language: str, split: Split, download: bool): + """Initializes the dataset""" + self.dataset_path = dataset_path + self.language = language + self.file_ext = ".opus" if "opus" in language else ".flac" + self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk + self.split: Split = split # split used internally + self.dataset_lookup = [] + + self._handle_download_dataset(download) + self._validate_local_directory() + + transcripts_path = os.path.join( + dataset_path, language, self.mls_split, "transcripts.txt" + ) + + with open(transcripts_path, "r", encoding="utf-8") as script_file: + # read all lines in transcripts.txt + transcripts = script_file.readlines() + # split each line into (__, ) + transcripts = [line.strip().split("\t", 1) for line in transcripts] + utterances = [utterance.strip() for _, utterance in transcripts] + identifier = [identifier.strip() for identifier, _ in transcripts] + identifier = [path.split("_") for path in identifier] + + self.dataset_lookup = [ + { + "speakerid": path[0], + "bookid": path[1], + "chapterid": path[2], + "utterance": utterance, + } + for path, utterance in zip(identifier, utterances) + ] + + # save dataset_lookup as list of dicts, where each dict contains + # the speakerid, bookid and chapterid, as well as the utterance + # we can then use this to map the utterance to the audio file + + def _handle_download_dataset(self, download: bool): + """Download the dataset""" + if ( + not os.path.exists(os.path.join(self.dataset_path, self.language)) + and download + ): + os.makedirs(self.dataset_path) + url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" + + torch.hub.download_url_to_file(url, self.dataset_path) + elif ( + not os.path.exists(os.path.join(self.dataset_path, self.language)) + and not download + ): + raise ValueError("Dataset not found. Set download to True to download it") + + def _validate_local_directory(self): + # check if dataset_path exists + if not os.path.exists(self.dataset_path): + raise ValueError("Dataset path does not exist") + if not os.path.exists(os.path.join(self.dataset_path, self.language)): + raise ValueError("Language not found in dataset") + if not os.path.exists( + os.path.join(self.dataset_path, self.language, self.mls_split) + ): + raise ValueError("Split not found in dataset") + + # checks if the transcripts.txt file exists + if not os.path.exists( + os.path.join(dataset_path, language, split, "transcripts.txt") + ): + raise ValueError("transcripts.txt not found in dataset") + + def __get_len__(self): + """Returns the length of the dataset""" + return len(self.dataset_lookup) + + def __getitem__(self, idx: int) -> Sample: + """One sample""" + # get the utterance + utterance = self.dataset_lookup[idx]["utterance"] + + # get the audio file + audio_path = os.path.join( + self.dataset_path, + self.language, + self.mls_split, + "audio", + self.dataset_lookup[idx]["speakerid"], + self.dataset_lookup[idx]["bookid"], + "_".join( + [ + self.dataset_lookup[idx]["speakerid"], + self.dataset_lookup[idx]["bookid"], + self.dataset_lookup[idx]["chapterid"], + ] + ) + + self.file_ext, + ) + + waveform, sample_rate = torchaudio.load(audio_path) # type: ignore + + return Sample( + waveform=waveform, + spectrogram=torchaudio.transforms.MelSpectrogram( + sample_rate=16000, n_mels=128 + )(waveform), + utterance=utterance, + sample_rate=sample_rate, + speaker_id=self.dataset_lookup[idx]["speakerid"], + book_id=self.dataset_lookup[idx]["bookid"], + chapter_id=self.dataset_lookup[idx]["chapterid"], + ) + + def download(self, dataset_path: str, language: str): + """Download the dataset""" + os.makedirs(dataset_path) + url = f"https://dl.fbaipublicfiles.com/mls/{language}.tar.gz" + + torch.hub.download_url_to_file(url, dataset_path) + + +class DataProcessor: + """Factory for DataProcessingclass + + Transforms the dataset into spectrograms and labels, as well as a tokenizer + """ + + def __init__( + self, + dataset: MultilingualLibriSpeech, + tokenizer_path: str, + data_type: str = "train", + tokenizer_type: str = "BPE", + ): + self.dataset = dataset + self.data_type = data_type + + self.train_audio_transforms = nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), + ) + + self.valid_audio_transforms = torchaudio.transforms.MelSpectrogram() + self.tokenizer = tokenizer_factory( + tokenizer_path=tokenizer_path, tokenizer_type=tokenizer_type + ) + + def __call__(self) -> tuple[np.ndarray, np.ndarray, int, int]: + """Returns spectrograms, labels and their lenghts""" + for sample in self.dataset: + if self.data_type == "train": + spec = ( + self.train_audio_transforms(sample["waveform"]) + .squeeze(0) + .transpose(0, 1) + ) + elif self.data_type == "valid": + spec = ( + self.valid_audio_transforms(sample["waveform"]) + .squeeze(0) + .transpose(0, 1) + ) + else: + raise ValueError("data_type should be train or valid") + label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + .unsqueeze(1) + .transpose(2, 3) + ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + yield spec, label, spec.shape[0] // 2, len(labels) + + +if __name__ == "__main__": + dataset_path = "/Volumes/pherkel/SWR2-ASR" + language = "mls_german_opus" + split = Split.train + download = False + + dataset = MLSDataset(dataset_path, language, split, download) + print(dataset[0]) -- cgit v1.2.3 From 403472ca4e65e8ed404e8a73fb9b3fbafe3f2a53 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Thu, 24 Aug 2023 00:03:56 +0200 Subject: wip: commit before going on vacation :) --- pyproject.toml | 9 ++ swr2_asr/model_deep_speech.py | 150 +++++++++++++++++++++ swr2_asr/tokenizer.py | 72 ++++++---- swr2_asr/train.py | 298 +++++++++--------------------------------- swr2_asr/utils.py | 207 +++++++++++++---------------- 5 files changed, 365 insertions(+), 371 deletions(-) create mode 100644 swr2_asr/model_deep_speech.py (limited to 'swr2_asr/utils.py') diff --git a/pyproject.toml b/pyproject.toml index fabe364..94f7553 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,15 @@ pylint = "^2.17.5" ruff = "^0.0.285" types-tqdm = "^4.66.0.1" +[tool.ruff] +select = ["E", "F", "B", "I"] +fixable = ["ALL"] +line-length = 120 +target-version = "py310" + +[tool.black] +line-length = 120 + [tool.poetry.scripts] train = "swr2_asr.train:run_cli" train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer" diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py new file mode 100644 index 0000000..ea0b667 --- /dev/null +++ b/swr2_asr/model_deep_speech.py @@ -0,0 +1,150 @@ +from torch import nn +import torch.nn.functional as F + + +class CNNLayerNorm(nn.Module): + """Layer normalization built for cnns input""" + + def __init__(self, n_feats: int): + super().__init__() + self.layer_norm = nn.LayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) + data = self.layer_norm(data) + return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) + + +class ResidualCNN(nn.Module): + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + stride: int, + dropout: float, + n_feats: int, + ): + super().__init__() + + self.cnn1 = nn.Conv2d( + in_channels, out_channels, kernel, stride, padding=kernel // 2 + ) + self.cnn2 = nn.Conv2d( + out_channels, + out_channels, + kernel, + stride, + padding=kernel // 2, + ) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.layer_norm1 = CNNLayerNorm(n_feats) + self.layer_norm2 = CNNLayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + residual = data # (batch, channel, feature, time) + data = self.layer_norm1(data) + data = F.gelu(data) + data = self.dropout1(data) + data = self.cnn1(data) + data = self.layer_norm2(data) + data = F.gelu(data) + data = self.dropout2(data) + data = self.cnn2(data) + data += residual + return data # (batch, channel, feature, time) + + +class BidirectionalGRU(nn.Module): + """BIdirectional GRU with Layer Normalization and Dropout""" + + def __init__( + self, + rnn_dim: int, + hidden_size: int, + dropout: float, + batch_first: bool, + ): + super().__init__() + + self.bi_gru = nn.GRU( + input_size=rnn_dim, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=True, + ) + self.layer_norm = nn.LayerNorm(rnn_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, data): + """data (batch, time, feature)""" + data = self.layer_norm(data) + data = F.gelu(data) + data = self.dropout(data) + data, _ = self.bi_gru(data) + return data + + +class SpeechRecognitionModel(nn.Module): + """Speech Recognition Model Inspired by DeepSpeech 2""" + + def __init__( + self, + n_cnn_layers: int, + n_rnn_layers: int, + rnn_dim: int, + n_class: int, + n_feats: int, + stride: int = 2, + dropout: float = 0.1, + ): + super().__init__() + n_feats //= 2 + self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + # n residual cnn layers with filter size of 32 + self.rescnn_layers = nn.Sequential( + *[ + ResidualCNN( + 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats + ) + for _ in range(n_cnn_layers) + ] + ) + self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) + self.birnn_layers = nn.Sequential( + *[ + BidirectionalGRU( + rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, + hidden_size=rnn_dim, + dropout=dropout, + batch_first=i == 0, + ) + for i in range(n_rnn_layers) + ] + ) + self.classifier = nn.Sequential( + nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(rnn_dim, n_class), + ) + + def forward(self, data): + """data (batch, channel, feature, time)""" + data = self.cnn(data) + data = self.rescnn_layers(data) + sizes = data.size() + data = data.view( + sizes[0], sizes[1] * sizes[2], sizes[3] + ) # (batch, feature, time) + data = data.transpose(1, 2) # (batch, time, feature) + data = self.fully_connected(data) + data = self.birnn_layers(data) + data = self.classifier(data) + return data diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 4dbb386..5758da7 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -1,16 +1,50 @@ """Tokenizer for use with Multilingual Librispeech""" -from dataclasses import dataclass import json import os -import click -from tqdm import tqdm +from dataclasses import dataclass +from typing import Type +import click from AudioLoader.speech import MultilingualLibriSpeech - from tokenizers import Tokenizer, normalizers from tokenizers.models import BPE -from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace +from tokenizers.trainers import BpeTrainer +from tqdm import tqdm + + +class TokenizerType: + def encode(self, sequence: str) -> list[int]: + raise NotImplementedError + + def decode(self, labels: list[int], remove_special_tokens: bool) -> str: + raise NotImplementedError + + def decode_batch(self, labels: list[list[int]]) -> list[str]: + raise NotImplementedError + + def get_vocab_size(self) -> int: + raise NotImplementedError + + def enable_padding( + self, + length: int = -1, + direction: str = "right", + pad_id: int = 0, + pad_type_id: int = 0, + pad_token: str = "[PAD]", + ) -> None: + raise NotImplementedError + + def save(self, path: str) -> None: + raise NotImplementedError + + @staticmethod + def from_file(path: str) -> "TokenizerType": + raise NotImplementedError + + +tokenizer_type = Type[TokenizerType] @dataclass @@ -20,7 +54,7 @@ class Encoding: ids: list[int] -class CharTokenizer: +class CharTokenizer(TokenizerType): """Very simple tokenizer for use with Multilingual Librispeech Simply checks what characters are in the dataset and uses them as tokens. @@ -45,9 +79,7 @@ class CharTokenizer: self.char_map[token] = len(self.char_map) self.index_map[len(self.index_map)] = token - def train( - self, dataset_path: str, language: str, split: str, download: bool = True - ): + def train(self, dataset_path: str, language: str, split: str, download: bool = True): """Train the tokenizer on the given dataset Args: @@ -65,9 +97,7 @@ class CharTokenizer: chars: set = set() for s_plit in splits: - transcript_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) + transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") # check if dataset is downloaded, download if not if download and not os.path.exists(transcript_path): @@ -90,7 +120,7 @@ class CharTokenizer: self.char_map[char] = i self.index_map[i] = char - def encode(self, text: str): + def encode(self, sequence: str): """Use a character map and convert text to an integer sequence automatically maps spaces to and makes everything lowercase @@ -98,8 +128,8 @@ class CharTokenizer: """ int_sequence = [] - text = text.lower() - for char in text: + sequence = sequence.lower() + for char in sequence: if char == " ": mapped_char = self.char_map[""] elif char not in self.char_map: @@ -174,9 +204,7 @@ class CharTokenizer: @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use (including all)") @click.option("--download", default=True, help="Whether to download the dataset") -@click.option( - "--out_path", default="tokenizer.json", help="Path to save the tokenizer to" -) +@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") def train_bpe_tokenizer( dataset_path: str, @@ -210,9 +238,7 @@ def train_bpe_tokenizer( lines = [] for s_plit in splits: - transcripts_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) + transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") if download and not os.path.exists(transcripts_path): MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) @@ -296,9 +322,7 @@ def train_bpe_tokenizer( @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use") -@click.option( - "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to" -) +@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") @click.option("--download", default=True, help="Whether to download the dataset") def train_char_tokenizer( dataset_path: str, diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 6af1e80..53cdac1 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,74 +1,44 @@ """Training script for the ASR model.""" import os +from typing import TypedDict + import click import torch import torch.nn.functional as F -import torchaudio -from AudioLoader.speech import MultilingualLibriSpeech +from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader -from tokenizers import Tokenizer -from .tokenizer import CharTokenizer + +from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.tokenizer import train_bpe_tokenizer +from swr2_asr.utils import MLSDataset, Split from .loss_scores import cer, wer -train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) -valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - -# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") -text_transform = CharTokenizer() -text_transform.from_file("data/tokenizers/char_tokenizer_german.json") - - -def data_processing(data, data_type="train"): - """Return the spectrograms, labels, and their lengths.""" - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - for sample in data: - if data_type == "train": - spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - elif data_type == "valid": - spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - else: - raise ValueError("data_type should be train or valid") - spectrograms.append(spec) - label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) - labels.append(label) - input_lengths.append(spec.shape[0] // 2) - label_lengths.append(len(label)) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" - return spectrograms, labels, input_lengths, label_lengths + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int -def greedy_decoder( - output, labels, label_lengths, blank_label=28, collapse_repeated=True -): - # TODO: adopt to support both tokenizers +def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append( - text_transform.decode( - [int(x) for x in labels[i][: label_lengths[i]].tolist()] - ) - ) + targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: @@ -78,155 +48,6 @@ def greedy_decoder( return decodes, targets -# TODO: restructure into own file / class -class CNNLayerNorm(nn.Module): - """Layer normalization built for cnns input""" - - def __init__(self, n_feats: int): - super().__init__() - self.layer_norm = nn.LayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) - data = self.layer_norm(data) - return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) - - -class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super().__init__() - - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.layer_norm1 = CNNLayerNorm(n_feats) - self.layer_norm2 = CNNLayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - residual = data # (batch, channel, feature, time) - data = self.layer_norm1(data) - data = F.gelu(data) - data = self.dropout1(data) - data = self.cnn1(data) - data = self.layer_norm2(data) - data = F.gelu(data) - data = self.dropout2(data) - data = self.cnn2(data) - data += residual - return data # (batch, channel, feature, time) - - -class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" - - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super().__init__() - - self.bi_gru = nn.GRU( - input_size=rnn_dim, - hidden_size=hidden_size, - num_layers=1, - batch_first=batch_first, - bidirectional=True, - ) - self.layer_norm = nn.LayerNorm(rnn_dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, data): - """data (batch, time, feature)""" - data = self.layer_norm(data) - data = F.gelu(data) - data = self.dropout(data) - data, _ = self.bi_gru(data) - return data - - -class SpeechRecognitionModel(nn.Module): - """Speech Recognition Model Inspired by DeepSpeech 2""" - - def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, - ): - super().__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) - # n residual cnn layers with filter size of 32 - self.rescnn_layers = nn.Sequential( - *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) - for _ in range(n_cnn_layers) - ] - ) - self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) - self.birnn_layers = nn.Sequential( - *[ - BidirectionalGRU( - rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, - hidden_size=rnn_dim, - dropout=dropout, - batch_first=i == 0, - ) - for i in range(n_rnn_layers) - ] - ) - self.classifier = nn.Sequential( - nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(rnn_dim, n_class), - ) - - def forward(self, data): - """data (batch, channel, feature, time)""" - data = self.cnn(data) - data = self.rescnn_layers(data) - sizes = data.size() - data = data.view( - sizes[0], sizes[1] * sizes[2], sizes[3] - ) # (batch, feature, time) - data = data.transpose(1, 2) # (batch, time, feature) - data = self.fully_connected(data) - data = self.birnn_layers(data) - data = self.classifier(data) - return data - - class IterMeter: """keeps track of total iterations""" @@ -256,9 +77,8 @@ def train( model.train() data_len = len(train_loader.dataset) for batch_idx, _data in enumerate(train_loader): - spectrograms, labels, input_lengths, label_lengths = _data + _, spectrograms, input_lengths, labels, label_lengths, *_ = _data spectrograms, labels = spectrograms.to(device), labels.to(device) - optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) @@ -282,7 +102,6 @@ def train( return loss.item() -# TODO: check how dataloader can be made more efficient def test(model, device, test_loader, criterion): """Test""" print("\nevaluating...") @@ -301,9 +120,7 @@ def test(model, device, test_loader, criterion): loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths - ) + decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) @@ -324,46 +141,62 @@ def run( load: bool, path: str, dataset_path: str, + language: str, ) -> None: """Runs the training script.""" - hparams = { - "n_cnn_layers": 3, - "n_rnn_layers": 5, - "rnn_dim": 512, - "n_class": 36, # TODO: dynamically determine this from vocab size - "n_feats": 128, - "stride": 2, - "dropout": 0.1, - "learning_rate": learning_rate, - "batch_size": batch_size, - "epochs": epochs, - } - use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") - download_dataset = not os.path.isdir(path) - train_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="dev", download=download_dataset - ) - test_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="test", download=False + # load dataset + train_dataset = MLSDataset(dataset_path, language, Split.train, download=True) + valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True) + test_dataset = MLSDataset(dataset_path, language, Split.test, download=True) + + # TODO: add flag to choose tokenizer + # load tokenizer (bpe by default): + if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"): + print("There is no tokenizer available. Do you want to train it on the dataset?") + input("Press Enter to continue...") + train_bpe_tokenizer( + dataset_path=dataset_path, + language=language, + split="all", + download=False, + out_path="data/tokenizers/bpe_tokenizer_german_3000.json", + vocab_size=3000, + ) + + tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") + + train_dataset.set_tokenizer(tokenizer) + valid_dataset.set_tokenizer(tokenizer) + test_dataset.set_tokenizer(tokenizer) + + hparams = HParams( + n_cnn_layers=3, + n_rnn_layers=5, + rnn_dim=512, + n_class=tokenizer.get_vocab_size(), + n_feats=128, + stride=2, + dropout=0.1, + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, ) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), ) - test_loader = DataLoader( - test_dataset, + valid_loader = DataLoader( + valid_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), ) # enable flag to find the most compatible algorithms in advance @@ -380,9 +213,7 @@ def run( hparams["dropout"], ).to(device) - print( - "Num Model Parameters", sum((param.nelement() for param in model.parameters())) - ) + print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) if load: @@ -412,7 +243,7 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=test_loader, criterion=criterion) + test(model=model, device=device, test_loader=valid_loader, criterion=criterion) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -452,4 +283,5 @@ def run_cli( load=load, path=path, dataset_path=dataset_path, + language="mls_german_opus", ) diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index c4aeb0b..786fbcf 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -1,16 +1,21 @@ """Class containing utils for the ASR system.""" -from dataclasses import dataclass import os -from AudioLoader.speech import MultilingualLibriSpeech +from enum import Enum +from typing import TypedDict + import numpy as np import torch import torchaudio -from torch import nn -from torch.utils.data import Dataset, DataLoader -from enum import Enum - from tokenizers import Tokenizer -from swr2_asr.tokenizer import CharTokenizer +from torch.utils.data import Dataset + +from swr2_asr.tokenizer import CharTokenizer, TokenizerType + +train_audio_transforms = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), +) # create enum specifiying dataset splits @@ -40,34 +45,20 @@ def split_to_mls_split(split: Split) -> MLSSplit: return split # type: ignore -@dataclass -class Sample: - """Dataclass for a sample in the dataset""" +class Sample(TypedDict): + """Type for a sample in the dataset""" waveform: torch.Tensor spectrogram: torch.Tensor - utterance: str + input_length: int + utterance: torch.Tensor + utterance_length: int sample_rate: int speaker_id: str book_id: str chapter_id: str -def tokenizer_factory(tokenizer_path: str, tokenizer_type: str = "BPE"): - """Factory for Tokenizer class - - Args: - tokenizer_type (str, optional): Type of tokenizer to use. Defaults to "BPE". - - Returns: - nn.Module: Tokenizer class - """ - if tokenizer_type == "BPE": - return Tokenizer.from_file(tokenizer_path) - elif tokenizer_type == "char": - return CharTokenizer.from_file(tokenizer_path) - - class MLSDataset(Dataset): """Custom Dataset for reading Multilingual LibriSpeech @@ -105,23 +96,33 @@ class MLSDataset(Dataset): self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally self.dataset_lookup = [] + self.tokenizer: type[TokenizerType] self._handle_download_dataset(download) self._validate_local_directory() - transcripts_path = os.path.join( - dataset_path, language, self.mls_split, "transcripts.txt" - ) + transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt") with open(transcripts_path, "r", encoding="utf-8") as script_file: # read all lines in transcripts.txt transcripts = script_file.readlines() # split each line into (__, ) - transcripts = [line.strip().split("\t", 1) for line in transcripts] - utterances = [utterance.strip() for _, utterance in transcripts] - identifier = [identifier.strip() for identifier, _ in transcripts] + transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore + utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore + identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore identifier = [path.split("_") for path in identifier] + if self.split == Split.valid: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + elif self.split == Split.train: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + self.dataset_lookup = [ { "speakerid": path[0], @@ -129,27 +130,23 @@ class MLSDataset(Dataset): "chapterid": path[2], "utterance": utterance, } - for path, utterance in zip(identifier, utterances) + for path, utterance in zip(identifier, utterances, strict=False) ] - # save dataset_lookup as list of dicts, where each dict contains - # the speakerid, bookid and chapterid, as well as the utterance - # we can then use this to map the utterance to the audio file + def set_tokenizer(self, tokenizer: type[TokenizerType]): + """Sets the tokenizer""" + self.tokenizer = tokenizer + + self.calc_paddings() def _handle_download_dataset(self, download: bool): """Download the dataset""" - if ( - not os.path.exists(os.path.join(self.dataset_path, self.language)) - and download - ): + if not os.path.exists(os.path.join(self.dataset_path, self.language)) and download: os.makedirs(self.dataset_path) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" torch.hub.download_url_to_file(url, self.dataset_path) - elif ( - not os.path.exists(os.path.join(self.dataset_path, self.language)) - and not download - ): + elif not os.path.exists(os.path.join(self.dataset_path, self.language)) and not download: raise ValueError("Dataset not found. Set download to True to download it") def _validate_local_directory(self): @@ -158,18 +155,32 @@ class MLSDataset(Dataset): raise ValueError("Dataset path does not exist") if not os.path.exists(os.path.join(self.dataset_path, self.language)): raise ValueError("Language not found in dataset") - if not os.path.exists( - os.path.join(self.dataset_path, self.language, self.mls_split) - ): + if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") - # checks if the transcripts.txt file exists - if not os.path.exists( - os.path.join(dataset_path, language, split, "transcripts.txt") - ): - raise ValueError("transcripts.txt not found in dataset") - - def __get_len__(self): + def calc_paddings(self): + """Sets the maximum length of the spectrogram""" + # check if dataset has been loaded and tokenizer has been set + if not self.dataset_lookup: + raise ValueError("Dataset not loaded") + if not self.tokenizer: + raise ValueError("Tokenizer not set") + + max_spec_length = 0 + max_uterance_length = 0 + for sample in self.dataset_lookup: + spec_length = sample["spectrogram"].shape[0] + if spec_length > max_spec_length: + max_spec_length = spec_length + + utterance_length = sample["utterance"].shape[0] + if utterance_length > max_uterance_length: + max_uterance_length = utterance_length + + self.max_spec_length = max_spec_length + self.max_utterance_length = max_uterance_length + + def __len__(self): """Returns the length of the dataset""" return len(self.dataset_lookup) @@ -197,13 +208,32 @@ class MLSDataset(Dataset): ) waveform, sample_rate = torchaudio.load(audio_path) # type: ignore + # TODO: figure out if we have to resample or not + # TODO: pad correctly (manually) + spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1) + print(f"spec.shape: {spec.shape}") + input_length = spec.shape[0] // 2 + spec = ( + torch.nn.functional.pad(spec, pad=(0, self.max_spec_length), mode="constant", value=0) + .unsqueeze(1) + .transpose(2, 3) + ) + + utterance_length = len(utterance) + self.tokenizer.enable_padding() + utterance = self.tokenizer.encode( + utterance, + ).ids + + utterance = torch.Tensor(utterance) return Sample( + # TODO: add flag to only return spectrogram or waveform or both waveform=waveform, - spectrogram=torchaudio.transforms.MelSpectrogram( - sample_rate=16000, n_mels=128 - )(waveform), + spectrogram=spec, + input_length=input_length, utterance=utterance, + utterance_length=utterance_length, sample_rate=sample_rate, speaker_id=self.dataset_lookup[idx]["speakerid"], book_id=self.dataset_lookup[idx]["bookid"], @@ -218,62 +248,6 @@ class MLSDataset(Dataset): torch.hub.download_url_to_file(url, dataset_path) -class DataProcessor: - """Factory for DataProcessingclass - - Transforms the dataset into spectrograms and labels, as well as a tokenizer - """ - - def __init__( - self, - dataset: MultilingualLibriSpeech, - tokenizer_path: str, - data_type: str = "train", - tokenizer_type: str = "BPE", - ): - self.dataset = dataset - self.data_type = data_type - - self.train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), - ) - - self.valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - self.tokenizer = tokenizer_factory( - tokenizer_path=tokenizer_path, tokenizer_type=tokenizer_type - ) - - def __call__(self) -> tuple[np.ndarray, np.ndarray, int, int]: - """Returns spectrograms, labels and their lenghts""" - for sample in self.dataset: - if self.data_type == "train": - spec = ( - self.train_audio_transforms(sample["waveform"]) - .squeeze(0) - .transpose(0, 1) - ) - elif self.data_type == "valid": - spec = ( - self.valid_audio_transforms(sample["waveform"]) - .squeeze(0) - .transpose(0, 1) - ) - else: - raise ValueError("data_type should be train or valid") - label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) - - yield spec, label, spec.shape[0] // 2, len(labels) - - if __name__ == "__main__": dataset_path = "/Volumes/pherkel/SWR2-ASR" language = "mls_german_opus" @@ -281,4 +255,9 @@ if __name__ == "__main__": download = False dataset = MLSDataset(dataset_path, language, split, download) - print(dataset[0]) + + tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") + dataset.set_tokenizer(tok) + dataset.calc_paddings() + + print(dataset[41]["spectrogram"].shape) -- cgit v1.2.3 From 9b450c685e9ec4e7e74688de3d6dbb719b19fcf6 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Tue, 29 Aug 2023 11:55:08 +0200 Subject: added padding calculation for audiofiles --- poetry.lock | 187 +++++++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + swr2_asr/utils.py | 108 +++++++++++++++++++++++-------- 3 files changed, 267 insertions(+), 29 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/poetry.lock b/poetry.lock index 1f3609a..77643f5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -19,6 +19,50 @@ wrapt = [ {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, ] +[[package]] +name = "attrs" +version = "19.3.0" +description = "Classes Without Boilerplate" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "attrs-19.3.0-py2.py3-none-any.whl", hash = "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c"}, + {file = "attrs-19.3.0.tar.gz", hash = "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72"}, +] + +[package.extras] +azure-pipelines = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-azurepipelines", "six", "zope.interface"] +dev = ["coverage", "hypothesis", "pre-commit", "pympler", "pytest (>=4.3.0)", "six", "sphinx", "zope.interface"] +docs = ["sphinx", "zope.interface"] +tests = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] + +[[package]] +name = "audio-metadata" +version = "0.11.1" +description = "A library for reading and, in the future, writing metadata from audio files." +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "audio-metadata-0.11.1.tar.gz", hash = "sha256:9e7ba79d49cf048a911d5f7d55bb2715c10be5c127fe5db0987c5fe1aa7335eb"}, + {file = "audio_metadata-0.11.1-py3-none-any.whl", hash = "sha256:f5b85ad087324c255f8d1223574c3e7d3c27b649e411d1dd54aa3bf342fe93fb"}, +] + +[package.dependencies] +attrs = ">=18.2,<19.4" +bidict = "<1.0.0" +bitstruct = ">=6.0,<9.0" +more-itertools = ">=4.0,<9.0" +pendulum = ">=2.0,<2.0.5 || >2.0.5,<2.1.0 || >2.1.0,<=3.0" +pprintpp = "<1.0.0" +tbm-utils = ">=2.3,<3.0" +wrapt = ">=1.0,<2.0" + +[package.extras] +dev = ["coverage[toml] (>=5.0,<6.0)", "flake8 (>=3.5,<4.0)", "flake8-builtins (>=1.0,<2.0)", "flake8-comprehensions (>=2.0,<=4.0)", "flake8-import-order (>=0.18,<0.19)", "flake8-import-order-tbm (>=1.0,<2.0)", "nox (>=2019,<2020)", "sphinx (>=2.0,<3.0)", "sphinx-material (<1.0.0)", "ward (>=0.42.0-beta.0)"] +doc = ["sphinx (>=2.0,<3.0)", "sphinx-material (<1.0.0)"] +lint = ["flake8 (>=3.5,<4.0)", "flake8-builtins (>=1.0,<2.0)", "flake8-comprehensions (>=2.0,<=4.0)", "flake8-import-order (>=0.18,<0.19)", "flake8-import-order-tbm (>=1.0,<2.0)"] +test = ["coverage[toml] (>=5.0,<6.0)", "nox (>=2019,<2020)", "ward (>=0.42.0-beta.0)"] + [[package]] name = "AudioLoader" version = "0.1.4" @@ -34,6 +78,32 @@ url = "https://github.com/marvinborner/AudioLoader.git" reference = "HEAD" resolved_reference = "8fb829bf7fb98f26f8456dc22ef0fe2c7bb38ac2" +[[package]] +name = "bidict" +version = "0.22.1" +description = "The bidirectional mapping library for Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "bidict-0.22.1-py3-none-any.whl", hash = "sha256:6ef212238eb884b664f28da76f33f1d28b260f665fc737b413b287d5487d1e7b"}, + {file = "bidict-0.22.1.tar.gz", hash = "sha256:1e0f7f74e4860e6d0943a05d4134c63a2fad86f3d4732fb265bd79e4e856d81d"}, +] + +[package.extras] +docs = ["furo", "sphinx", "sphinx-copybutton"] +lint = ["pre-commit"] +test = ["hypothesis", "pytest", "pytest-benchmark[histogram]", "pytest-cov", "pytest-xdist", "sortedcollections", "sortedcontainers", "sphinx"] + +[[package]] +name = "bitstruct" +version = "8.17.0" +description = "This module performs conversions between Python values and C bit field structs represented as Python byte strings." +optional = false +python-versions = "*" +files = [ + {file = "bitstruct-8.17.0.tar.gz", hash = "sha256:eb94b40e4218a23aa8f90406b836a9e6ed83e48b8d112ce3f96408463bd1b874"}, +] + [[package]] name = "black" version = "23.7.0" @@ -348,6 +418,17 @@ ports-rtmidi-python = ["rtmidi-python (>=0.2.2,<0.3.0)"] release = ["twine (>=4.0.2,<4.1.0)"] test-code = ["pytest (>=7.4.0,<7.5.0)"] +[[package]] +name = "more-itertools" +version = "8.14.0" +description = "More routines for operating on iterables, beyond itertools" +optional = false +python-versions = ">=3.5" +files = [ + {file = "more-itertools-8.14.0.tar.gz", hash = "sha256:c09443cd3d5438b8dafccd867a6bc1cb0894389e90cb53d227456b0b0bccb750"}, + {file = "more_itertools-8.14.0-py3-none-any.whl", hash = "sha256:1bc4f91ee5b1b31ac7ceacc17c09befe6a40a503907baf9c839c229b5095cfd2"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -654,6 +735,40 @@ files = [ {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, ] +[[package]] +name = "pendulum" +version = "2.1.2" +description = "Python datetimes made easy" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "pendulum-2.1.2-cp27-cp27m-macosx_10_15_x86_64.whl", hash = "sha256:b6c352f4bd32dff1ea7066bd31ad0f71f8d8100b9ff709fb343f3b86cee43efe"}, + {file = "pendulum-2.1.2-cp27-cp27m-win_amd64.whl", hash = "sha256:318f72f62e8e23cd6660dbafe1e346950281a9aed144b5c596b2ddabc1d19739"}, + {file = "pendulum-2.1.2-cp35-cp35m-macosx_10_15_x86_64.whl", hash = "sha256:0731f0c661a3cb779d398803655494893c9f581f6488048b3fb629c2342b5394"}, + {file = "pendulum-2.1.2-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:3481fad1dc3f6f6738bd575a951d3c15d4b4ce7c82dce37cf8ac1483fde6e8b0"}, + {file = "pendulum-2.1.2-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:9702069c694306297ed362ce7e3c1ef8404ac8ede39f9b28b7c1a7ad8c3959e3"}, + {file = "pendulum-2.1.2-cp35-cp35m-win_amd64.whl", hash = "sha256:fb53ffa0085002ddd43b6ca61a7b34f2d4d7c3ed66f931fe599e1a531b42af9b"}, + {file = "pendulum-2.1.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:c501749fdd3d6f9e726086bf0cd4437281ed47e7bca132ddb522f86a1645d360"}, + {file = "pendulum-2.1.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:c807a578a532eeb226150d5006f156632df2cc8c5693d778324b43ff8c515dd0"}, + {file = "pendulum-2.1.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:2d1619a721df661e506eff8db8614016f0720ac171fe80dda1333ee44e684087"}, + {file = "pendulum-2.1.2-cp36-cp36m-win_amd64.whl", hash = "sha256:f888f2d2909a414680a29ae74d0592758f2b9fcdee3549887779cd4055e975db"}, + {file = "pendulum-2.1.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:e95d329384717c7bf627bf27e204bc3b15c8238fa8d9d9781d93712776c14002"}, + {file = "pendulum-2.1.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:4c9c689747f39d0d02a9f94fcee737b34a5773803a64a5fdb046ee9cac7442c5"}, + {file = "pendulum-2.1.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:1245cd0075a3c6d889f581f6325dd8404aca5884dea7223a5566c38aab94642b"}, + {file = "pendulum-2.1.2-cp37-cp37m-win_amd64.whl", hash = "sha256:db0a40d8bcd27b4fb46676e8eb3c732c67a5a5e6bfab8927028224fbced0b40b"}, + {file = "pendulum-2.1.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:f5e236e7730cab1644e1b87aca3d2ff3e375a608542e90fe25685dae46310116"}, + {file = "pendulum-2.1.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:de42ea3e2943171a9e95141f2eecf972480636e8e484ccffaf1e833929e9e052"}, + {file = "pendulum-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7c5ec650cb4bec4c63a89a0242cc8c3cebcec92fcfe937c417ba18277d8560be"}, + {file = "pendulum-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33fb61601083f3eb1d15edeb45274f73c63b3c44a8524703dc143f4212bf3269"}, + {file = "pendulum-2.1.2-cp39-cp39-manylinux1_i686.whl", hash = "sha256:29c40a6f2942376185728c9a0347d7c0f07905638c83007e1d262781f1e6953a"}, + {file = "pendulum-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:94b1fc947bfe38579b28e1cccb36f7e28a15e841f30384b5ad6c5e31055c85d7"}, + {file = "pendulum-2.1.2.tar.gz", hash = "sha256:b06a0ca1bfe41c990bbf0c029f0b6501a7f2ec4e38bfec730712015e8860f207"}, +] + +[package.dependencies] +python-dateutil = ">=2.6,<3.0" +pytzdata = ">=2020.1" + [[package]] name = "platformdirs" version = "3.10.0" @@ -669,6 +784,17 @@ files = [ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] +[[package]] +name = "pprintpp" +version = "0.4.0" +description = "A drop-in replacement for pprint that's actually pretty" +optional = false +python-versions = "*" +files = [ + {file = "pprintpp-0.4.0-py2.py3-none-any.whl", hash = "sha256:b6b4dcdd0c0c0d75e4d7b2f21a9e933e5b2ce62b26e1a54537f9651ae5a5c01d"}, + {file = "pprintpp-0.4.0.tar.gz", hash = "sha256:ea826108e2c7f49dc6d66c752973c3fc9749142a798d6b254e1e301cfdbc6403"}, +] + [[package]] name = "pylint" version = "2.17.5" @@ -697,6 +823,31 @@ tomlkit = ">=0.10.1" spelling = ["pyenchant (>=3.2,<4.0)"] testutils = ["gitpython (>3)"] +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "pytzdata" +version = "2020.1" +description = "The Olson timezone database for Python." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pytzdata-2020.1-py2.py3-none-any.whl", hash = "sha256:e1e14750bcf95016381e4d472bad004eef710f2d6417240904070b3d6654485f"}, + {file = "pytzdata-2020.1.tar.gz", hash = "sha256:3efa13b335a00a8de1d345ae41ec78dd11c9f8807f522d39850f2dd828681540"}, +] + [[package]] name = "ruff" version = "0.0.285" @@ -739,6 +890,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "sympy" version = "1.12" @@ -753,6 +915,29 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tbm-utils" +version = "2.6.0" +description = "A commonly-used set of utilities used by me (thebigmunch)." +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "tbm-utils-2.6.0.tar.gz", hash = "sha256:235748cceeb22c042e32d2fdfd4d710021bac9b938c4f2c35e1fce1cfd58f7ec"}, + {file = "tbm_utils-2.6.0-py3-none-any.whl", hash = "sha256:692b5cde2b810bb84e55ca0f5f6c055ca6ad321c4d5acc0cade98af97c5998e2"}, +] + +[package.dependencies] +attrs = ">=18.2,<19.4" +pendulum = ">=2.0,<2.0.5 || >2.0.5,<2.1.0 || >2.1.0,<=3.0" +pprintpp = "<1.0.0" +wrapt = ">=1.0,<2.0" + +[package.extras] +dev = ["coverage[toml] (>=4.5,<6.0)", "flake8 (>=3.5,<4.0)", "flake8-builtins (>=1.0,<2.0)", "flake8-comprehensions (>=2.0,<=4.0)", "flake8-import-order (>=0.18,<0.19)", "flake8-import-order-tbm (>=1.0,<2.0)", "nox (>=2019,<2020)", "pytest (>=4.0,<6.0)", "sphinx (>=2.0,<3.0)", "sphinx-material (<1.0.0)"] +doc = ["sphinx (>=2.0,<3.0)", "sphinx-material (<1.0.0)"] +lint = ["flake8 (>=3.5,<4.0)", "flake8-builtins (>=1.0,<2.0)", "flake8-comprehensions (>=2.0,<=4.0)", "flake8-import-order (>=0.18,<0.19)", "flake8-import-order-tbm (>=1.0,<2.0)"] +test = ["coverage[toml] (>=4.5,<6.0)", "nox (>=2019,<2020)", "pytest (>=4.0,<6.0)"] + [[package]] name = "tokenizers" version = "0.13.3" @@ -1096,4 +1281,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "6b42e36364178f1670267137f73e8d2b2f3fc1d534a2b198d4ca3f65457d55c2" +content-hash = "992e0eb975f6fb490726bc9db470a682986c1a287078bec407a4a6dc96e45b3c" diff --git a/pyproject.toml b/pyproject.toml index 94f7553..dc136f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ numpy = "^1.25.2" mido = "^1.3.0" tokenizers = "^0.13.3" click = "^8.1.7" +audio-metadata = "^0.11.1" [tool.poetry.group.dev.dependencies] black = "^23.7.0" diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 786fbcf..404661d 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -1,6 +1,7 @@ """Class containing utils for the ASR system.""" import os from enum import Enum +from multiprocessing import Pool from typing import TypedDict import numpy as np @@ -8,6 +9,8 @@ import torch import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset +from tqdm import tqdm +import audio_metadata from swr2_asr.tokenizer import CharTokenizer, TokenizerType @@ -133,11 +136,13 @@ class MLSDataset(Dataset): for path, utterance in zip(identifier, utterances, strict=False) ] + self.max_spec_length = 0 + self.max_utterance_length = 0 + def set_tokenizer(self, tokenizer: type[TokenizerType]): """Sets the tokenizer""" self.tokenizer = tokenizer - - self.calc_paddings() + # self.calc_paddings() def _handle_download_dataset(self, download: bool): """Download the dataset""" @@ -158,27 +163,73 @@ class MLSDataset(Dataset): if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") - def calc_paddings(self): - """Sets the maximum length of the spectrogram""" + def _calculate_max_length(self, chunk): + """Calculates the maximum length of the spectrogram and the utterance + + to be called in a multiprocessing pool + """ + max_spec_length = 0 + max_utterance_length = 0 + + for sample in chunk: + audio_path = os.path.join( + self.dataset_path, + self.language, + self.mls_split, + "audio", + sample["speakerid"], + sample["bookid"], + "_".join( + [ + sample["speakerid"], + sample["bookid"], + sample["chapterid"], + ] + ) + + self.file_ext, + ) + metadata = audio_metadata.load(audio_path) + audio_duration = metadata.streaminfo.duration + sample_rate = metadata.streaminfo.sample_rate + + max_spec_length = int(max(max_spec_length, (audio_duration * sample_rate) // 200)) + max_utterance_length = max(max_utterance_length, len(self.tokenizer.encode(sample["utterance"]).ids)) + + return max_spec_length, max_utterance_length + + def calc_paddings(self) -> None: + """Sets the maximum length of the spectrogram and the utterance""" # check if dataset has been loaded and tokenizer has been set if not self.dataset_lookup: raise ValueError("Dataset not loaded") if not self.tokenizer: raise ValueError("Tokenizer not set") - - max_spec_length = 0 - max_uterance_length = 0 - for sample in self.dataset_lookup: - spec_length = sample["spectrogram"].shape[0] - if spec_length > max_spec_length: - max_spec_length = spec_length - - utterance_length = sample["utterance"].shape[0] - if utterance_length > max_uterance_length: - max_uterance_length = utterance_length - - self.max_spec_length = max_spec_length - self.max_utterance_length = max_uterance_length + # check if paddings have been calculated already + if os.path.isfile(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt")): + print("Paddings already calculated") + with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "r") as f: + self.max_spec_length, self.max_utterance_length = [int(line.strip()) for line in f.readlines()] + return + else: + print("Calculating paddings...") + + thread_count = os.cpu_count() + if thread_count is None: + thread_count = 4 + chunk_size = len(self.dataset_lookup) // thread_count + chunks = [self.dataset_lookup[i : i + chunk_size] for i in range(0, len(self.dataset_lookup), chunk_size)] + + with Pool(thread_count) as p: + results = list(p.imap(self._calculate_max_length, chunks)) + + for spec, utterance in results: + self.max_spec_length = max(self.max_spec_length, spec) + self.max_utterance_length = max(self.max_utterance_length, utterance) + + # write to file + with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "w") as f: + f.write(f"{self.max_spec_length}\n") + f.write(f"{self.max_utterance_length}") def __len__(self): """Returns the length of the dataset""" @@ -208,18 +259,18 @@ class MLSDataset(Dataset): ) waveform, sample_rate = torchaudio.load(audio_path) # type: ignore - # TODO: figure out if we have to resample or not - # TODO: pad correctly (manually) + + # resample if necessary + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) + waveform = resampler(waveform) + sample_rate = 16000 + spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1) - print(f"spec.shape: {spec.shape}") - input_length = spec.shape[0] // 2 - spec = ( - torch.nn.functional.pad(spec, pad=(0, self.max_spec_length), mode="constant", value=0) - .unsqueeze(1) - .transpose(2, 3) - ) + input_length = spec.shape[0] // 2 utterance_length = len(utterance) + self.tokenizer.enable_padding() utterance = self.tokenizer.encode( utterance, @@ -260,4 +311,5 @@ if __name__ == "__main__": dataset.set_tokenizer(tok) dataset.calc_paddings() - print(dataset[41]["spectrogram"].shape) + print(f"Spectrogram shape: {dataset[41]['spectrogram'].shape}") + print(f"Utterance shape: {dataset[41]['utterance'].shape}") -- cgit v1.2.3 From 335b8a32f8bba5d37c00af6b4ecd1b9fc520f964 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Wed, 30 Aug 2023 17:11:51 +0200 Subject: wörks now!°!! --- swr2_asr/model_deep_speech.py | 2 +- swr2_asr/train.py | 49 +++++++------ swr2_asr/utils.py | 164 +++++++++++++++++------------------------- 3 files changed, 95 insertions(+), 120 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index ea0b667..f00ebd4 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,5 +1,5 @@ -from torch import nn import torch.nn.functional as F +from torch import nn class CNNLayerNorm(nn.Module): diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 53cdac1..bae8c7c 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,13 +5,14 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +from AudioLoader.speech import MultilingualLibriSpeech from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.tokenizer import train_bpe_tokenizer -from swr2_asr.utils import MLSDataset, Split +from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer @@ -31,20 +32,20 @@ class HParams(TypedDict): epochs: int -def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True): +def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) + targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: continue decode.append(index.item()) - decodes.append(text_transform.decode(decode)) + decodes.append(tokenizer.decode(decode)) return decodes, targets @@ -77,15 +78,14 @@ def train( model.train() data_len = len(train_loader.dataset) for batch_idx, _data in enumerate(train_loader): - _, spectrograms, input_lengths, labels, label_lengths, *_ = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) + spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) loss.backward() optimizer.step() @@ -102,7 +102,7 @@ def train( return loss.item() -def test(model, device, test_loader, criterion): +def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") model.eval() @@ -110,17 +110,20 @@ def test(model, device, test_loader, criterion): test_cer, test_wer = [], [] with torch.no_grad(): for _data in test_loader: - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) + spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths) + decoded_preds, decoded_targets = greedy_decoder( + output = output.transpose(0, 1), + labels = labels, + label_lengths= _data["utterance_length"], + tokenizer=tokenizer) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) @@ -150,12 +153,11 @@ def run( # device = torch.device("mps") # load dataset - train_dataset = MLSDataset(dataset_path, language, Split.train, download=True) - valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True) - test_dataset = MLSDataset(dataset_path, language, Split.test, download=True) + train_dataset = MLSDataset(dataset_path, language, Split.train, download=True, spectrogram_hparams=None) + valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True, spectrogram_hparams=None) + test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None) - # TODO: add flag to choose tokenizer - # load tokenizer (bpe by default): + # load tokenizer (bpe by default): if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"): print("There is no tokenizer available. Do you want to train it on the dataset?") input("Press Enter to continue...") @@ -167,12 +169,14 @@ def run( out_path="data/tokenizers/bpe_tokenizer_german_3000.json", vocab_size=3000, ) - + tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - + train_dataset.set_tokenizer(tokenizer) valid_dataset.set_tokenizer(tokenizer) test_dataset.set_tokenizer(tokenizer) + + print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") hparams = HParams( n_cnn_layers=3, @@ -191,12 +195,14 @@ def run( train_dataset, batch_size=hparams["batch_size"], shuffle=True, + collate_fn=lambda x: collate_fn(x), ) valid_loader = DataLoader( valid_dataset, batch_size=hparams["batch_size"], shuffle=True, + collate_fn=lambda x: collate_fn(x), ) # enable flag to find the most compatible algorithms in advance @@ -243,7 +249,7 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=valid_loader, criterion=criterion) + test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -285,3 +291,6 @@ def run_cli( dataset_path=dataset_path, language="mls_german_opus", ) + +if __name__ == "__main__": + run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") \ No newline at end of file diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 404661d..4c751d5 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -1,7 +1,6 @@ """Class containing utils for the ASR system.""" import os from enum import Enum -from multiprocessing import Pool from typing import TypedDict import numpy as np @@ -10,9 +9,8 @@ import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset from tqdm import tqdm -import audio_metadata -from swr2_asr.tokenizer import CharTokenizer, TokenizerType +from swr2_asr.tokenizer import TokenizerType train_audio_transforms = torch.nn.Sequential( torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), @@ -91,20 +89,42 @@ class MLSDataset(Dataset): __ """ - def __init__(self, dataset_path: str, language: str, split: Split, download: bool): + def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None): """Initializes the dataset""" self.dataset_path = dataset_path self.language = language self.file_ext = ".opus" if "opus" in language else ".flac" self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally + + if spectrogram_hparams is None: + self.spectrogram_hparams = { + "sample_rate": 16000, + "n_fft": 400, + "win_length": 400, + "hop_length": 160, + "n_mels": 128, + "f_min": 0, + "f_max": 8000, + "power": 2.0, + } + else: + self.spectrogram_hparams = spectrogram_hparams + self.dataset_lookup = [] self.tokenizer: type[TokenizerType] self._handle_download_dataset(download) self._validate_local_directory() + self.initialize() - transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt") + + def initialize(self) -> None: + """Initializes the dataset + + Reads the transcripts.txt file and creates a lookup table + """ + transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt") with open(transcripts_path, "r", encoding="utf-8") as script_file: # read all lines in transcripts.txt @@ -136,13 +156,9 @@ class MLSDataset(Dataset): for path, utterance in zip(identifier, utterances, strict=False) ] - self.max_spec_length = 0 - self.max_utterance_length = 0 - def set_tokenizer(self, tokenizer: type[TokenizerType]): """Sets the tokenizer""" self.tokenizer = tokenizer - # self.calc_paddings() def _handle_download_dataset(self, download: bool): """Download the dataset""" @@ -163,80 +179,14 @@ class MLSDataset(Dataset): if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") - def _calculate_max_length(self, chunk): - """Calculates the maximum length of the spectrogram and the utterance - - to be called in a multiprocessing pool - """ - max_spec_length = 0 - max_utterance_length = 0 - - for sample in chunk: - audio_path = os.path.join( - self.dataset_path, - self.language, - self.mls_split, - "audio", - sample["speakerid"], - sample["bookid"], - "_".join( - [ - sample["speakerid"], - sample["bookid"], - sample["chapterid"], - ] - ) - + self.file_ext, - ) - metadata = audio_metadata.load(audio_path) - audio_duration = metadata.streaminfo.duration - sample_rate = metadata.streaminfo.sample_rate - - max_spec_length = int(max(max_spec_length, (audio_duration * sample_rate) // 200)) - max_utterance_length = max(max_utterance_length, len(self.tokenizer.encode(sample["utterance"]).ids)) - - return max_spec_length, max_utterance_length - - def calc_paddings(self) -> None: - """Sets the maximum length of the spectrogram and the utterance""" - # check if dataset has been loaded and tokenizer has been set - if not self.dataset_lookup: - raise ValueError("Dataset not loaded") - if not self.tokenizer: - raise ValueError("Tokenizer not set") - # check if paddings have been calculated already - if os.path.isfile(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt")): - print("Paddings already calculated") - with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "r") as f: - self.max_spec_length, self.max_utterance_length = [int(line.strip()) for line in f.readlines()] - return - else: - print("Calculating paddings...") - - thread_count = os.cpu_count() - if thread_count is None: - thread_count = 4 - chunk_size = len(self.dataset_lookup) // thread_count - chunks = [self.dataset_lookup[i : i + chunk_size] for i in range(0, len(self.dataset_lookup), chunk_size)] - - with Pool(thread_count) as p: - results = list(p.imap(self._calculate_max_length, chunks)) - - for spec, utterance in results: - self.max_spec_length = max(self.max_spec_length, spec) - self.max_utterance_length = max(self.max_utterance_length, utterance) - - # write to file - with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "w") as f: - f.write(f"{self.max_spec_length}\n") - f.write(f"{self.max_utterance_length}") - def __len__(self): """Returns the length of the dataset""" return len(self.dataset_lookup) def __getitem__(self, idx: int) -> Sample: """One sample""" + if self.tokenizer is None: + raise ValueError("No tokenizer set") # get the utterance utterance = self.dataset_lookup[idx]["utterance"] @@ -261,42 +211,62 @@ class MLSDataset(Dataset): waveform, sample_rate = torchaudio.load(audio_path) # type: ignore # resample if necessary - if sample_rate != 16000: - resampler = torchaudio.transforms.Resample(sample_rate, 16000) + if sample_rate != self.spectrogram_hparams["sample_rate"]: + resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"]) waveform = resampler(waveform) - sample_rate = 16000 - spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1) + spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1) input_length = spec.shape[0] // 2 + utterance_length = len(utterance) - self.tokenizer.enable_padding() - utterance = self.tokenizer.encode( - utterance, - ).ids + utterance = self.tokenizer.encode(utterance) - utterance = torch.Tensor(utterance) + utterance = torch.LongTensor(utterance.ids) return Sample( - # TODO: add flag to only return spectrogram or waveform or both waveform=waveform, spectrogram=spec, input_length=input_length, utterance=utterance, utterance_length=utterance_length, - sample_rate=sample_rate, + sample_rate=self.spectrogram_hparams["sample_rate"], speaker_id=self.dataset_lookup[idx]["speakerid"], book_id=self.dataset_lookup[idx]["bookid"], chapter_id=self.dataset_lookup[idx]["chapterid"], ) - def download(self, dataset_path: str, language: str): - """Download the dataset""" - os.makedirs(dataset_path) - url = f"https://dl.fbaipublicfiles.com/mls/{language}.tar.gz" - - torch.hub.download_url_to_file(url, dataset_path) +def collate_fn(samples: list[Sample]) -> dict: + """Collate function for the dataloader + + pads all tensors within a batch to the same dimensions + """ + waveforms = [] + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + + for sample in samples: + waveforms.append(sample["waveform"].transpose(0, 1)) + spectrograms.append(sample["spectrogram"]) + labels.append(sample["utterance"]) + input_lengths.append(sample["spectrogram"].shape[0] // 2) + label_lengths.append(len(sample["utterance"])) + + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return { + "waveform": waveforms, + "spectrogram": spectrograms, + "input_length": input_lengths, + "utterance": labels, + "utterance_length": label_lengths, + } + if __name__ == "__main__": @@ -305,11 +275,7 @@ if __name__ == "__main__": split = Split.train download = False - dataset = MLSDataset(dataset_path, language, split, download) + dataset = MLSDataset(dataset_path, language, split, download, None) tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) - dataset.calc_paddings() - - print(f"Spectrogram shape: {dataset[41]['spectrogram'].shape}") - print(f"Utterance shape: {dataset[41]['utterance'].shape}") -- cgit v1.2.3 From f3d2ea9a16944434a08e662c5ecfd6ba50e5ea89 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Fri, 1 Sep 2023 22:40:29 +0200 Subject: many todos --- swr2_asr/tokenizer.py | 1 + swr2_asr/train.py | 35 +++++++++++++++++++---------------- swr2_asr/utils.py | 4 ++-- 3 files changed, 22 insertions(+), 18 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 5758da7..8e3bf09 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -302,6 +302,7 @@ def train_bpe_tokenizer( "ü", ] + # TODO: add padding token / whitespace token / special tokens trainer = BpeTrainer( special_tokens=["[UNK]"], vocab_size=vocab_size, diff --git a/swr2_asr/train.py b/swr2_asr/train.py index bae8c7c..f3efd69 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,10 +5,10 @@ from typing import TypedDict import click import torch import torch.nn.functional as F -from AudioLoader.speech import MultilingualLibriSpeech from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader +from tqdm import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.tokenizer import train_bpe_tokenizer @@ -16,6 +16,7 @@ from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer +# TODO: improve naming of functions class HParams(TypedDict): """Type for the hyperparameters of the model.""" @@ -32,6 +33,7 @@ class HParams(TypedDict): epochs: int +# TODO: get blank label from tokenizer def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member @@ -76,8 +78,9 @@ def train( ): """Train""" model.train() - data_len = len(train_loader.dataset) - for batch_idx, _data in enumerate(train_loader): + print(f"Epoch: {epoch}") + losses = [] + for _data in tqdm(train_loader, desc="batches"): spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) optimizer.zero_grad() @@ -91,17 +94,15 @@ def train( optimizer.step() scheduler.step() iter_meter.step() - if batch_idx % 100 == 0 or batch_idx == data_len: - print( - f"Train Epoch: \ - {epoch} \ - [{batch_idx * len(spectrograms)}/{data_len} \ - ({100.0 * batch_idx / len(train_loader)}%)]\t \ - Loss: {loss.item()}" - ) - return loss.item() + + losses.append(loss.item()) + print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") + return sum(losses) / len(losses) +# TODO: profile this function call +# TODO: only calculate wer and cer at the end, or less often +# TODO: change this to only be a sanity check and calculate measures after training def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") @@ -116,6 +117,7 @@ def test(model, device, test_loader, criterion, tokenizer): output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) + # TODO: get rid of this loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) @@ -150,11 +152,12 @@ def run( use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - # device = torch.device("mps") + device = torch.device("mps") # load dataset - train_dataset = MLSDataset(dataset_path, language, Split.train, download=True, spectrogram_hparams=None) - valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True, spectrogram_hparams=None) + # TODO: change this from dev split to train split again (was faster for development) + train_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) + valid_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None) # load tokenizer (bpe by default): @@ -237,7 +240,7 @@ def run( ) iter_meter = IterMeter() - for epoch in range(1, epochs + 1): + for epoch in range(1, epochs + 1): loss = train( model, device, diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 4c751d5..3b9b3ca 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -236,6 +236,7 @@ class MLSDataset(Dataset): book_id=self.dataset_lookup[idx]["bookid"], chapter_id=self.dataset_lookup[idx]["chapterid"], ) + def collate_fn(samples: list[Sample]) -> dict: """Collate function for the dataloader @@ -266,8 +267,7 @@ def collate_fn(samples: list[Sample]) -> dict: "utterance": labels, "utterance_length": label_lengths, } - - + if __name__ == "__main__": dataset_path = "/Volumes/pherkel/SWR2-ASR" -- cgit v1.2.3 From 33f09080aee10bddb4797a557d676ee1f7b8de31 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Sun, 3 Sep 2023 19:30:33 +0200 Subject: idk, hopefully this works --- poetry.lock | 685 +++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 3 +- swr2_asr/inference_test.py | 4 +- swr2_asr/loss_scores.py | 8 +- swr2_asr/model_deep_speech.py | 13 +- swr2_asr/tokenizer.py | 17 +- swr2_asr/train.py | 90 +++--- swr2_asr/utils.py | 85 +++--- 8 files changed, 807 insertions(+), 98 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/poetry.lock b/poetry.lock index 77643f5..38cba87 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,15 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. + +[[package]] +name = "appnope" +version = "0.1.3" +description = "Disable App Nap on macOS >= 10.9" +optional = false +python-versions = "*" +files = [ + {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"}, + {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"}, +] [[package]] name = "astroid" @@ -19,6 +30,23 @@ wrapt = [ {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, ] +[[package]] +name = "asttokens" +version = "2.3.0" +description = "Annotate AST trees with source code positions" +optional = false +python-versions = "*" +files = [ + {file = "asttokens-2.3.0-py2.py3-none-any.whl", hash = "sha256:bef1a51bc256d349e9f94e7e40e44b705ed1162f55294220dd561d24583d9877"}, + {file = "asttokens-2.3.0.tar.gz", hash = "sha256:2552a88626aaa7f0f299f871479fc755bd4e7c11e89078965e928fb7bb9a6afe"}, +] + +[package.dependencies] +six = ">=1.12.0" + +[package.extras] +test = ["astroid", "pytest"] + [[package]] name = "attrs" version = "19.3.0" @@ -78,6 +106,17 @@ url = "https://github.com/marvinborner/AudioLoader.git" reference = "HEAD" resolved_reference = "8fb829bf7fb98f26f8456dc22ef0fe2c7bb38ac2" +[[package]] +name = "backcall" +version = "0.2.0" +description = "Specifications for callback functions passed in to an API" +optional = false +python-versions = "*" +files = [ + {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"}, + {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, +] + [[package]] name = "bidict" version = "0.22.1" @@ -149,6 +188,82 @@ d = ["aiohttp (>=3.7.4)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "cffi" +version = "1.15.1" +description = "Foreign Function Interface for Python calling C code." +optional = false +python-versions = "*" +files = [ + {file = "cffi-1.15.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a66d3508133af6e8548451b25058d5812812ec3798c886bf38ed24a98216fab2"}, + {file = "cffi-1.15.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:470c103ae716238bbe698d67ad020e1db9d9dba34fa5a899b5e21577e6d52ed2"}, + {file = "cffi-1.15.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:9ad5db27f9cabae298d151c85cf2bad1d359a1b9c686a275df03385758e2f914"}, + {file = "cffi-1.15.1-cp27-cp27m-win32.whl", hash = "sha256:b3bbeb01c2b273cca1e1e0c5df57f12dce9a4dd331b4fa1635b8bec26350bde3"}, + {file = "cffi-1.15.1-cp27-cp27m-win_amd64.whl", hash = "sha256:e00b098126fd45523dd056d2efba6c5a63b71ffe9f2bbe1a4fe1716e1d0c331e"}, + {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:d61f4695e6c866a23a21acab0509af1cdfd2c013cf256bbf5b6b5e2695827162"}, + {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:ed9cb427ba5504c1dc15ede7d516b84757c3e3d7868ccc85121d9310d27eed0b"}, + {file = "cffi-1.15.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d39875251ca8f612b6f33e6b1195af86d1b3e60086068be9cc053aa4376e21"}, + {file = "cffi-1.15.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:285d29981935eb726a4399badae8f0ffdff4f5050eaa6d0cfc3f64b857b77185"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3eb6971dcff08619f8d91607cfc726518b6fa2a9eba42856be181c6d0d9515fd"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21157295583fe8943475029ed5abdcf71eb3911894724e360acff1d61c1d54bc"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5635bd9cb9731e6d4a1132a498dd34f764034a8ce60cef4f5319c0541159392f"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2012c72d854c2d03e45d06ae57f40d78e5770d252f195b93f581acf3ba44496e"}, + {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd86c085fae2efd48ac91dd7ccffcfc0571387fe1193d33b6394db7ef31fe2a4"}, + {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:fa6693661a4c91757f4412306191b6dc88c1703f780c8234035eac011922bc01"}, + {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59c0b02d0a6c384d453fece7566d1c7e6b7bae4fc5874ef2ef46d56776d61c9e"}, + {file = "cffi-1.15.1-cp310-cp310-win32.whl", hash = "sha256:cba9d6b9a7d64d4bd46167096fc9d2f835e25d7e4c121fb2ddfc6528fb0413b2"}, + {file = "cffi-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:ce4bcc037df4fc5e3d184794f27bdaab018943698f4ca31630bc7f84a7b69c6d"}, + {file = "cffi-1.15.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d08afd128ddaa624a48cf2b859afef385b720bb4b43df214f85616922e6a5ac"}, + {file = "cffi-1.15.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3799aecf2e17cf585d977b780ce79ff0dc9b78d799fc694221ce814c2c19db83"}, + {file = "cffi-1.15.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a591fe9e525846e4d154205572a029f653ada1a78b93697f3b5a8f1f2bc055b9"}, + {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3548db281cd7d2561c9ad9984681c95f7b0e38881201e157833a2342c30d5e8c"}, + {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91fc98adde3d7881af9b59ed0294046f3806221863722ba7d8d120c575314325"}, + {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94411f22c3985acaec6f83c6df553f2dbe17b698cc7f8ae751ff2237d96b9e3c"}, + {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:03425bdae262c76aad70202debd780501fabeaca237cdfddc008987c0e0f59ef"}, + {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cc4d65aeeaa04136a12677d3dd0b1c0c94dc43abac5860ab33cceb42b801c1e8"}, + {file = "cffi-1.15.1-cp311-cp311-win32.whl", hash = "sha256:a0f100c8912c114ff53e1202d0078b425bee3649ae34d7b070e9697f93c5d52d"}, + {file = "cffi-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:04ed324bda3cda42b9b695d51bb7d54b680b9719cfab04227cdd1e04e5de3104"}, + {file = "cffi-1.15.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50a74364d85fd319352182ef59c5c790484a336f6db772c1a9231f1c3ed0cbd7"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e263d77ee3dd201c3a142934a086a4450861778baaeeb45db4591ef65550b0a6"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cec7d9412a9102bdc577382c3929b337320c4c4c4849f2c5cdd14d7368c5562d"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4289fc34b2f5316fbb762d75362931e351941fa95fa18789191b33fc4cf9504a"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:173379135477dc8cac4bc58f45db08ab45d228b3363adb7af79436135d028405"}, + {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6975a3fac6bc83c4a65c9f9fcab9e47019a11d3d2cf7f3c0d03431bf145a941e"}, + {file = "cffi-1.15.1-cp36-cp36m-win32.whl", hash = "sha256:2470043b93ff09bf8fb1d46d1cb756ce6132c54826661a32d4e4d132e1977adf"}, + {file = "cffi-1.15.1-cp36-cp36m-win_amd64.whl", hash = "sha256:30d78fbc8ebf9c92c9b7823ee18eb92f2e6ef79b45ac84db507f52fbe3ec4497"}, + {file = "cffi-1.15.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:198caafb44239b60e252492445da556afafc7d1e3ab7a1fb3f0584ef6d742375"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5ef34d190326c3b1f822a5b7a45f6c4535e2f47ed06fec77d3d799c450b2651e"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8102eaf27e1e448db915d08afa8b41d6c7ca7a04b7d73af6514df10a3e74bd82"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5df2768244d19ab7f60546d0c7c63ce1581f7af8b5de3eb3004b9b6fc8a9f84b"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8c4917bd7ad33e8eb21e9a5bbba979b49d9a97acb3a803092cbc1133e20343c"}, + {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2642fe3142e4cc4af0799748233ad6da94c62a8bec3a6648bf8ee68b1c7426"}, + {file = "cffi-1.15.1-cp37-cp37m-win32.whl", hash = "sha256:e229a521186c75c8ad9490854fd8bbdd9a0c9aa3a524326b55be83b54d4e0ad9"}, + {file = "cffi-1.15.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a0b71b1b8fbf2b96e41c4d990244165e2c9be83d54962a9a1d118fd8657d2045"}, + {file = "cffi-1.15.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:320dab6e7cb2eacdf0e658569d2575c4dad258c0fcc794f46215e1e39f90f2c3"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e74c6b51a9ed6589199c787bf5f9875612ca4a8a0785fb2d4a84429badaf22a"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5c84c68147988265e60416b57fc83425a78058853509c1b0629c180094904a5"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b926aa83d1edb5aa5b427b4053dc420ec295a08e40911296b9eb1b6170f6cca"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:87c450779d0914f2861b8526e035c5e6da0a3199d8f1add1a665e1cbc6fc6d02"}, + {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f2c9f67e9821cad2e5f480bc8d83b8742896f1242dba247911072d4fa94c192"}, + {file = "cffi-1.15.1-cp38-cp38-win32.whl", hash = "sha256:8b7ee99e510d7b66cdb6c593f21c043c248537a32e0bedf02e01e9553a172314"}, + {file = "cffi-1.15.1-cp38-cp38-win_amd64.whl", hash = "sha256:00a9ed42e88df81ffae7a8ab6d9356b371399b91dbdf0c3cb1e84c03a13aceb5"}, + {file = "cffi-1.15.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:54a2db7b78338edd780e7ef7f9f6c442500fb0d41a5a4ea24fff1c929d5af585"}, + {file = "cffi-1.15.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fcd131dd944808b5bdb38e6f5b53013c5aa4f334c5cad0c72742f6eba4b73db0"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7473e861101c9e72452f9bf8acb984947aa1661a7704553a9f6e4baa5ba64415"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c9a799e985904922a4d207a94eae35c78ebae90e128f0c4e521ce339396be9d"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3bcde07039e586f91b45c88f8583ea7cf7a0770df3a1649627bf598332cb6984"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33ab79603146aace82c2427da5ca6e58f2b3f2fb5da893ceac0c42218a40be35"}, + {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d598b938678ebf3c67377cdd45e09d431369c3b1a5b331058c338e201f12b27"}, + {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db0fbb9c62743ce59a9ff687eb5f4afbe77e5e8403d6697f7446e5f609976f76"}, + {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:98d85c6a2bef81588d9227dde12db8a7f47f639f4a17c9ae08e773aa9c697bf3"}, + {file = "cffi-1.15.1-cp39-cp39-win32.whl", hash = "sha256:40f4774f5a9d4f5e344f31a32b5096977b5d48560c5592e2f3d2c4374bd543ee"}, + {file = "cffi-1.15.1-cp39-cp39-win_amd64.whl", hash = "sha256:70df4e3b545a17496c9b3f41f5115e69a4f2e77e94e1d2a8e1070bc0c38c8a3c"}, + {file = "cffi-1.15.1.tar.gz", hash = "sha256:d400bfb9a37b1351253cb402671cea7e89bdecc294e8016a707f6d1d8ac934f9"}, +] + +[package.dependencies] +pycparser = "*" + [[package]] name = "click" version = "8.1.7" @@ -203,6 +318,63 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "comm" +version = "0.1.4" +description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +optional = false +python-versions = ">=3.6" +files = [ + {file = "comm-0.1.4-py3-none-any.whl", hash = "sha256:6d52794cba11b36ed9860999cd10fd02d6b2eac177068fdd585e1e2f8a96e67a"}, + {file = "comm-0.1.4.tar.gz", hash = "sha256:354e40a59c9dd6db50c5cc6b4acc887d82e9603787f83b68c01a80a923984d15"}, +] + +[package.dependencies] +traitlets = ">=4" + +[package.extras] +lint = ["black (>=22.6.0)", "mdformat (>0.7)", "mdformat-gfm (>=0.3.5)", "ruff (>=0.0.156)"] +test = ["pytest"] +typing = ["mypy (>=0.990)"] + +[[package]] +name = "debugpy" +version = "1.6.7.post1" +description = "An implementation of the Debug Adapter Protocol for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "debugpy-1.6.7.post1-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:903bd61d5eb433b6c25b48eae5e23821d4c1a19e25c9610205f5aeaccae64e32"}, + {file = "debugpy-1.6.7.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d16882030860081e7dd5aa619f30dec3c2f9a421e69861125f83cc372c94e57d"}, + {file = "debugpy-1.6.7.post1-cp310-cp310-win32.whl", hash = "sha256:eea8d8cfb9965ac41b99a61f8e755a8f50e9a20330938ad8271530210f54e09c"}, + {file = "debugpy-1.6.7.post1-cp310-cp310-win_amd64.whl", hash = "sha256:85969d864c45f70c3996067cfa76a319bae749b04171f2cdeceebe4add316155"}, + {file = "debugpy-1.6.7.post1-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:890f7ab9a683886a0f185786ffbda3b46495c4b929dab083b8c79d6825832a52"}, + {file = "debugpy-1.6.7.post1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4ac7a4dba28801d184b7fc0e024da2635ca87d8b0a825c6087bb5168e3c0d28"}, + {file = "debugpy-1.6.7.post1-cp37-cp37m-win32.whl", hash = "sha256:3370ef1b9951d15799ef7af41f8174194f3482ee689988379763ef61a5456426"}, + {file = "debugpy-1.6.7.post1-cp37-cp37m-win_amd64.whl", hash = "sha256:65b28435a17cba4c09e739621173ff90c515f7b9e8ea469b92e3c28ef8e5cdfb"}, + {file = "debugpy-1.6.7.post1-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:92b6dae8bfbd497c90596bbb69089acf7954164aea3228a99d7e43e5267f5b36"}, + {file = "debugpy-1.6.7.post1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72f5d2ecead8125cf669e62784ef1e6300f4067b0f14d9f95ee00ae06fc7c4f7"}, + {file = "debugpy-1.6.7.post1-cp38-cp38-win32.whl", hash = "sha256:f0851403030f3975d6e2eaa4abf73232ab90b98f041e3c09ba33be2beda43fcf"}, + {file = "debugpy-1.6.7.post1-cp38-cp38-win_amd64.whl", hash = "sha256:3de5d0f97c425dc49bce4293df6a04494309eedadd2b52c22e58d95107e178d9"}, + {file = "debugpy-1.6.7.post1-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:38651c3639a4e8bbf0ca7e52d799f6abd07d622a193c406be375da4d510d968d"}, + {file = "debugpy-1.6.7.post1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038c51268367c9c935905a90b1c2d2dbfe304037c27ba9d19fe7409f8cdc710c"}, + {file = "debugpy-1.6.7.post1-cp39-cp39-win32.whl", hash = "sha256:4b9eba71c290852f959d2cf8a03af28afd3ca639ad374d393d53d367f7f685b2"}, + {file = "debugpy-1.6.7.post1-cp39-cp39-win_amd64.whl", hash = "sha256:973a97ed3b434eab0f792719a484566c35328196540676685c975651266fccf9"}, + {file = "debugpy-1.6.7.post1-py2.py3-none-any.whl", hash = "sha256:1093a5c541af079c13ac8c70ab8b24d1d35c8cacb676306cf11e57f699c02926"}, + {file = "debugpy-1.6.7.post1.zip", hash = "sha256:fe87ec0182ef624855d05e6ed7e0b7cb1359d2ffa2a925f8ec2d22e98b75d0ca"}, +] + +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + [[package]] name = "dill" version = "0.3.7" @@ -217,6 +389,34 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] +[[package]] +name = "exceptiongroup" +version = "1.1.3" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"}, + {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "executing" +version = "1.2.0" +description = "Get the currently executing AST node of a frame, and other information" +optional = false +python-versions = "*" +files = [ + {file = "executing-1.2.0-py2.py3-none-any.whl", hash = "sha256:0314a69e37426e3608aada02473b4161d4caf5a4b244d1d0c48072b8fee7bacc"}, + {file = "executing-1.2.0.tar.gz", hash = "sha256:19da64c18d2d851112f09c287f8d3dbbdf725ab0e569077efb6cdcbd3497c107"}, +] + +[package.extras] +tests = ["asttokens", "littleutils", "pytest", "rich"] + [[package]] name = "filelock" version = "3.12.2" @@ -232,6 +432,78 @@ files = [ docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] +[[package]] +name = "ipykernel" +version = "6.25.1" +description = "IPython Kernel for Jupyter" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ipykernel-6.25.1-py3-none-any.whl", hash = "sha256:c8a2430b357073b37c76c21c52184db42f6b4b0e438e1eb7df3c4440d120497c"}, + {file = "ipykernel-6.25.1.tar.gz", hash = "sha256:050391364c0977e768e354bdb60cbbfbee7cbb943b1af1618382021136ffd42f"}, +] + +[package.dependencies] +appnope = {version = "*", markers = "platform_system == \"Darwin\""} +comm = ">=0.1.1" +debugpy = ">=1.6.5" +ipython = ">=7.23.1" +jupyter-client = ">=6.1.12" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +matplotlib-inline = ">=0.1" +nest-asyncio = "*" +packaging = "*" +psutil = "*" +pyzmq = ">=20" +tornado = ">=6.1" +traitlets = ">=5.4.0" + +[package.extras] +cov = ["coverage[toml]", "curio", "matplotlib", "pytest-cov", "trio"] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "trio"] +pyqt5 = ["pyqt5"] +pyside6 = ["pyside6"] +test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov", "pytest-timeout"] + +[[package]] +name = "ipython" +version = "8.15.0" +description = "IPython: Productive Interactive Computing" +optional = false +python-versions = ">=3.9" +files = [ + {file = "ipython-8.15.0-py3-none-any.whl", hash = "sha256:45a2c3a529296870a97b7de34eda4a31bee16bc7bf954e07d39abe49caf8f887"}, + {file = "ipython-8.15.0.tar.gz", hash = "sha256:2baeb5be6949eeebf532150f81746f8333e2ccce02de1c7eedde3f23ed5e9f1e"}, +] + +[package.dependencies] +appnope = {version = "*", markers = "sys_platform == \"darwin\""} +backcall = "*" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +decorator = "*" +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} +jedi = ">=0.16" +matplotlib-inline = "*" +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} +pickleshare = "*" +prompt-toolkit = ">=3.0.30,<3.0.37 || >3.0.37,<3.1.0" +pygments = ">=2.4.0" +stack-data = "*" +traitlets = ">=5" + +[package.extras] +all = ["black", "curio", "docrepr", "exceptiongroup", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.21)", "pandas", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] +black = ["black"] +doc = ["docrepr", "exceptiongroup", "ipykernel", "matplotlib", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "typing-extensions"] +kernel = ["ipykernel"] +nbconvert = ["nbconvert"] +nbformat = ["nbformat"] +notebook = ["ipywidgets", "notebook"] +parallel = ["ipyparallel"] +qtconsole = ["qtconsole"] +test = ["pytest (<7.1)", "pytest-asyncio", "testpath"] +test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"] + [[package]] name = "isort" version = "5.12.0" @@ -249,6 +521,25 @@ pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib" plugins = ["setuptools"] requirements-deprecated-finder = ["pip-api", "pipreqs"] +[[package]] +name = "jedi" +version = "0.19.0" +description = "An autocompletion tool for Python that can be used for text editors." +optional = false +python-versions = ">=3.6" +files = [ + {file = "jedi-0.19.0-py2.py3-none-any.whl", hash = "sha256:cb8ce23fbccff0025e9386b5cf85e892f94c9b822378f8da49970471335ac64e"}, + {file = "jedi-0.19.0.tar.gz", hash = "sha256:bcf9894f1753969cbac8022a8c2eaee06bfa3724e4192470aaffe7eb6272b0c4"}, +] + +[package.dependencies] +parso = ">=0.8.3,<0.9.0" + +[package.extras] +docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] +testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] + [[package]] name = "jinja2" version = "3.1.2" @@ -266,6 +557,48 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jupyter-client" +version = "8.3.1" +description = "Jupyter protocol implementation and client libraries" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_client-8.3.1-py3-none-any.whl", hash = "sha256:5eb9f55eb0650e81de6b7e34308d8b92d04fe4ec41cd8193a913979e33d8e1a5"}, + {file = "jupyter_client-8.3.1.tar.gz", hash = "sha256:60294b2d5b869356c893f57b1a877ea6510d60d45cf4b38057f1672d85699ac9"}, +] + +[package.dependencies] +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +python-dateutil = ">=2.8.2" +pyzmq = ">=23.0" +tornado = ">=6.2" +traitlets = ">=5.3" + +[package.extras] +docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] +test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] + +[[package]] +name = "jupyter-core" +version = "5.3.1" +description = "Jupyter core package. A base package on which Jupyter projects rely." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_core-5.3.1-py3-none-any.whl", hash = "sha256:ae9036db959a71ec1cac33081eeb040a79e681f08ab68b0883e9a676c7a90dce"}, + {file = "jupyter_core-5.3.1.tar.gz", hash = "sha256:5ba5c7938a7f97a6b0481463f7ff0dbac7c15ba48cf46fa4035ca6e838aa1aba"}, +] + +[package.dependencies] +platformdirs = ">=2.5" +pywin32 = {version = ">=300", markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\""} +traitlets = ">=5.3" + +[package.extras] +docs = ["myst-parser", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] +test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] + [[package]] name = "lazy-object-proxy" version = "1.9.0" @@ -380,6 +713,20 @@ files = [ {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, ] +[[package]] +name = "matplotlib-inline" +version = "0.1.6" +description = "Inline Matplotlib backend for Jupyter" +optional = false +python-versions = ">=3.5" +files = [ + {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"}, + {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"}, +] + +[package.dependencies] +traitlets = "*" + [[package]] name = "mccabe" version = "0.7.0" @@ -503,6 +850,17 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "nest-asyncio" +version = "1.5.7" +description = "Patch asyncio to allow nested event loops" +optional = false +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.5.7-py3-none-any.whl", hash = "sha256:5301c82941b550b3123a1ea772ba9a1c80bad3a182be8c1a5ae6ad3be57a9657"}, + {file = "nest_asyncio-1.5.7.tar.gz", hash = "sha256:6a80f7b98f24d9083ed24608977c09dd608d83f91cccc24c9d2cba6d10e01c10"}, +] + [[package]] name = "networkx" version = "3.1" @@ -724,6 +1082,21 @@ files = [ {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, ] +[[package]] +name = "parso" +version = "0.8.3" +description = "A Python Parser" +optional = false +python-versions = ">=3.6" +files = [ + {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, + {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, +] + +[package.extras] +qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] +testing = ["docopt", "pytest (<6.0.0)"] + [[package]] name = "pathspec" version = "0.11.2" @@ -769,6 +1142,31 @@ files = [ python-dateutil = ">=2.6,<3.0" pytzdata = ">=2020.1" +[[package]] +name = "pexpect" +version = "4.8.0" +description = "Pexpect allows easy control of interactive console applications." +optional = false +python-versions = "*" +files = [ + {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, + {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, +] + +[package.dependencies] +ptyprocess = ">=0.5" + +[[package]] +name = "pickleshare" +version = "0.7.5" +description = "Tiny 'shelve'-like database with concurrency support" +optional = false +python-versions = "*" +files = [ + {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"}, + {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, +] + [[package]] name = "platformdirs" version = "3.10.0" @@ -795,6 +1193,96 @@ files = [ {file = "pprintpp-0.4.0.tar.gz", hash = "sha256:ea826108e2c7f49dc6d66c752973c3fc9749142a798d6b254e1e301cfdbc6403"}, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.39" +description = "Library for building powerful interactive command lines in Python" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"}, + {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"}, +] + +[package.dependencies] +wcwidth = "*" + +[[package]] +name = "psutil" +version = "5.9.5" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"}, + {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"}, + {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"}, + {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"}, + {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"}, + {file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"}, + {file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"}, + {file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"}, + {file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"}, + {file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"}, + {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"}, + {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + +[[package]] +name = "ptyprocess" +version = "0.7.0" +description = "Run a subprocess in a pseudo terminal" +optional = false +python-versions = "*" +files = [ + {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, + {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, +] + +[[package]] +name = "pure-eval" +version = "0.2.2" +description = "Safely evaluate AST nodes without side effects" +optional = false +python-versions = "*" +files = [ + {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, + {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"}, +] + +[package.extras] +tests = ["pytest"] + +[[package]] +name = "pycparser" +version = "2.21" +description = "C parser in Python" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, + {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, +] + +[[package]] +name = "pygments" +version = "2.16.1" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "Pygments-2.16.1-py3-none-any.whl", hash = "sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692"}, + {file = "Pygments-2.16.1.tar.gz", hash = "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29"}, +] + +[package.extras] +plugins = ["importlib-metadata"] + [[package]] name = "pylint" version = "2.17.5" @@ -848,6 +1336,134 @@ files = [ {file = "pytzdata-2020.1.tar.gz", hash = "sha256:3efa13b335a00a8de1d345ae41ec78dd11c9f8807f522d39850f2dd828681540"}, ] +[[package]] +name = "pywin32" +version = "306" +description = "Python for Window Extensions" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, + {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, + {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, + {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, + {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, + {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, + {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, + {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, + {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, + {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, + {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, + {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, + {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, + {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, +] + +[[package]] +name = "pyzmq" +version = "25.1.1" +description = "Python bindings for 0MQ" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:381469297409c5adf9a0e884c5eb5186ed33137badcbbb0560b86e910a2f1e76"}, + {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:955215ed0604dac5b01907424dfa28b40f2b2292d6493445dd34d0dfa72586a8"}, + {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:985bbb1316192b98f32e25e7b9958088431d853ac63aca1d2c236f40afb17c83"}, + {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:afea96f64efa98df4da6958bae37f1cbea7932c35878b185e5982821bc883369"}, + {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76705c9325d72a81155bb6ab48d4312e0032bf045fb0754889133200f7a0d849"}, + {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:77a41c26205d2353a4c94d02be51d6cbdf63c06fbc1295ea57dad7e2d3381b71"}, + {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:12720a53e61c3b99d87262294e2b375c915fea93c31fc2336898c26d7aed34cd"}, + {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:57459b68e5cd85b0be8184382cefd91959cafe79ae019e6b1ae6e2ba8a12cda7"}, + {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:292fe3fc5ad4a75bc8df0dfaee7d0babe8b1f4ceb596437213821f761b4589f9"}, + {file = "pyzmq-25.1.1-cp310-cp310-win32.whl", hash = "sha256:35b5ab8c28978fbbb86ea54958cd89f5176ce747c1fb3d87356cf698048a7790"}, + {file = "pyzmq-25.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:11baebdd5fc5b475d484195e49bae2dc64b94a5208f7c89954e9e354fc609d8f"}, + {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:d20a0ddb3e989e8807d83225a27e5c2eb2260eaa851532086e9e0fa0d5287d83"}, + {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e1c1be77bc5fb77d923850f82e55a928f8638f64a61f00ff18a67c7404faf008"}, + {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d89528b4943d27029a2818f847c10c2cecc79fa9590f3cb1860459a5be7933eb"}, + {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90f26dc6d5f241ba358bef79be9ce06de58d477ca8485e3291675436d3827cf8"}, + {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2b92812bd214018e50b6380ea3ac0c8bb01ac07fcc14c5f86a5bb25e74026e9"}, + {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2f957ce63d13c28730f7fd6b72333814221c84ca2421298f66e5143f81c9f91f"}, + {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:047a640f5c9c6ade7b1cc6680a0e28c9dd5a0825135acbd3569cc96ea00b2505"}, + {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7f7e58effd14b641c5e4dec8c7dab02fb67a13df90329e61c869b9cc607ef752"}, + {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c2910967e6ab16bf6fbeb1f771c89a7050947221ae12a5b0b60f3bca2ee19bca"}, + {file = "pyzmq-25.1.1-cp311-cp311-win32.whl", hash = "sha256:76c1c8efb3ca3a1818b837aea423ff8a07bbf7aafe9f2f6582b61a0458b1a329"}, + {file = "pyzmq-25.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:44e58a0554b21fc662f2712814a746635ed668d0fbc98b7cb9d74cb798d202e6"}, + {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:e1ffa1c924e8c72778b9ccd386a7067cddf626884fd8277f503c48bb5f51c762"}, + {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1af379b33ef33757224da93e9da62e6471cf4a66d10078cf32bae8127d3d0d4a"}, + {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cff084c6933680d1f8b2f3b4ff5bbb88538a4aac00d199ac13f49d0698727ecb"}, + {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2400a94f7dd9cb20cd012951a0cbf8249e3d554c63a9c0cdfd5cbb6c01d2dec"}, + {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d81f1ddae3858b8299d1da72dd7d19dd36aab654c19671aa8a7e7fb02f6638a"}, + {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:255ca2b219f9e5a3a9ef3081512e1358bd4760ce77828e1028b818ff5610b87b"}, + {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a882ac0a351288dd18ecae3326b8a49d10c61a68b01419f3a0b9a306190baf69"}, + {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:724c292bb26365659fc434e9567b3f1adbdb5e8d640c936ed901f49e03e5d32e"}, + {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ca1ed0bb2d850aa8471387882247c68f1e62a4af0ce9c8a1dbe0d2bf69e41fb"}, + {file = "pyzmq-25.1.1-cp312-cp312-win32.whl", hash = "sha256:b3451108ab861040754fa5208bca4a5496c65875710f76789a9ad27c801a0075"}, + {file = "pyzmq-25.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:eadbefd5e92ef8a345f0525b5cfd01cf4e4cc651a2cffb8f23c0dd184975d787"}, + {file = "pyzmq-25.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:db0b2af416ba735c6304c47f75d348f498b92952f5e3e8bff449336d2728795d"}, + {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7c133e93b405eb0d36fa430c94185bdd13c36204a8635470cccc200723c13bb"}, + {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:273bc3959bcbff3f48606b28229b4721716598d76b5aaea2b4a9d0ab454ec062"}, + {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cbc8df5c6a88ba5ae385d8930da02201165408dde8d8322072e3e5ddd4f68e22"}, + {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:18d43df3f2302d836f2a56f17e5663e398416e9dd74b205b179065e61f1a6edf"}, + {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:73461eed88a88c866656e08f89299720a38cb4e9d34ae6bf5df6f71102570f2e"}, + {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:34c850ce7976d19ebe7b9d4b9bb8c9dfc7aac336c0958e2651b88cbd46682123"}, + {file = "pyzmq-25.1.1-cp36-cp36m-win32.whl", hash = "sha256:d2045d6d9439a0078f2a34b57c7b18c4a6aef0bee37f22e4ec9f32456c852c71"}, + {file = "pyzmq-25.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:458dea649f2f02a0b244ae6aef8dc29325a2810aa26b07af8374dc2a9faf57e3"}, + {file = "pyzmq-25.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7cff25c5b315e63b07a36f0c2bab32c58eafbe57d0dce61b614ef4c76058c115"}, + {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1579413ae492b05de5a6174574f8c44c2b9b122a42015c5292afa4be2507f28"}, + {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3d0a409d3b28607cc427aa5c30a6f1e4452cc44e311f843e05edb28ab5e36da0"}, + {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:21eb4e609a154a57c520e3d5bfa0d97e49b6872ea057b7c85257b11e78068222"}, + {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:034239843541ef7a1aee0c7b2cb7f6aafffb005ede965ae9cbd49d5ff4ff73cf"}, + {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f8115e303280ba09f3898194791a153862cbf9eef722ad8f7f741987ee2a97c7"}, + {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1a5d26fe8f32f137e784f768143728438877d69a586ddeaad898558dc971a5ae"}, + {file = "pyzmq-25.1.1-cp37-cp37m-win32.whl", hash = "sha256:f32260e556a983bc5c7ed588d04c942c9a8f9c2e99213fec11a031e316874c7e"}, + {file = "pyzmq-25.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:abf34e43c531bbb510ae7e8f5b2b1f2a8ab93219510e2b287a944432fad135f3"}, + {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:87e34f31ca8f168c56d6fbf99692cc8d3b445abb5bfd08c229ae992d7547a92a"}, + {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c9c6c9b2c2f80747a98f34ef491c4d7b1a8d4853937bb1492774992a120f475d"}, + {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5619f3f5a4db5dbb572b095ea3cb5cc035335159d9da950830c9c4db2fbb6995"}, + {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5a34d2395073ef862b4032343cf0c32a712f3ab49d7ec4f42c9661e0294d106f"}, + {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25f0e6b78220aba09815cd1f3a32b9c7cb3e02cb846d1cfc526b6595f6046618"}, + {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3669cf8ee3520c2f13b2e0351c41fea919852b220988d2049249db10046a7afb"}, + {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2d163a18819277e49911f7461567bda923461c50b19d169a062536fffe7cd9d2"}, + {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:df27ffddff4190667d40de7beba4a950b5ce78fe28a7dcc41d6f8a700a80a3c0"}, + {file = "pyzmq-25.1.1-cp38-cp38-win32.whl", hash = "sha256:a382372898a07479bd34bda781008e4a954ed8750f17891e794521c3e21c2e1c"}, + {file = "pyzmq-25.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:52533489f28d62eb1258a965f2aba28a82aa747202c8fa5a1c7a43b5db0e85c1"}, + {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:03b3f49b57264909aacd0741892f2aecf2f51fb053e7d8ac6767f6c700832f45"}, + {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:330f9e188d0d89080cde66dc7470f57d1926ff2fb5576227f14d5be7ab30b9fa"}, + {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2ca57a5be0389f2a65e6d3bb2962a971688cbdd30b4c0bd188c99e39c234f414"}, + {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d457aed310f2670f59cc5b57dcfced452aeeed77f9da2b9763616bd57e4dbaae"}, + {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c56d748ea50215abef7030c72b60dd723ed5b5c7e65e7bc2504e77843631c1a6"}, + {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f03d3f0d01cb5a018debeb412441996a517b11c5c17ab2001aa0597c6d6882c"}, + {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:820c4a08195a681252f46926de10e29b6bbf3e17b30037bd4250d72dd3ddaab8"}, + {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:17ef5f01d25b67ca8f98120d5fa1d21efe9611604e8eb03a5147360f517dd1e2"}, + {file = "pyzmq-25.1.1-cp39-cp39-win32.whl", hash = "sha256:04ccbed567171579ec2cebb9c8a3e30801723c575601f9a990ab25bcac6b51e2"}, + {file = "pyzmq-25.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:e61f091c3ba0c3578411ef505992d356a812fb200643eab27f4f70eed34a29ef"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ade6d25bb29c4555d718ac6d1443a7386595528c33d6b133b258f65f963bb0f6"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0c95ddd4f6e9fca4e9e3afaa4f9df8552f0ba5d1004e89ef0a68e1f1f9807c7"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48e466162a24daf86f6b5ca72444d2bf39a5e58da5f96370078be67c67adc978"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abc719161780932c4e11aaebb203be3d6acc6b38d2f26c0f523b5b59d2fc1996"}, + {file = "pyzmq-25.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ccf825981640b8c34ae54231b7ed00271822ea1c6d8ba1090ebd4943759abf5"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c2f20ce161ebdb0091a10c9ca0372e023ce24980d0e1f810f519da6f79c60800"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:deee9ca4727f53464daf089536e68b13e6104e84a37820a88b0a057b97bba2d2"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aa8d6cdc8b8aa19ceb319aaa2b660cdaccc533ec477eeb1309e2a291eaacc43a"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019e59ef5c5256a2c7378f2fb8560fc2a9ff1d315755204295b2eab96b254d0a"}, + {file = "pyzmq-25.1.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:b9af3757495c1ee3b5c4e945c1df7be95562277c6e5bccc20a39aec50f826cd0"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:548d6482dc8aadbe7e79d1b5806585c8120bafa1ef841167bc9090522b610fa6"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:057e824b2aae50accc0f9a0570998adc021b372478a921506fddd6c02e60308e"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2243700cc5548cff20963f0ca92d3e5e436394375ab8a354bbea2b12911b20b0"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79986f3b4af059777111409ee517da24a529bdbd46da578b33f25580adcff728"}, + {file = "pyzmq-25.1.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:11d58723d44d6ed4dd677c5615b2ffb19d5c426636345567d6af82be4dff8a55"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:49d238cf4b69652257db66d0c623cd3e09b5d2e9576b56bc067a396133a00d4a"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fedbdc753827cf014c01dbbee9c3be17e5a208dcd1bf8641ce2cd29580d1f0d4"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc16ac425cc927d0a57d242589f87ee093884ea4804c05a13834d07c20db203c"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11c1d2aed9079c6b0c9550a7257a836b4a637feb334904610f06d70eb44c56d2"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e8a701123029cc240cea61dd2d16ad57cab4691804143ce80ecd9286b464d180"}, + {file = "pyzmq-25.1.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:61706a6b6c24bdece85ff177fec393545a3191eeda35b07aaa1458a027ad1304"}, + {file = "pyzmq-25.1.1.tar.gz", hash = "sha256:259c22485b71abacdfa8bf79720cd7bcf4b9d128b30ea554f01ae71fdbfdaa23"}, +] + +[package.dependencies] +cffi = {version = "*", markers = "implementation_name == \"pypy\""} + [[package]] name = "ruff" version = "0.0.285" @@ -901,6 +1517,25 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "stack-data" +version = "0.6.2" +description = "Extract data from python stack frames and tracebacks for informative displays" +optional = false +python-versions = "*" +files = [ + {file = "stack_data-0.6.2-py3-none-any.whl", hash = "sha256:cbb2a53eb64e5785878201a97ed7c7b94883f48b87bfb0bbe8b623c74679e4a8"}, + {file = "stack_data-0.6.2.tar.gz", hash = "sha256:32d2dd0376772d01b6cb9fc996f3c8b57a357089dec328ed4b6553d037eaf815"}, +] + +[package.dependencies] +asttokens = ">=2.1.0" +executing = ">=1.2.0" +pure-eval = "*" + +[package.extras] +tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] + [[package]] name = "sympy" version = "1.12" @@ -1101,6 +1736,26 @@ files = [ [package.dependencies] torch = "2.0.0" +[[package]] +name = "tornado" +version = "6.3.3" +description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +optional = false +python-versions = ">= 3.8" +files = [ + {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:502fba735c84450974fec147340016ad928d29f1e91f49be168c0a4c18181e1d"}, + {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:805d507b1f588320c26f7f097108eb4023bbaa984d63176d1652e184ba24270a"}, + {file = "tornado-6.3.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bd19ca6c16882e4d37368e0152f99c099bad93e0950ce55e71daed74045908f"}, + {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ac51f42808cca9b3613f51ffe2a965c8525cb1b00b7b2d56828b8045354f76a"}, + {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71a8db65160a3c55d61839b7302a9a400074c9c753040455494e2af74e2501f2"}, + {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:ceb917a50cd35882b57600709dd5421a418c29ddc852da8bcdab1f0db33406b0"}, + {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:7d01abc57ea0dbb51ddfed477dfe22719d376119844e33c661d873bf9c0e4a16"}, + {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:9dc4444c0defcd3929d5c1eb5706cbe1b116e762ff3e0deca8b715d14bf6ec17"}, + {file = "tornado-6.3.3-cp38-abi3-win32.whl", hash = "sha256:65ceca9500383fbdf33a98c0087cb975b2ef3bfb874cb35b8de8740cf7f41bd3"}, + {file = "tornado-6.3.3-cp38-abi3-win_amd64.whl", hash = "sha256:22d3c2fa10b5793da13c807e6fc38ff49a4f6e1e3868b0a6f4164768bb8e20f5"}, + {file = "tornado-6.3.3.tar.gz", hash = "sha256:e7d8db41c0181c80d76c982aacc442c0783a2c54d6400fe028954201a2e032fe"}, +] + [[package]] name = "tqdm" version = "4.66.1" @@ -1121,6 +1776,21 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "traitlets" +version = "5.9.0" +description = "Traitlets Python configuration system" +optional = false +python-versions = ">=3.7" +files = [ + {file = "traitlets-5.9.0-py3-none-any.whl", hash = "sha256:9e6ec080259b9a5940c797d58b613b5e31441c2257b87c2e795c5228ae80d2d8"}, + {file = "traitlets-5.9.0.tar.gz", hash = "sha256:f6cde21a9c68cf756af02035f72d5a723bf607e862e7be33ece505abf4a3bad9"}, +] + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] +test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] + [[package]] name = "triton" version = "2.0.0" @@ -1180,6 +1850,17 @@ files = [ {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, ] +[[package]] +name = "wcwidth" +version = "0.2.6" +description = "Measures the displayed width of unicode strings in a terminal" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"}, + {file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"}, +] + [[package]] name = "wheel" version = "0.41.1" @@ -1281,4 +1962,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "992e0eb975f6fb490726bc9db470a682986c1a287078bec407a4a6dc96e45b3c" +content-hash = "48badc3ebe08af723b455abd9067ab0e72d990ffaf686509dd3fc04529dfd508" diff --git a/pyproject.toml b/pyproject.toml index dc136f2..2e69211 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ mypy = "^1.5.1" pylint = "^2.17.5" ruff = "^0.0.285" types-tqdm = "^4.66.0.1" +ipykernel = "^6.25.1" [tool.ruff] select = ["E", "F", "B", "I"] @@ -33,7 +34,7 @@ line-length = 120 target-version = "py310" [tool.black] -line-length = 120 +line-length = 100 [tool.poetry.scripts] train = "swr2_asr.train:run_cli" diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py index a6b0010..16bd54b 100644 --- a/swr2_asr/inference_test.py +++ b/swr2_asr/inference_test.py @@ -44,9 +44,7 @@ def main() -> None: print(model.__class__) # only do all things for one single sample - dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="train", download=True - ) + dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True) print(dataset[0]) diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py index c49cc15..ef37b0a 100644 --- a/swr2_asr/loss_scores.py +++ b/swr2_asr/loss_scores.py @@ -54,9 +54,7 @@ def _levenshtein_distance(ref, hyp): return distance[len_ref % 2][len_hyp] -def word_errors( - reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " -): +def word_errors(reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "): """Compute the levenshtein distance between reference sequence and hypothesis sequence in word-level. :param reference: The reference sentence. @@ -176,9 +174,7 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False): :rtype: float :raises ValueError: If the reference length is zero. """ - edit_distance, ref_len = char_errors( - reference, hypothesis, ignore_case, remove_space - ) + edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space) if ref_len == 0: raise ValueError("Length of reference should be greater than 0.") diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index f00ebd4..dd07ff9 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,3 +1,4 @@ +"""Main definition of model""" import torch.nn.functional as F from torch import nn @@ -30,9 +31,7 @@ class ResidualCNN(nn.Module): ): super().__init__() - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) + self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2) self.cnn2 = nn.Conv2d( out_channels, out_channels, @@ -110,9 +109,7 @@ class SpeechRecognitionModel(nn.Module): # n residual cnn layers with filter size of 32 self.rescnn_layers = nn.Sequential( *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) + ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) for _ in range(n_cnn_layers) ] ) @@ -140,9 +137,7 @@ class SpeechRecognitionModel(nn.Module): data = self.cnn(data) data = self.rescnn_layers(data) sizes = data.size() - data = data.view( - sizes[0], sizes[1] * sizes[2], sizes[3] - ) # (batch, feature, time) + data = data.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) data = data.transpose(1, 2) # (batch, time, feature) data = self.fully_connected(data) data = self.birnn_layers(data) diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 8e3bf09..c8d3793 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -14,16 +14,24 @@ from tqdm import tqdm class TokenizerType: + """Base class for tokenizers. + + exposes the same interface as tokenizers from the huggingface library""" + def encode(self, sequence: str) -> list[int]: + """Encode a sequence to a list of integer labels""" raise NotImplementedError def decode(self, labels: list[int], remove_special_tokens: bool) -> str: + """Decode a list of integer labels to a sequence""" raise NotImplementedError def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Decode a batch of integer labels to a list of sequences""" raise NotImplementedError def get_vocab_size(self) -> int: + """Get the size of the vocabulary""" raise NotImplementedError def enable_padding( @@ -34,17 +42,20 @@ class TokenizerType: pad_type_id: int = 0, pad_token: str = "[PAD]", ) -> None: + """Enable padding for the tokenizer""" raise NotImplementedError def save(self, path: str) -> None: + """Save the tokenizer to a file""" raise NotImplementedError @staticmethod def from_file(path: str) -> "TokenizerType": + """Load the tokenizer from a file""" raise NotImplementedError -tokenizer_type = Type[TokenizerType] +MyTokenizerType = Type[TokenizerType] @dataclass @@ -52,6 +63,7 @@ class Encoding: """Simple dataclass to represent an encoding""" ids: list[int] + tokens: list[str] class CharTokenizer(TokenizerType): @@ -137,7 +149,7 @@ class CharTokenizer(TokenizerType): else: mapped_char = self.char_map[char] int_sequence.append(mapped_char) - return Encoding(ids=int_sequence) + return Encoding(ids=int_sequence, tokens=list(sequence)) def decode(self, labels: list[int], remove_special_tokens: bool = True): """Use a character map and convert integer labels to an text sequence @@ -256,6 +268,7 @@ def train_bpe_tokenizer( bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]")) initial_alphabet = [ + " ", "a", "b", "c", diff --git a/swr2_asr/train.py b/swr2_asr/train.py index f3efd69..95038c2 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,19 +5,19 @@ from typing import TypedDict import click import torch import torch.nn.functional as F -from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader from tqdm import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel -from swr2_asr.tokenizer import train_bpe_tokenizer +from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer # TODO: improve naming of functions + class HParams(TypedDict): """Type for the hyperparameters of the model.""" @@ -33,10 +33,11 @@ class HParams(TypedDict): epochs: int -# TODO: get blank label from tokenizer -def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True): +def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): """Greedily decode a sequence.""" + print("output shape", output.shape) arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + blank_label = tokenizer.encode(" ").ids[0] decodes = [] targets = [] for i, args in enumerate(arg_maxes): @@ -81,28 +82,27 @@ def train( print(f"Epoch: {epoch}") losses = [] for _data in tqdm(train_loader, desc="batches"): - spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) loss.backward() optimizer.step() scheduler.step() iter_meter.step() - + losses.append(loss.item()) print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") return sum(losses) / len(losses) -# TODO: profile this function call -# TODO: only calculate wer and cer at the end, or less often -# TODO: change this to only be a sanity check and calculate measures after training + + def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") @@ -111,21 +111,21 @@ def test(model, device, test_loader, criterion, tokenizer): test_cer, test_wer = [], [] with torch.no_grad(): for _data in test_loader: - spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - # TODO: get rid of this - loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output = output.transpose(0, 1), - labels = labels, - label_lengths= _data["utterance_length"], - tokenizer=tokenizer) + output=output.transpose(0, 1), + labels=labels, + label_lengths=_data["utterance_length"], + tokenizer=tokenizer, + ) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) @@ -135,9 +135,11 @@ def test(model, device, test_loader, criterion, tokenizer): print( f"Test set: Average loss:\ - {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" + {test_loss}, Average CER: {None} Average WER: {None}\n" ) + return test_loss, avg_cer, avg_wer + def run( learning_rate: float, @@ -152,33 +154,34 @@ def run( use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - device = torch.device("mps") + # device = torch.device("mps") # load dataset - # TODO: change this from dev split to train split again (was faster for development) - train_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) - valid_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) - test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None) + train_dataset = MLSDataset( + dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None + ) + valid_dataset = MLSDataset( + dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None + ) - # load tokenizer (bpe by default): - if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"): + # load tokenizer (bpe by default): + if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): print("There is no tokenizer available. Do you want to train it on the dataset?") input("Press Enter to continue...") - train_bpe_tokenizer( + train_char_tokenizer( dataset_path=dataset_path, language=language, split="all", download=False, - out_path="data/tokenizers/bpe_tokenizer_german_3000.json", + out_path="data/tokenizers/char_tokenizer_german.json", vocab_size=3000, ) - - tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - - train_dataset.set_tokenizer(tokenizer) - valid_dataset.set_tokenizer(tokenizer) - test_dataset.set_tokenizer(tokenizer) - + + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + + train_dataset.set_tokenizer(tokenizer) # type: ignore + valid_dataset.set_tokenizer(tokenizer) # type: ignore + print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") hparams = HParams( @@ -221,10 +224,10 @@ def run( hparams["stride"], hparams["dropout"], ).to(device) - + print(tokenizer.encode(" ")) print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(blank=28).to(device) + criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) if load: checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) @@ -240,7 +243,7 @@ def run( ) iter_meter = IterMeter() - for epoch in range(1, epochs + 1): + for epoch in range(1, epochs + 1): loss = train( model, device, @@ -252,7 +255,13 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer) + test( + model=model, + device=device, + test_loader=valid_loader, + criterion=criterion, + tokenizer=tokenizer, + ) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -295,5 +304,6 @@ def run_cli( language="mls_german_opus", ) -if __name__ == "__main__": - run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") \ No newline at end of file + +if __name__ == "__main__": + run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 3b9b3ca..efecb56 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -8,7 +8,6 @@ import torch import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset -from tqdm import tqdm from swr2_asr.tokenizer import TokenizerType @@ -24,26 +23,26 @@ class MLSSplit(str, Enum): """Enum specifying dataset as they are defined in the Multilingual LibriSpeech dataset""" - train = "train" - test = "test" - dev = "dev" + TRAIN = "train" + TEST = "test" + DEV = "dev" class Split(str, Enum): """Extending the MLSSplit class to allow for a custom validatio split""" - train = "train" - valid = "valid" - test = "test" - dev = "dev" + TRAIN = "train" + VALID = "valid" + TEST = "test" + DEV = "dev" -def split_to_mls_split(split: Split) -> MLSSplit: +def split_to_mls_split(split_name: Split) -> MLSSplit: """Converts the custom split to a MLSSplit""" - if split == Split.valid: - return MLSSplit.train + if split_name == Split.VALID: + return MLSSplit.TRAIN else: - return split # type: ignore + return split_name # type: ignore class Sample(TypedDict): @@ -89,14 +88,21 @@ class MLSDataset(Dataset): __ """ - def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None): + def __init__( + self, + dataset_path: str, + language: str, + split: Split, + download: bool, + spectrogram_hparams: dict | None, + ): """Initializes the dataset""" self.dataset_path = dataset_path self.language = language self.file_ext = ".opus" if "opus" in language else ".flac" self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally - + if spectrogram_hparams is None: self.spectrogram_hparams = { "sample_rate": 16000, @@ -110,7 +116,7 @@ class MLSDataset(Dataset): } else: self.spectrogram_hparams = spectrogram_hparams - + self.dataset_lookup = [] self.tokenizer: type[TokenizerType] @@ -118,13 +124,14 @@ class MLSDataset(Dataset): self._validate_local_directory() self.initialize() - def initialize(self) -> None: """Initializes the dataset - + Reads the transcripts.txt file and creates a lookup table """ - transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt") + transcripts_path = os.path.join( + self.dataset_path, self.language, self.mls_split, "transcripts.txt" + ) with open(transcripts_path, "r", encoding="utf-8") as script_file: # read all lines in transcripts.txt @@ -135,12 +142,12 @@ class MLSDataset(Dataset): identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore identifier = [path.split("_") for path in identifier] - if self.split == Split.valid: + if self.split == Split.VALID: np.random.seed(42) indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) utterances = [utterances[i] for i in indices] identifier = [identifier[i] for i in indices] - elif self.split == Split.train: + elif self.split == Split.TRAIN: np.random.seed(42) indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) utterances = [utterances[i] for i in indices] @@ -212,13 +219,19 @@ class MLSDataset(Dataset): # resample if necessary if sample_rate != self.spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"]) + resampler = torchaudio.transforms.Resample( + sample_rate, self.spectrogram_hparams["sample_rate"] + ) waveform = resampler(waveform) - spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1) + spec = ( + torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform) + .squeeze(0) + .transpose(0, 1) + ) input_length = spec.shape[0] // 2 - + utterance_length = len(utterance) utterance = self.tokenizer.encode(utterance) @@ -236,11 +249,11 @@ class MLSDataset(Dataset): book_id=self.dataset_lookup[idx]["bookid"], chapter_id=self.dataset_lookup[idx]["chapterid"], ) - + def collate_fn(samples: list[Sample]) -> dict: """Collate function for the dataloader - + pads all tensors within a batch to the same dimensions """ waveforms = [] @@ -248,18 +261,20 @@ def collate_fn(samples: list[Sample]) -> dict: labels = [] input_lengths = [] label_lengths = [] - + for sample in samples: waveforms.append(sample["waveform"].transpose(0, 1)) spectrograms.append(sample["spectrogram"]) labels.append(sample["utterance"]) input_lengths.append(sample["spectrogram"].shape[0] // 2) label_lengths.append(len(sample["utterance"])) - + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) - spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3) + spectrograms = ( + torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) + ) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) - + return { "waveform": waveforms, "spectrogram": spectrograms, @@ -267,15 +282,15 @@ def collate_fn(samples: list[Sample]) -> dict: "utterance": labels, "utterance_length": label_lengths, } - + if __name__ == "__main__": - dataset_path = "/Volumes/pherkel/SWR2-ASR" - language = "mls_german_opus" - split = Split.train - download = False + DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" + LANGUAGE = "mls_german_opus" + split = Split.TRAIN + DOWNLOAD = False - dataset = MLSDataset(dataset_path, language, split, download, None) + dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None) tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) -- cgit v1.2.3 From acafe88a1a360832b727651b713806ce0404db3f Mon Sep 17 00:00:00 2001 From: Pherkel Date: Sun, 3 Sep 2023 20:52:15 +0200 Subject: fix not unzipping dataset --- swr2_asr/utils.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index efecb56..0330cd2 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -8,6 +8,7 @@ import torch import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset +from torchaudio.datasets.utils import _extract_tar as extract_archive from swr2_asr.tokenizer import TokenizerType @@ -169,20 +170,27 @@ class MLSDataset(Dataset): def _handle_download_dataset(self, download: bool): """Download the dataset""" - if not os.path.exists(os.path.join(self.dataset_path, self.language)) and download: + if not download: + return + # zip exists: + if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: + print(f"Found dataset at {self.dataset_path}. Skipping download") + # zip does not exist: + else: os.makedirs(self.dataset_path) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" torch.hub.download_url_to_file(url, self.dataset_path) - elif not os.path.exists(os.path.join(self.dataset_path, self.language)) and not download: - raise ValueError("Dataset not found. Set download to True to download it") + + # unzip the dataset + extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz") def _validate_local_directory(self): # check if dataset_path exists if not os.path.exists(self.dataset_path): raise ValueError("Dataset path does not exist") if not os.path.exists(os.path.join(self.dataset_path, self.language)): - raise ValueError("Language not found in dataset") + raise ValueError("Language not downloaded!") if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") -- cgit v1.2.3 From cd15a49ccee83c21ada481d6815d004f134147fe Mon Sep 17 00:00:00 2001 From: Philipp Merkel Date: Mon, 4 Sep 2023 14:07:54 +0000 Subject: applied fixes to download and tokenizers --- swr2_asr/tokenizer.py | 28 ++++++++++++++++++++++++++++ swr2_asr/train.py | 3 +-- swr2_asr/utils.py | 14 ++++++++++---- 3 files changed, 39 insertions(+), 6 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index f02d4f5..64227a4 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -187,6 +187,7 @@ class CharTokenizer(TokenizerType): def save(self, path: str): """Save the tokenizer to a file""" + os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", encoding="utf-8") as file: # save it in the following format: # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} @@ -217,6 +218,23 @@ class CharTokenizer(TokenizerType): @click.option("--download", default=True, help="Whether to download the dataset") @click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") +def train_bpe_tokenizer_cli( + dataset_path: str, + language: str, + split: str, + out_path: str, + download: bool, + vocab_size: int, +): + train_bpe_tokenizer( + dataset_path, + language, + split, + out_path, + download, + vocab_size, +) + def train_bpe_tokenizer( dataset_path: str, language: str, @@ -251,6 +269,7 @@ def train_bpe_tokenizer( for s_plit in splits: transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") if download and not os.path.exists(transcripts_path): + # TODO: move to own dataset MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) with open( @@ -337,6 +356,15 @@ def train_bpe_tokenizer( @click.option("--split", default="train", help="Split to use") @click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") @click.option("--download", default=True, help="Whether to download the dataset") +def train_char_tokenizer_cli( + dataset_path: str, + language: str, + split: str, + out_path: str, + download: bool, + ): + train_char_tokenizer(dataset_path, language, split, out_path, download) + def train_char_tokenizer( dataset_path: str, language: str, diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 8fc0b78..aea99e0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -173,7 +173,6 @@ def run( split="all", download=False, out_path="data/tokenizers/char_tokenizer_german.json", - vocab_size=3000, ) tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") @@ -305,4 +304,4 @@ def run_cli( if __name__ == "__main__": - run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") + run_cli() diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 0330cd2..1c755a6 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -168,22 +168,28 @@ class MLSDataset(Dataset): """Sets the tokenizer""" self.tokenizer = tokenizer - def _handle_download_dataset(self, download: bool): + def _handle_download_dataset(self, download: bool) -> None: """Download the dataset""" if not download: + print("Download flag not set, skipping download") return # zip exists: if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: print(f"Found dataset at {self.dataset_path}. Skipping download") # zip does not exist: else: - os.makedirs(self.dataset_path) + os.makedirs(self.dataset_path, exist_ok=True) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" - torch.hub.download_url_to_file(url, self.dataset_path) + torch.hub.download_url_to_file(url, os.path.join(self.dataset_path, self.language) + ".tar.gz") # unzip the dataset - extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz") + if not os.path.isdir(os.path.join(self.dataset_path, self.language)): + print(f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}") + extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True) + else: + print("Dataset is already unzipped, validating it now") + return def _validate_local_directory(self): # check if dataset_path exists -- cgit v1.2.3 From 0d70a19e1fea6eda3f7b16ad0084591613f2de72 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 4 Sep 2023 22:48:40 +0200 Subject: please the linter --- swr2_asr/tokenizer.py | 45 ++++++++++++++++++++------------------------- swr2_asr/train.py | 2 +- swr2_asr/utils.py | 16 +++++++++++----- 3 files changed, 32 insertions(+), 31 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 64227a4..e4df93b 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -90,7 +90,7 @@ class CharTokenizer(TokenizerType): self.char_map[token] = len(self.char_map) self.index_map[len(self.index_map)] = token - def train(self, dataset_path: str, language: str, split: str, download: bool = True): + def train(self, dataset_path: str, language: str, split: str): """Train the tokenizer on the given dataset Args: @@ -110,10 +110,6 @@ class CharTokenizer(TokenizerType): for s_plit in splits: transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - # check if dataset is downloaded, download if not - if download and not os.path.exists(transcript_path): - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) - with open( transcript_path, "r", @@ -215,7 +211,6 @@ class CharTokenizer(TokenizerType): @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use (including all)") -@click.option("--download", default=True, help="Whether to download the dataset") @click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") def train_bpe_tokenizer_cli( @@ -223,24 +218,23 @@ def train_bpe_tokenizer_cli( language: str, split: str, out_path: str, - download: bool, vocab_size: int, ): - train_bpe_tokenizer( - dataset_path, - language, - split, - out_path, - download, - vocab_size, -) + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_bpe_tokenizer( + dataset_path, + language, + split, + out_path, + vocab_size, + ) + def train_bpe_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, vocab_size: int, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -268,9 +262,11 @@ def train_bpe_tokenizer( for s_plit in splits: transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - if download and not os.path.exists(transcripts_path): - # TODO: move to own dataset - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) + if not os.path.exists(transcripts_path): + raise FileNotFoundError( + f"Could not find transcripts.txt in {transcripts_path}. " + "Please make sure that the dataset is downloaded." + ) with open( transcripts_path, @@ -355,22 +351,21 @@ def train_bpe_tokenizer( @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use") @click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") -@click.option("--download", default=True, help="Whether to download the dataset") def train_char_tokenizer_cli( dataset_path: str, language: str, split: str, out_path: str, - download: bool, - ): - train_char_tokenizer(dataset_path, language, split, out_path, download) +): + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_char_tokenizer(dataset_path, language, split, out_path) + def train_char_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -386,7 +381,7 @@ def train_char_tokenizer( """ char_tokenizer = CharTokenizer() - char_tokenizer.train(dataset_path, language, split, download) + char_tokenizer.train(dataset_path, language, split) char_tokenizer.save(out_path) diff --git a/swr2_asr/train.py b/swr2_asr/train.py index aea99e0..63deb72 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -304,4 +304,4 @@ def run_cli( if __name__ == "__main__": - run_cli() + run_cli() # pylint: disable=no-value-for-parameter diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 1c755a6..8a950ab 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -181,15 +181,21 @@ class MLSDataset(Dataset): os.makedirs(self.dataset_path, exist_ok=True) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" - torch.hub.download_url_to_file(url, os.path.join(self.dataset_path, self.language) + ".tar.gz") + torch.hub.download_url_to_file( + url, os.path.join(self.dataset_path, self.language) + ".tar.gz" + ) # unzip the dataset if not os.path.isdir(os.path.join(self.dataset_path, self.language)): - print(f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}") - extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True) + print( + f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" + ) + extract_archive( + os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True + ) else: - print("Dataset is already unzipped, validating it now") - return + print("Dataset is already unzipped, validating it now") + return def _validate_local_directory(self): # check if dataset_path exists -- cgit v1.2.3