diff options
author | Pherkel | 2023-09-03 19:30:33 +0200 |
---|---|---|
committer | Pherkel | 2023-09-03 19:30:33 +0200 |
commit | 33f09080aee10bddb4797a557d676ee1f7b8de31 (patch) | |
tree | 1a35c13c2c91f84f542fe9ed5552bcbe60b437c3 /swr2_asr | |
parent | f3d2ea9a16944434a08e662c5ecfd6ba50e5ea89 (diff) |
idk, hopefully this works
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/inference_test.py | 4 | ||||
-rw-r--r-- | swr2_asr/loss_scores.py | 8 | ||||
-rw-r--r-- | swr2_asr/model_deep_speech.py | 13 | ||||
-rw-r--r-- | swr2_asr/tokenizer.py | 17 | ||||
-rw-r--r-- | swr2_asr/train.py | 90 | ||||
-rw-r--r-- | swr2_asr/utils.py | 85 |
6 files changed, 122 insertions, 95 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py index a6b0010..16bd54b 100644 --- a/swr2_asr/inference_test.py +++ b/swr2_asr/inference_test.py @@ -44,9 +44,7 @@ def main() -> None: print(model.__class__) # only do all things for one single sample - dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="train", download=True - ) + dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True) print(dataset[0]) diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py index c49cc15..ef37b0a 100644 --- a/swr2_asr/loss_scores.py +++ b/swr2_asr/loss_scores.py @@ -54,9 +54,7 @@ def _levenshtein_distance(ref, hyp): return distance[len_ref % 2][len_hyp] -def word_errors( - reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " -): +def word_errors(reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "): """Compute the levenshtein distance between reference sequence and hypothesis sequence in word-level. :param reference: The reference sentence. @@ -176,9 +174,7 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False): :rtype: float :raises ValueError: If the reference length is zero. """ - edit_distance, ref_len = char_errors( - reference, hypothesis, ignore_case, remove_space - ) + edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space) if ref_len == 0: raise ValueError("Length of reference should be greater than 0.") diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index f00ebd4..dd07ff9 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,3 +1,4 @@ +"""Main definition of model""" import torch.nn.functional as F from torch import nn @@ -30,9 +31,7 @@ class ResidualCNN(nn.Module): ): super().__init__() - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) + self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2) self.cnn2 = nn.Conv2d( out_channels, out_channels, @@ -110,9 +109,7 @@ class SpeechRecognitionModel(nn.Module): # n residual cnn layers with filter size of 32 self.rescnn_layers = nn.Sequential( *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) + ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) for _ in range(n_cnn_layers) ] ) @@ -140,9 +137,7 @@ class SpeechRecognitionModel(nn.Module): data = self.cnn(data) data = self.rescnn_layers(data) sizes = data.size() - data = data.view( - sizes[0], sizes[1] * sizes[2], sizes[3] - ) # (batch, feature, time) + data = data.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) data = data.transpose(1, 2) # (batch, time, feature) data = self.fully_connected(data) data = self.birnn_layers(data) diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 8e3bf09..c8d3793 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -14,16 +14,24 @@ from tqdm import tqdm class TokenizerType: + """Base class for tokenizers. + + exposes the same interface as tokenizers from the huggingface library""" + def encode(self, sequence: str) -> list[int]: + """Encode a sequence to a list of integer labels""" raise NotImplementedError def decode(self, labels: list[int], remove_special_tokens: bool) -> str: + """Decode a list of integer labels to a sequence""" raise NotImplementedError def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Decode a batch of integer labels to a list of sequences""" raise NotImplementedError def get_vocab_size(self) -> int: + """Get the size of the vocabulary""" raise NotImplementedError def enable_padding( @@ -34,17 +42,20 @@ class TokenizerType: pad_type_id: int = 0, pad_token: str = "[PAD]", ) -> None: + """Enable padding for the tokenizer""" raise NotImplementedError def save(self, path: str) -> None: + """Save the tokenizer to a file""" raise NotImplementedError @staticmethod def from_file(path: str) -> "TokenizerType": + """Load the tokenizer from a file""" raise NotImplementedError -tokenizer_type = Type[TokenizerType] +MyTokenizerType = Type[TokenizerType] @dataclass @@ -52,6 +63,7 @@ class Encoding: """Simple dataclass to represent an encoding""" ids: list[int] + tokens: list[str] class CharTokenizer(TokenizerType): @@ -137,7 +149,7 @@ class CharTokenizer(TokenizerType): else: mapped_char = self.char_map[char] int_sequence.append(mapped_char) - return Encoding(ids=int_sequence) + return Encoding(ids=int_sequence, tokens=list(sequence)) def decode(self, labels: list[int], remove_special_tokens: bool = True): """Use a character map and convert integer labels to an text sequence @@ -256,6 +268,7 @@ def train_bpe_tokenizer( bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]")) initial_alphabet = [ + " ", "a", "b", "c", diff --git a/swr2_asr/train.py b/swr2_asr/train.py index f3efd69..95038c2 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,19 +5,19 @@ from typing import TypedDict import click import torch import torch.nn.functional as F -from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader from tqdm import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel -from swr2_asr.tokenizer import train_bpe_tokenizer +from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer # TODO: improve naming of functions + class HParams(TypedDict): """Type for the hyperparameters of the model.""" @@ -33,10 +33,11 @@ class HParams(TypedDict): epochs: int -# TODO: get blank label from tokenizer -def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True): +def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): """Greedily decode a sequence.""" + print("output shape", output.shape) arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + blank_label = tokenizer.encode(" ").ids[0] decodes = [] targets = [] for i, args in enumerate(arg_maxes): @@ -81,28 +82,27 @@ def train( print(f"Epoch: {epoch}") losses = [] for _data in tqdm(train_loader, desc="batches"): - spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) loss.backward() optimizer.step() scheduler.step() iter_meter.step() - + losses.append(loss.item()) print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") return sum(losses) / len(losses) -# TODO: profile this function call -# TODO: only calculate wer and cer at the end, or less often -# TODO: change this to only be a sanity check and calculate measures after training + + def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") @@ -111,21 +111,21 @@ def test(model, device, test_loader, criterion, tokenizer): test_cer, test_wer = [], [] with torch.no_grad(): for _data in test_loader: - spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - # TODO: get rid of this - loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output = output.transpose(0, 1), - labels = labels, - label_lengths= _data["utterance_length"], - tokenizer=tokenizer) + output=output.transpose(0, 1), + labels=labels, + label_lengths=_data["utterance_length"], + tokenizer=tokenizer, + ) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) @@ -135,9 +135,11 @@ def test(model, device, test_loader, criterion, tokenizer): print( f"Test set: Average loss:\ - {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" + {test_loss}, Average CER: {None} Average WER: {None}\n" ) + return test_loss, avg_cer, avg_wer + def run( learning_rate: float, @@ -152,33 +154,34 @@ def run( use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - device = torch.device("mps") + # device = torch.device("mps") # load dataset - # TODO: change this from dev split to train split again (was faster for development) - train_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) - valid_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) - test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None) + train_dataset = MLSDataset( + dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None + ) + valid_dataset = MLSDataset( + dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None + ) - # load tokenizer (bpe by default): - if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"): + # load tokenizer (bpe by default): + if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): print("There is no tokenizer available. Do you want to train it on the dataset?") input("Press Enter to continue...") - train_bpe_tokenizer( + train_char_tokenizer( dataset_path=dataset_path, language=language, split="all", download=False, - out_path="data/tokenizers/bpe_tokenizer_german_3000.json", + out_path="data/tokenizers/char_tokenizer_german.json", vocab_size=3000, ) - - tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - - train_dataset.set_tokenizer(tokenizer) - valid_dataset.set_tokenizer(tokenizer) - test_dataset.set_tokenizer(tokenizer) - + + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + + train_dataset.set_tokenizer(tokenizer) # type: ignore + valid_dataset.set_tokenizer(tokenizer) # type: ignore + print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") hparams = HParams( @@ -221,10 +224,10 @@ def run( hparams["stride"], hparams["dropout"], ).to(device) - + print(tokenizer.encode(" ")) print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(blank=28).to(device) + criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) if load: checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) @@ -240,7 +243,7 @@ def run( ) iter_meter = IterMeter() - for epoch in range(1, epochs + 1): + for epoch in range(1, epochs + 1): loss = train( model, device, @@ -252,7 +255,13 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer) + test( + model=model, + device=device, + test_loader=valid_loader, + criterion=criterion, + tokenizer=tokenizer, + ) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -295,5 +304,6 @@ def run_cli( language="mls_german_opus", ) -if __name__ == "__main__": - run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus")
\ No newline at end of file + +if __name__ == "__main__": + run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 3b9b3ca..efecb56 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -8,7 +8,6 @@ import torch import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset -from tqdm import tqdm from swr2_asr.tokenizer import TokenizerType @@ -24,26 +23,26 @@ class MLSSplit(str, Enum): """Enum specifying dataset as they are defined in the Multilingual LibriSpeech dataset""" - train = "train" - test = "test" - dev = "dev" + TRAIN = "train" + TEST = "test" + DEV = "dev" class Split(str, Enum): """Extending the MLSSplit class to allow for a custom validatio split""" - train = "train" - valid = "valid" - test = "test" - dev = "dev" + TRAIN = "train" + VALID = "valid" + TEST = "test" + DEV = "dev" -def split_to_mls_split(split: Split) -> MLSSplit: +def split_to_mls_split(split_name: Split) -> MLSSplit: """Converts the custom split to a MLSSplit""" - if split == Split.valid: - return MLSSplit.train + if split_name == Split.VALID: + return MLSSplit.TRAIN else: - return split # type: ignore + return split_name # type: ignore class Sample(TypedDict): @@ -89,14 +88,21 @@ class MLSDataset(Dataset): <speakerid>_<bookid>_<chapterid> <utterance> """ - def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None): + def __init__( + self, + dataset_path: str, + language: str, + split: Split, + download: bool, + spectrogram_hparams: dict | None, + ): """Initializes the dataset""" self.dataset_path = dataset_path self.language = language self.file_ext = ".opus" if "opus" in language else ".flac" self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally - + if spectrogram_hparams is None: self.spectrogram_hparams = { "sample_rate": 16000, @@ -110,7 +116,7 @@ class MLSDataset(Dataset): } else: self.spectrogram_hparams = spectrogram_hparams - + self.dataset_lookup = [] self.tokenizer: type[TokenizerType] @@ -118,13 +124,14 @@ class MLSDataset(Dataset): self._validate_local_directory() self.initialize() - def initialize(self) -> None: """Initializes the dataset - + Reads the transcripts.txt file and creates a lookup table """ - transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt") + transcripts_path = os.path.join( + self.dataset_path, self.language, self.mls_split, "transcripts.txt" + ) with open(transcripts_path, "r", encoding="utf-8") as script_file: # read all lines in transcripts.txt @@ -135,12 +142,12 @@ class MLSDataset(Dataset): identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore identifier = [path.split("_") for path in identifier] - if self.split == Split.valid: + if self.split == Split.VALID: np.random.seed(42) indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) utterances = [utterances[i] for i in indices] identifier = [identifier[i] for i in indices] - elif self.split == Split.train: + elif self.split == Split.TRAIN: np.random.seed(42) indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) utterances = [utterances[i] for i in indices] @@ -212,13 +219,19 @@ class MLSDataset(Dataset): # resample if necessary if sample_rate != self.spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"]) + resampler = torchaudio.transforms.Resample( + sample_rate, self.spectrogram_hparams["sample_rate"] + ) waveform = resampler(waveform) - spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1) + spec = ( + torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform) + .squeeze(0) + .transpose(0, 1) + ) input_length = spec.shape[0] // 2 - + utterance_length = len(utterance) utterance = self.tokenizer.encode(utterance) @@ -236,11 +249,11 @@ class MLSDataset(Dataset): book_id=self.dataset_lookup[idx]["bookid"], chapter_id=self.dataset_lookup[idx]["chapterid"], ) - + def collate_fn(samples: list[Sample]) -> dict: """Collate function for the dataloader - + pads all tensors within a batch to the same dimensions """ waveforms = [] @@ -248,18 +261,20 @@ def collate_fn(samples: list[Sample]) -> dict: labels = [] input_lengths = [] label_lengths = [] - + for sample in samples: waveforms.append(sample["waveform"].transpose(0, 1)) spectrograms.append(sample["spectrogram"]) labels.append(sample["utterance"]) input_lengths.append(sample["spectrogram"].shape[0] // 2) label_lengths.append(len(sample["utterance"])) - + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) - spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3) + spectrograms = ( + torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) + ) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) - + return { "waveform": waveforms, "spectrogram": spectrograms, @@ -267,15 +282,15 @@ def collate_fn(samples: list[Sample]) -> dict: "utterance": labels, "utterance_length": label_lengths, } - + if __name__ == "__main__": - dataset_path = "/Volumes/pherkel/SWR2-ASR" - language = "mls_german_opus" - split = Split.train - download = False + DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" + LANGUAGE = "mls_german_opus" + split = Split.TRAIN + DOWNLOAD = False - dataset = MLSDataset(dataset_path, language, split, download, None) + dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None) tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) |