aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
authorPherkel2023-09-03 19:30:33 +0200
committerPherkel2023-09-03 19:30:33 +0200
commit33f09080aee10bddb4797a557d676ee1f7b8de31 (patch)
tree1a35c13c2c91f84f542fe9ed5552bcbe60b437c3 /swr2_asr
parentf3d2ea9a16944434a08e662c5ecfd6ba50e5ea89 (diff)
idk, hopefully this works
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/inference_test.py4
-rw-r--r--swr2_asr/loss_scores.py8
-rw-r--r--swr2_asr/model_deep_speech.py13
-rw-r--r--swr2_asr/tokenizer.py17
-rw-r--r--swr2_asr/train.py90
-rw-r--r--swr2_asr/utils.py85
6 files changed, 122 insertions, 95 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
index a6b0010..16bd54b 100644
--- a/swr2_asr/inference_test.py
+++ b/swr2_asr/inference_test.py
@@ -44,9 +44,7 @@ def main() -> None:
print(model.__class__)
# only do all things for one single sample
- dataset = MultilingualLibriSpeech(
- "data", "mls_german_opus", split="train", download=True
- )
+ dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True)
print(dataset[0])
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py
index c49cc15..ef37b0a 100644
--- a/swr2_asr/loss_scores.py
+++ b/swr2_asr/loss_scores.py
@@ -54,9 +54,7 @@ def _levenshtein_distance(ref, hyp):
return distance[len_ref % 2][len_hyp]
-def word_errors(
- reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "
-):
+def word_errors(reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "):
"""Compute the levenshtein distance between reference sequence and
hypothesis sequence in word-level.
:param reference: The reference sentence.
@@ -176,9 +174,7 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False):
:rtype: float
:raises ValueError: If the reference length is zero.
"""
- edit_distance, ref_len = char_errors(
- reference, hypothesis, ignore_case, remove_space
- )
+ edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space)
if ref_len == 0:
raise ValueError("Length of reference should be greater than 0.")
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
index f00ebd4..dd07ff9 100644
--- a/swr2_asr/model_deep_speech.py
+++ b/swr2_asr/model_deep_speech.py
@@ -1,3 +1,4 @@
+"""Main definition of model"""
import torch.nn.functional as F
from torch import nn
@@ -30,9 +31,7 @@ class ResidualCNN(nn.Module):
):
super().__init__()
- self.cnn1 = nn.Conv2d(
- in_channels, out_channels, kernel, stride, padding=kernel // 2
- )
+ self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
self.cnn2 = nn.Conv2d(
out_channels,
out_channels,
@@ -110,9 +109,7 @@ class SpeechRecognitionModel(nn.Module):
# n residual cnn layers with filter size of 32
self.rescnn_layers = nn.Sequential(
*[
- ResidualCNN(
- 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
- )
+ ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
for _ in range(n_cnn_layers)
]
)
@@ -140,9 +137,7 @@ class SpeechRecognitionModel(nn.Module):
data = self.cnn(data)
data = self.rescnn_layers(data)
sizes = data.size()
- data = data.view(
- sizes[0], sizes[1] * sizes[2], sizes[3]
- ) # (batch, feature, time)
+ data = data.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time)
data = data.transpose(1, 2) # (batch, time, feature)
data = self.fully_connected(data)
data = self.birnn_layers(data)
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index 8e3bf09..c8d3793 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -14,16 +14,24 @@ from tqdm import tqdm
class TokenizerType:
+ """Base class for tokenizers.
+
+ exposes the same interface as tokenizers from the huggingface library"""
+
def encode(self, sequence: str) -> list[int]:
+ """Encode a sequence to a list of integer labels"""
raise NotImplementedError
def decode(self, labels: list[int], remove_special_tokens: bool) -> str:
+ """Decode a list of integer labels to a sequence"""
raise NotImplementedError
def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ """Decode a batch of integer labels to a list of sequences"""
raise NotImplementedError
def get_vocab_size(self) -> int:
+ """Get the size of the vocabulary"""
raise NotImplementedError
def enable_padding(
@@ -34,17 +42,20 @@ class TokenizerType:
pad_type_id: int = 0,
pad_token: str = "[PAD]",
) -> None:
+ """Enable padding for the tokenizer"""
raise NotImplementedError
def save(self, path: str) -> None:
+ """Save the tokenizer to a file"""
raise NotImplementedError
@staticmethod
def from_file(path: str) -> "TokenizerType":
+ """Load the tokenizer from a file"""
raise NotImplementedError
-tokenizer_type = Type[TokenizerType]
+MyTokenizerType = Type[TokenizerType]
@dataclass
@@ -52,6 +63,7 @@ class Encoding:
"""Simple dataclass to represent an encoding"""
ids: list[int]
+ tokens: list[str]
class CharTokenizer(TokenizerType):
@@ -137,7 +149,7 @@ class CharTokenizer(TokenizerType):
else:
mapped_char = self.char_map[char]
int_sequence.append(mapped_char)
- return Encoding(ids=int_sequence)
+ return Encoding(ids=int_sequence, tokens=list(sequence))
def decode(self, labels: list[int], remove_special_tokens: bool = True):
"""Use a character map and convert integer labels to an text sequence
@@ -256,6 +268,7 @@ def train_bpe_tokenizer(
bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
initial_alphabet = [
+ " ",
"a",
"b",
"c",
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index f3efd69..95038c2 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -5,19 +5,19 @@ from typing import TypedDict
import click
import torch
import torch.nn.functional as F
-from tokenizers import Tokenizer
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from swr2_asr.model_deep_speech import SpeechRecognitionModel
-from swr2_asr.tokenizer import train_bpe_tokenizer
+from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer
from swr2_asr.utils import MLSDataset, Split, collate_fn
from .loss_scores import cer, wer
# TODO: improve naming of functions
+
class HParams(TypedDict):
"""Type for the hyperparameters of the model."""
@@ -33,10 +33,11 @@ class HParams(TypedDict):
epochs: int
-# TODO: get blank label from tokenizer
-def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True):
+def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True):
"""Greedily decode a sequence."""
+ print("output shape", output.shape)
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ blank_label = tokenizer.encode(" ").ids[0]
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
@@ -81,28 +82,27 @@ def train(
print(f"Epoch: {epoch}")
losses = []
for _data in tqdm(train_loader, desc="batches"):
- spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device)
+ spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, _data['input_length'], _data["utterance_length"])
+ loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
loss.backward()
optimizer.step()
scheduler.step()
iter_meter.step()
-
+
losses.append(loss.item())
print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}")
return sum(losses) / len(losses)
-# TODO: profile this function call
-# TODO: only calculate wer and cer at the end, or less often
-# TODO: change this to only be a sanity check and calculate measures after training
+
+
def test(model, device, test_loader, criterion, tokenizer):
"""Test"""
print("\nevaluating...")
@@ -111,21 +111,21 @@ def test(model, device, test_loader, criterion, tokenizer):
test_cer, test_wer = [], []
with torch.no_grad():
for _data in test_loader:
- spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device)
+ spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- # TODO: get rid of this
- loss = criterion(output, labels, _data['input_length'], _data["utterance_length"])
+ loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
test_loss += loss.item() / len(test_loader)
decoded_preds, decoded_targets = greedy_decoder(
- output = output.transpose(0, 1),
- labels = labels,
- label_lengths= _data["utterance_length"],
- tokenizer=tokenizer)
+ output=output.transpose(0, 1),
+ labels=labels,
+ label_lengths=_data["utterance_length"],
+ tokenizer=tokenizer,
+ )
for j, pred in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], pred))
test_wer.append(wer(decoded_targets[j], pred))
@@ -135,9 +135,11 @@ def test(model, device, test_loader, criterion, tokenizer):
print(
f"Test set: Average loss:\
- {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n"
+ {test_loss}, Average CER: {None} Average WER: {None}\n"
)
+ return test_loss, avg_cer, avg_wer
+
def run(
learning_rate: float,
@@ -152,33 +154,34 @@ def run(
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
- device = torch.device("mps")
+ # device = torch.device("mps")
# load dataset
- # TODO: change this from dev split to train split again (was faster for development)
- train_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None)
- valid_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None)
- test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None)
+ train_dataset = MLSDataset(
+ dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None
+ )
+ valid_dataset = MLSDataset(
+ dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None
+ )
- # load tokenizer (bpe by default):
- if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"):
+ # load tokenizer (bpe by default):
+ if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"):
print("There is no tokenizer available. Do you want to train it on the dataset?")
input("Press Enter to continue...")
- train_bpe_tokenizer(
+ train_char_tokenizer(
dataset_path=dataset_path,
language=language,
split="all",
download=False,
- out_path="data/tokenizers/bpe_tokenizer_german_3000.json",
+ out_path="data/tokenizers/char_tokenizer_german.json",
vocab_size=3000,
)
-
- tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
-
- train_dataset.set_tokenizer(tokenizer)
- valid_dataset.set_tokenizer(tokenizer)
- test_dataset.set_tokenizer(tokenizer)
-
+
+ tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+
+ train_dataset.set_tokenizer(tokenizer) # type: ignore
+ valid_dataset.set_tokenizer(tokenizer) # type: ignore
+
print(f"Waveform shape: {train_dataset[0]['waveform'].shape}")
hparams = HParams(
@@ -221,10 +224,10 @@ def run(
hparams["stride"],
hparams["dropout"],
).to(device)
-
+ print(tokenizer.encode(" "))
print("Num Model Parameters", sum((param.nelement() for param in model.parameters())))
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
- criterion = nn.CTCLoss(blank=28).to(device)
+ criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device)
if load:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
@@ -240,7 +243,7 @@ def run(
)
iter_meter = IterMeter()
- for epoch in range(1, epochs + 1):
+ for epoch in range(1, epochs + 1):
loss = train(
model,
device,
@@ -252,7 +255,13 @@ def run(
iter_meter,
)
- test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer)
+ test(
+ model=model,
+ device=device,
+ test_loader=valid_loader,
+ criterion=criterion,
+ tokenizer=tokenizer,
+ )
print("saving epoch", str(epoch))
torch.save(
{"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
@@ -295,5 +304,6 @@ def run_cli(
language="mls_german_opus",
)
-if __name__ == "__main__":
- run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") \ No newline at end of file
+
+if __name__ == "__main__":
+ run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus")
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 3b9b3ca..efecb56 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -8,7 +8,6 @@ import torch
import torchaudio
from tokenizers import Tokenizer
from torch.utils.data import Dataset
-from tqdm import tqdm
from swr2_asr.tokenizer import TokenizerType
@@ -24,26 +23,26 @@ class MLSSplit(str, Enum):
"""Enum specifying dataset as they are defined in the
Multilingual LibriSpeech dataset"""
- train = "train"
- test = "test"
- dev = "dev"
+ TRAIN = "train"
+ TEST = "test"
+ DEV = "dev"
class Split(str, Enum):
"""Extending the MLSSplit class to allow for a custom validatio split"""
- train = "train"
- valid = "valid"
- test = "test"
- dev = "dev"
+ TRAIN = "train"
+ VALID = "valid"
+ TEST = "test"
+ DEV = "dev"
-def split_to_mls_split(split: Split) -> MLSSplit:
+def split_to_mls_split(split_name: Split) -> MLSSplit:
"""Converts the custom split to a MLSSplit"""
- if split == Split.valid:
- return MLSSplit.train
+ if split_name == Split.VALID:
+ return MLSSplit.TRAIN
else:
- return split # type: ignore
+ return split_name # type: ignore
class Sample(TypedDict):
@@ -89,14 +88,21 @@ class MLSDataset(Dataset):
<speakerid>_<bookid>_<chapterid> <utterance>
"""
- def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None):
+ def __init__(
+ self,
+ dataset_path: str,
+ language: str,
+ split: Split,
+ download: bool,
+ spectrogram_hparams: dict | None,
+ ):
"""Initializes the dataset"""
self.dataset_path = dataset_path
self.language = language
self.file_ext = ".opus" if "opus" in language else ".flac"
self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
self.split: Split = split # split used internally
-
+
if spectrogram_hparams is None:
self.spectrogram_hparams = {
"sample_rate": 16000,
@@ -110,7 +116,7 @@ class MLSDataset(Dataset):
}
else:
self.spectrogram_hparams = spectrogram_hparams
-
+
self.dataset_lookup = []
self.tokenizer: type[TokenizerType]
@@ -118,13 +124,14 @@ class MLSDataset(Dataset):
self._validate_local_directory()
self.initialize()
-
def initialize(self) -> None:
"""Initializes the dataset
-
+
Reads the transcripts.txt file and creates a lookup table
"""
- transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt")
+ transcripts_path = os.path.join(
+ self.dataset_path, self.language, self.mls_split, "transcripts.txt"
+ )
with open(transcripts_path, "r", encoding="utf-8") as script_file:
# read all lines in transcripts.txt
@@ -135,12 +142,12 @@ class MLSDataset(Dataset):
identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore
identifier = [path.split("_") for path in identifier]
- if self.split == Split.valid:
+ if self.split == Split.VALID:
np.random.seed(42)
indices = np.random.choice(len(utterances), int(len(utterances) * 0.2))
utterances = [utterances[i] for i in indices]
identifier = [identifier[i] for i in indices]
- elif self.split == Split.train:
+ elif self.split == Split.TRAIN:
np.random.seed(42)
indices = np.random.choice(len(utterances), int(len(utterances) * 0.8))
utterances = [utterances[i] for i in indices]
@@ -212,13 +219,19 @@ class MLSDataset(Dataset):
# resample if necessary
if sample_rate != self.spectrogram_hparams["sample_rate"]:
- resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"])
+ resampler = torchaudio.transforms.Resample(
+ sample_rate, self.spectrogram_hparams["sample_rate"]
+ )
waveform = resampler(waveform)
- spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1)
+ spec = (
+ torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform)
+ .squeeze(0)
+ .transpose(0, 1)
+ )
input_length = spec.shape[0] // 2
-
+
utterance_length = len(utterance)
utterance = self.tokenizer.encode(utterance)
@@ -236,11 +249,11 @@ class MLSDataset(Dataset):
book_id=self.dataset_lookup[idx]["bookid"],
chapter_id=self.dataset_lookup[idx]["chapterid"],
)
-
+
def collate_fn(samples: list[Sample]) -> dict:
"""Collate function for the dataloader
-
+
pads all tensors within a batch to the same dimensions
"""
waveforms = []
@@ -248,18 +261,20 @@ def collate_fn(samples: list[Sample]) -> dict:
labels = []
input_lengths = []
label_lengths = []
-
+
for sample in samples:
waveforms.append(sample["waveform"].transpose(0, 1))
spectrograms.append(sample["spectrogram"])
labels.append(sample["utterance"])
input_lengths.append(sample["spectrogram"].shape[0] // 2)
label_lengths.append(len(sample["utterance"]))
-
+
waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
- spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3)
+ spectrograms = (
+ torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
+ )
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
-
+
return {
"waveform": waveforms,
"spectrogram": spectrograms,
@@ -267,15 +282,15 @@ def collate_fn(samples: list[Sample]) -> dict:
"utterance": labels,
"utterance_length": label_lengths,
}
-
+
if __name__ == "__main__":
- dataset_path = "/Volumes/pherkel/SWR2-ASR"
- language = "mls_german_opus"
- split = Split.train
- download = False
+ DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
+ LANGUAGE = "mls_german_opus"
+ split = Split.TRAIN
+ DOWNLOAD = False
- dataset = MLSDataset(dataset_path, language, split, download, None)
+ dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None)
tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
dataset.set_tokenizer(tok)