aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--swr2_asr/model_deep_speech.py2
-rw-r--r--swr2_asr/train.py49
-rw-r--r--swr2_asr/utils.py164
3 files changed, 95 insertions, 120 deletions
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
index ea0b667..f00ebd4 100644
--- a/swr2_asr/model_deep_speech.py
+++ b/swr2_asr/model_deep_speech.py
@@ -1,5 +1,5 @@
-from torch import nn
import torch.nn.functional as F
+from torch import nn
class CNNLayerNorm(nn.Module):
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 53cdac1..bae8c7c 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -5,13 +5,14 @@ from typing import TypedDict
import click
import torch
import torch.nn.functional as F
+from AudioLoader.speech import MultilingualLibriSpeech
from tokenizers import Tokenizer
from torch import nn, optim
from torch.utils.data import DataLoader
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.tokenizer import train_bpe_tokenizer
-from swr2_asr.utils import MLSDataset, Split
+from swr2_asr.utils import MLSDataset, Split, collate_fn
from .loss_scores import cer, wer
@@ -31,20 +32,20 @@ class HParams(TypedDict):
epochs: int
-def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True):
+def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True):
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
decode = []
- targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
+ targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
- decodes.append(text_transform.decode(decode))
+ decodes.append(tokenizer.decode(decode))
return decodes, targets
@@ -77,15 +78,14 @@ def train(
model.train()
data_len = len(train_loader.dataset)
for batch_idx, _data in enumerate(train_loader):
- _, spectrograms, input_lengths, labels, label_lengths, *_ = _data
- spectrograms, labels = spectrograms.to(device), labels.to(device)
+ spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device)
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, input_lengths, label_lengths)
+ loss = criterion(output, labels, _data['input_length'], _data["utterance_length"])
loss.backward()
optimizer.step()
@@ -102,7 +102,7 @@ def train(
return loss.item()
-def test(model, device, test_loader, criterion):
+def test(model, device, test_loader, criterion, tokenizer):
"""Test"""
print("\nevaluating...")
model.eval()
@@ -110,17 +110,20 @@ def test(model, device, test_loader, criterion):
test_cer, test_wer = [], []
with torch.no_grad():
for _data in test_loader:
- spectrograms, labels, input_lengths, label_lengths = _data
- spectrograms, labels = spectrograms.to(device), labels.to(device)
+ spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device)
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, input_lengths, label_lengths)
+ loss = criterion(output, labels, _data['input_length'], _data["utterance_length"])
test_loss += loss.item() / len(test_loader)
- decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths)
+ decoded_preds, decoded_targets = greedy_decoder(
+ output = output.transpose(0, 1),
+ labels = labels,
+ label_lengths= _data["utterance_length"],
+ tokenizer=tokenizer)
for j, pred in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], pred))
test_wer.append(wer(decoded_targets[j], pred))
@@ -150,12 +153,11 @@ def run(
# device = torch.device("mps")
# load dataset
- train_dataset = MLSDataset(dataset_path, language, Split.train, download=True)
- valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True)
- test_dataset = MLSDataset(dataset_path, language, Split.test, download=True)
+ train_dataset = MLSDataset(dataset_path, language, Split.train, download=True, spectrogram_hparams=None)
+ valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True, spectrogram_hparams=None)
+ test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None)
- # TODO: add flag to choose tokenizer
- # load tokenizer (bpe by default):
+ # load tokenizer (bpe by default):
if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"):
print("There is no tokenizer available. Do you want to train it on the dataset?")
input("Press Enter to continue...")
@@ -167,12 +169,14 @@ def run(
out_path="data/tokenizers/bpe_tokenizer_german_3000.json",
vocab_size=3000,
)
-
+
tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
-
+
train_dataset.set_tokenizer(tokenizer)
valid_dataset.set_tokenizer(tokenizer)
test_dataset.set_tokenizer(tokenizer)
+
+ print(f"Waveform shape: {train_dataset[0]['waveform'].shape}")
hparams = HParams(
n_cnn_layers=3,
@@ -191,12 +195,14 @@ def run(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
+ collate_fn=lambda x: collate_fn(x),
)
valid_loader = DataLoader(
valid_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
+ collate_fn=lambda x: collate_fn(x),
)
# enable flag to find the most compatible algorithms in advance
@@ -243,7 +249,7 @@ def run(
iter_meter,
)
- test(model=model, device=device, test_loader=valid_loader, criterion=criterion)
+ test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer)
print("saving epoch", str(epoch))
torch.save(
{"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
@@ -285,3 +291,6 @@ def run_cli(
dataset_path=dataset_path,
language="mls_german_opus",
)
+
+if __name__ == "__main__":
+ run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") \ No newline at end of file
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 404661d..4c751d5 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -1,7 +1,6 @@
"""Class containing utils for the ASR system."""
import os
from enum import Enum
-from multiprocessing import Pool
from typing import TypedDict
import numpy as np
@@ -10,9 +9,8 @@ import torchaudio
from tokenizers import Tokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
-import audio_metadata
-from swr2_asr.tokenizer import CharTokenizer, TokenizerType
+from swr2_asr.tokenizer import TokenizerType
train_audio_transforms = torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
@@ -91,20 +89,42 @@ class MLSDataset(Dataset):
<speakerid>_<bookid>_<chapterid> <utterance>
"""
- def __init__(self, dataset_path: str, language: str, split: Split, download: bool):
+ def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None):
"""Initializes the dataset"""
self.dataset_path = dataset_path
self.language = language
self.file_ext = ".opus" if "opus" in language else ".flac"
self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
self.split: Split = split # split used internally
+
+ if spectrogram_hparams is None:
+ self.spectrogram_hparams = {
+ "sample_rate": 16000,
+ "n_fft": 400,
+ "win_length": 400,
+ "hop_length": 160,
+ "n_mels": 128,
+ "f_min": 0,
+ "f_max": 8000,
+ "power": 2.0,
+ }
+ else:
+ self.spectrogram_hparams = spectrogram_hparams
+
self.dataset_lookup = []
self.tokenizer: type[TokenizerType]
self._handle_download_dataset(download)
self._validate_local_directory()
+ self.initialize()
- transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt")
+
+ def initialize(self) -> None:
+ """Initializes the dataset
+
+ Reads the transcripts.txt file and creates a lookup table
+ """
+ transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt")
with open(transcripts_path, "r", encoding="utf-8") as script_file:
# read all lines in transcripts.txt
@@ -136,13 +156,9 @@ class MLSDataset(Dataset):
for path, utterance in zip(identifier, utterances, strict=False)
]
- self.max_spec_length = 0
- self.max_utterance_length = 0
-
def set_tokenizer(self, tokenizer: type[TokenizerType]):
"""Sets the tokenizer"""
self.tokenizer = tokenizer
- # self.calc_paddings()
def _handle_download_dataset(self, download: bool):
"""Download the dataset"""
@@ -163,80 +179,14 @@ class MLSDataset(Dataset):
if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
raise ValueError("Split not found in dataset")
- def _calculate_max_length(self, chunk):
- """Calculates the maximum length of the spectrogram and the utterance
-
- to be called in a multiprocessing pool
- """
- max_spec_length = 0
- max_utterance_length = 0
-
- for sample in chunk:
- audio_path = os.path.join(
- self.dataset_path,
- self.language,
- self.mls_split,
- "audio",
- sample["speakerid"],
- sample["bookid"],
- "_".join(
- [
- sample["speakerid"],
- sample["bookid"],
- sample["chapterid"],
- ]
- )
- + self.file_ext,
- )
- metadata = audio_metadata.load(audio_path)
- audio_duration = metadata.streaminfo.duration
- sample_rate = metadata.streaminfo.sample_rate
-
- max_spec_length = int(max(max_spec_length, (audio_duration * sample_rate) // 200))
- max_utterance_length = max(max_utterance_length, len(self.tokenizer.encode(sample["utterance"]).ids))
-
- return max_spec_length, max_utterance_length
-
- def calc_paddings(self) -> None:
- """Sets the maximum length of the spectrogram and the utterance"""
- # check if dataset has been loaded and tokenizer has been set
- if not self.dataset_lookup:
- raise ValueError("Dataset not loaded")
- if not self.tokenizer:
- raise ValueError("Tokenizer not set")
- # check if paddings have been calculated already
- if os.path.isfile(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt")):
- print("Paddings already calculated")
- with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "r") as f:
- self.max_spec_length, self.max_utterance_length = [int(line.strip()) for line in f.readlines()]
- return
- else:
- print("Calculating paddings...")
-
- thread_count = os.cpu_count()
- if thread_count is None:
- thread_count = 4
- chunk_size = len(self.dataset_lookup) // thread_count
- chunks = [self.dataset_lookup[i : i + chunk_size] for i in range(0, len(self.dataset_lookup), chunk_size)]
-
- with Pool(thread_count) as p:
- results = list(p.imap(self._calculate_max_length, chunks))
-
- for spec, utterance in results:
- self.max_spec_length = max(self.max_spec_length, spec)
- self.max_utterance_length = max(self.max_utterance_length, utterance)
-
- # write to file
- with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "w") as f:
- f.write(f"{self.max_spec_length}\n")
- f.write(f"{self.max_utterance_length}")
-
def __len__(self):
"""Returns the length of the dataset"""
return len(self.dataset_lookup)
def __getitem__(self, idx: int) -> Sample:
"""One sample"""
+ if self.tokenizer is None:
+ raise ValueError("No tokenizer set")
# get the utterance
utterance = self.dataset_lookup[idx]["utterance"]
@@ -261,42 +211,62 @@ class MLSDataset(Dataset):
waveform, sample_rate = torchaudio.load(audio_path) # type: ignore
# resample if necessary
- if sample_rate != 16000:
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
+ if sample_rate != self.spectrogram_hparams["sample_rate"]:
+ resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"])
waveform = resampler(waveform)
- sample_rate = 16000
- spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1)
+ spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1)
input_length = spec.shape[0] // 2
+
utterance_length = len(utterance)
- self.tokenizer.enable_padding()
- utterance = self.tokenizer.encode(
- utterance,
- ).ids
+ utterance = self.tokenizer.encode(utterance)
- utterance = torch.Tensor(utterance)
+ utterance = torch.LongTensor(utterance.ids)
return Sample(
- # TODO: add flag to only return spectrogram or waveform or both
waveform=waveform,
spectrogram=spec,
input_length=input_length,
utterance=utterance,
utterance_length=utterance_length,
- sample_rate=sample_rate,
+ sample_rate=self.spectrogram_hparams["sample_rate"],
speaker_id=self.dataset_lookup[idx]["speakerid"],
book_id=self.dataset_lookup[idx]["bookid"],
chapter_id=self.dataset_lookup[idx]["chapterid"],
)
- def download(self, dataset_path: str, language: str):
- """Download the dataset"""
- os.makedirs(dataset_path)
- url = f"https://dl.fbaipublicfiles.com/mls/{language}.tar.gz"
-
- torch.hub.download_url_to_file(url, dataset_path)
+def collate_fn(samples: list[Sample]) -> dict:
+ """Collate function for the dataloader
+
+ pads all tensors within a batch to the same dimensions
+ """
+ waveforms = []
+ spectrograms = []
+ labels = []
+ input_lengths = []
+ label_lengths = []
+
+ for sample in samples:
+ waveforms.append(sample["waveform"].transpose(0, 1))
+ spectrograms.append(sample["spectrogram"])
+ labels.append(sample["utterance"])
+ input_lengths.append(sample["spectrogram"].shape[0] // 2)
+ label_lengths.append(len(sample["utterance"]))
+
+ waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
+ spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3)
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+ return {
+ "waveform": waveforms,
+ "spectrogram": spectrograms,
+ "input_length": input_lengths,
+ "utterance": labels,
+ "utterance_length": label_lengths,
+ }
+
if __name__ == "__main__":
@@ -305,11 +275,7 @@ if __name__ == "__main__":
split = Split.train
download = False
- dataset = MLSDataset(dataset_path, language, split, download)
+ dataset = MLSDataset(dataset_path, language, split, download, None)
tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
dataset.set_tokenizer(tok)
- dataset.calc_paddings()
-
- print(f"Spectrogram shape: {dataset[41]['spectrogram'].shape}")
- print(f"Utterance shape: {dataset[41]['utterance'].shape}")