aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.vscode/settings.json26
-rw-r--r--pyproject.toml3
-rw-r--r--swr2_asr/model_deep_speech.py25
-rw-r--r--swr2_asr/train.py372
-rw-r--r--swr2_asr/utils/data.py14
-rw-r--r--swr2_asr/utils/loss_scores.py (renamed from swr2_asr/loss_scores.py)0
6 files changed, 232 insertions, 208 deletions
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 0054bca..1adbc18 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -1,14 +1,14 @@
{
- "[python]": {
- "editor.formatOnType": true,
- "editor.defaultFormatter": "ms-python.black-formatter",
- "editor.formatOnSave": true,
- "editor.rulers": [88, 120],
- },
- "black-formatter.importStrategy": "fromEnvironment",
- "python.analysis.typeCheckingMode": "basic",
- "ruff.organizeImports": true,
- "ruff.importStrategy": "fromEnvironment",
- "ruff.fixAll": true,
- "ruff.run": "onType"
-} \ No newline at end of file
+ "[python]": {
+ "editor.formatOnType": true,
+ "editor.defaultFormatter": "ms-python.black-formatter",
+ "editor.formatOnSave": true,
+ "editor.rulers": [88, 120]
+ },
+ "black-formatter.importStrategy": "fromEnvironment",
+ "python.analysis.typeCheckingMode": "off",
+ "ruff.organizeImports": true,
+ "ruff.importStrategy": "fromEnvironment",
+ "ruff.fixAll": true,
+ "ruff.run": "onType"
+}
diff --git a/pyproject.toml b/pyproject.toml
index 6f74b49..38cc51a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,12 +23,11 @@ black = "^23.7.0"
mypy = "^1.5.1"
pylint = "^2.17.5"
ruff = "^0.0.285"
-types-tqdm = "^4.66.0.1"
[tool.ruff]
select = ["E", "F", "B", "I"]
fixable = ["ALL"]
-line-length = 120
+line-length = 100
target-version = "py310"
[tool.black]
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
index dd07ff9..8ddbd99 100644
--- a/swr2_asr/model_deep_speech.py
+++ b/swr2_asr/model_deep_speech.py
@@ -1,8 +1,29 @@
-"""Main definition of model"""
+"""Main definition of the Deep speech 2 model by Baidu Research.
+
+Following definition by Assembly AI
+(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/)
+"""
+from typing import TypedDict
+
import torch.nn.functional as F
from torch import nn
+class HParams(TypedDict):
+ """Type for the hyperparameters of the model."""
+
+ n_cnn_layers: int
+ n_rnn_layers: int
+ rnn_dim: int
+ n_class: int
+ n_feats: int
+ stride: int
+ dropout: float
+ learning_rate: float
+ batch_size: int
+ epochs: int
+
+
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
@@ -60,7 +81,7 @@ class ResidualCNN(nn.Module):
class BidirectionalGRU(nn.Module):
- """BIdirectional GRU with Layer Normalization and Dropout"""
+ """Bidirectional GRU with Layer Normalization and Dropout"""
def __init__(
self,
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 9f12bcb..ac7666b 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -7,50 +7,17 @@ import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader
-from tqdm import tqdm
-
-from swr2_asr.model_deep_speech import SpeechRecognitionModel
-from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer
-from swr2_asr.utils import MLSDataset, Split, collate_fn,plot
-
-from .loss_scores import cer, wer
-
-
-class HParams(TypedDict):
- """Type for the hyperparameters of the model."""
-
- n_cnn_layers: int
- n_rnn_layers: int
- rnn_dim: int
- n_class: int
- n_feats: int
- stride: int
- dropout: float
- learning_rate: float
- batch_size: int
- epochs: int
-
-
-def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True):
- """Greedily decode a sequence."""
- print("output shape", output.shape)
- arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- blank_label = tokenizer.encode(" ").ids[0]
- decodes = []
- targets = []
- for i, args in enumerate(arg_maxes):
- decode = []
- targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
- for j, index in enumerate(args):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == args[j - 1]:
- continue
- decode.append(index.item())
- decodes.append(tokenizer.decode(decode))
- return decodes, targets
-
-
-class IterMeter:
+from tqdm.autonotebook import tqdm
+
+from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel
+from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
+from swr2_asr.utils.decoder import greedy_decoder
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+from .utils.loss_scores import cer, wer
+
+
+class IterMeter(object):
"""keeps track of total iterations"""
def __init__(self):
@@ -61,123 +28,195 @@ class IterMeter:
self.val += 1
def get(self):
- """get"""
+ """get steps"""
return self.val
-def train(
- model,
- device,
- train_loader,
- criterion,
- optimizer,
- scheduler,
- epoch,
- iter_meter,
-):
- """Train"""
+class TrainArgs(TypedDict):
+ """Type for the arguments of the training function."""
+
+ model: SpeechRecognitionModel
+ device: torch.device # pylint: disable=no-member
+ train_loader: DataLoader
+ criterion: nn.CTCLoss
+ optimizer: optim.AdamW
+ scheduler: optim.lr_scheduler.OneCycleLR
+ epoch: int
+ iter_meter: IterMeter
+
+
+def train(train_args) -> float:
+ """Train
+ Args:
+ model: model
+ device: device type
+ train_loader: train dataloader
+ criterion: loss function
+ optimizer: optimizer
+ scheduler: learning rate scheduler
+ epoch: epoch number
+ iter_meter: iteration meter
+
+ Returns:
+ avg_train_loss: avg_train_loss for the epoch
+
+ Information:
+ spectrograms: (batch, time, feature)
+ labels: (batch, label_length)
+
+ model output: (batch,time, n_class)
+
+ """
+ # get values from train_args:
+ (
+ model,
+ device,
+ train_loader,
+ criterion,
+ optimizer,
+ scheduler,
+ epoch,
+ iter_meter,
+ ) = train_args.values()
+
model.train()
- print(f"Epoch: {epoch}")
- losses = []
- for _data in tqdm(train_loader, desc="batches"):
- spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
+ print(f"training batch {epoch}")
+ train_losses = []
+ for _data in tqdm(train_loader, desc="Training batches"):
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
+
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
+ loss = criterion(output, labels, input_lengths, label_lengths)
+ train_losses.append(loss)
loss.backward()
optimizer.step()
scheduler.step()
iter_meter.step()
+ avg_train_loss = sum(train_losses) / len(train_losses)
+ print(f"Train set: Average loss: {avg_train_loss:.2f}")
+ return avg_train_loss
+
- losses.append(loss.item())
+class TestArgs(TypedDict):
+ """Type for the arguments of the test function."""
- print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}")
- return sum(losses) / len(losses)
+ model: SpeechRecognitionModel
+ device: torch.device # pylint: disable=no-member
+ test_loader: DataLoader
+ criterion: nn.CTCLoss
+ tokenizer: CharTokenizer
+ decoder: str
-def test(model, device, test_loader, criterion, tokenizer):
- """Test"""
+def test(test_args: TestArgs) -> tuple[float, float, float]:
print("\nevaluating...")
+
+ # get values from test_args:
+ model, device, test_loader, criterion, tokenizer, decoder = test_args.values()
+
+ if decoder == "greedy":
+ decoder = greedy_decoder
+
model.eval()
test_loss = 0
test_cer, test_wer = [], []
with torch.no_grad():
- for _data in test_loader:
- spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
+ for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")):
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
+ loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
decoded_preds, decoded_targets = greedy_decoder(
- output=output.transpose(0, 1),
- labels=labels,
- label_lengths=_data["utterance_length"],
- tokenizer=tokenizer,
+ output.transpose(0, 1), labels, label_lengths, tokenizer
)
- for j, pred in enumerate(decoded_preds):
- test_cer.append(cer(decoded_targets[j], pred))
- test_wer.append(wer(decoded_targets[j], pred))
+ if i == 1:
+ print(f"decoding first sample: {decoded_preds}")
+ for j, _ in enumerate(decoded_preds):
+ test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
+ test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
print(
- f"Test set: Average loss:\
- {test_loss}, Average CER: {None} Average WER: {None}\n"
+ f"Test set: \
+ Average loss: {test_loss:.4f}, \
+ Average CER: {avg_cer:4f} \
+ Average WER: {avg_wer:.4f}\n"
)
return test_loss, avg_cer, avg_wer
-def run(
+def main(
learning_rate: float,
batch_size: int,
epochs: int,
- load: bool,
- path: str,
dataset_path: str,
language: str,
-) -> None:
- """Runs the training script."""
+ limited_supervision: bool,
+ model_load_path: str,
+ model_save_path: str,
+ dataset_percentage: float,
+ eval_every: int,
+ num_workers: int,
+):
+ """Main function for training the model.
+
+ Args:
+ learning_rate: learning rate for the optimizer
+ batch_size: batch size
+ epochs: number of epochs to train
+ dataset_path: path for the dataset
+ language: language of the dataset
+ limited_supervision: whether to use only limited supervision
+ model_load_path: path to load a model from
+ model_save_path: path to save the model to
+ dataset_percentage: percentage of the dataset to use
+ eval_every: evaluate every n epochs
+ num_workers: number of workers for the dataloader
+ """
use_cuda = torch.cuda.is_available()
- torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
- # device = torch.device("mps")
- # load dataset
+ torch.manual_seed(7)
+
+ if not os.path.isdir(dataset_path):
+ os.makedirs(dataset_path)
+
train_dataset = MLSDataset(
- dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True
+ dataset_path,
+ language,
+ Split.TEST,
+ download=True,
+ limited=limited_supervision,
+ size=dataset_percentage,
)
valid_dataset = MLSDataset(
- dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True
+ dataset_path,
+ language,
+ Split.TRAIN,
+ download=False,
+ limited=Falimited_supervisionlse,
+ size=dataset_percentage,
)
- # load tokenizer (bpe by default):
- if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"):
- print("There is no tokenizer available. Do you want to train it on the dataset?")
- input("Press Enter to continue...")
- train_char_tokenizer(
- dataset_path=dataset_path,
- language=language,
- split="all",
- out_path="data/tokenizers/char_tokenizer_german.json",
- )
-
- tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
-
- train_dataset.set_tokenizer(tokenizer) # type: ignore
- valid_dataset.set_tokenizer(tokenizer) # type: ignore
+ # TODO: initialize and possibly train tokenizer if none found
- print(f"Waveform shape: {train_dataset[0]['waveform'].shape}")
+ kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}
hparams = HParams(
n_cnn_layers=3,
@@ -192,24 +231,24 @@ def run(
epochs=epochs,
)
+ train_data_processing = DataProcessing("train", tokenizer)
+ valid_data_processing = DataProcessing("valid", tokenizer)
+
train_loader = DataLoader(
- train_dataset,
+ dataset=train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- collate_fn=lambda x: collate_fn(x),
+ collate_fn=train_data_processing,
+ **kwargs,
)
-
valid_loader = DataLoader(
- valid_dataset,
+ dataset=valid_dataset,
batch_size=hparams["batch_size"],
- shuffle=True,
- collate_fn=lambda x: collate_fn(x),
+ shuffle=False,
+ collate_fn=valid_data_processing,
+ **kwargs,
)
- # enable flag to find the most compatible algorithms in advance
- if use_cuda:
- torch.backends.cudnn.benchmark = True # pylance: disable=no-member
-
model = SpeechRecognitionModel(
hparams["n_cnn_layers"],
hparams["n_rnn_layers"],
@@ -219,16 +258,9 @@ def run(
hparams["stride"],
hparams["dropout"],
).to(device)
- print(tokenizer.encode(" "))
- print("Num Model Parameters", sum((param.nelement() for param in model.parameters())))
+
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
- criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device)
- if load:
- checkpoint = torch.load(path)
- model.load_state_dict(checkpoint["model_state_dict"])
- optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
- epoch = checkpoint["epoch"]
- loss = checkpoint["loss"]
+ criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=hparams["learning_rate"],
@@ -236,77 +268,63 @@ def run(
epochs=hparams["epochs"],
anneal_strategy="linear",
)
+ prev_epoch = 0
+
+ if model_load_path is not None:
+ checkpoint = torch.load(model_load_path)
+ model.load_state_dict(checkpoint["model_state_dict"])
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+ prev_epoch = checkpoint["epoch"]
iter_meter = IterMeter()
- for epoch in range(1, epochs + 1):
- loss = train(
- model,
- device,
- train_loader,
- criterion,
- optimizer,
- scheduler,
- epoch,
- iter_meter,
+ if not os.path.isdir(os.path.dirname(model_save_path)):
+ os.makedirs(os.path.dirname(model_save_path))
+ for epoch in range(prev_epoch + 1, epochs + 1):
+ train_args: TrainArgs = dict(
+ model=model,
+ device=device,
+ train_loader=train_loader,
+ criterion=criterion,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ epoch=epoch,
+ iter_meter=iter_meter,
)
- test_loss, avg_cer, avg_wer = test(
+ train_loss = train(train_args)
+
+ test_loss, test_cer, test_wer = 0, 0, 0
+
+ test_args: TestArgs = dict(
model=model,
device=device,
test_loader=valid_loader,
criterion=criterion,
tokenizer=tokenizer,
+ decoder="greedy",
)
- print("saving epoch", str(epoch))
+
+ if epoch % eval_every == 0:
+ test_loss, test_cer, test_wer = test(test_args)
+
+ if model_save_path is None:
+ continue
+
+ if not os.path.isdir(os.path.dirname(model_save_path)):
+ os.makedirs(os.path.dirname(model_save_path))
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
- "loss": loss,
+ "optimizer_state_dict": optimizer.state_dict(),
+ "train_loss": train_loss,
"test_loss": test_loss,
- "avg_cer": avg_cer,
- "avg_wer": avg_wer,
+ "avg_cer": test_cer,
+ "avg_wer": test_wer,
},
- path + str(epoch),
- plot(epochs,path)
+ model_save_path + str(epoch),
)
-@click.command()
-@click.option("--learning_rate", default=1e-3, help="Learning rate")
-@click.option("--batch_size", default=10, help="Batch size")
-@click.option("--epochs", default=1, help="Number of epochs")
-@click.option("--load", default=False, help="Do you want to load a model?")
-@click.option(
- "--path",
- default="model",
- help="Path where the model will be saved to/loaded from",
-)
-@click.option(
- "--dataset_path",
- default="data/",
- help="Path for the dataset directory",
-)
-def run_cli(
- learning_rate: float,
- batch_size: int,
- epochs: int,
- load: bool,
- path: str,
- dataset_path: str,
-) -> None:
- """Runs the training script."""
-
- run(
- learning_rate=learning_rate,
- batch_size=batch_size,
- epochs=epochs,
- load=load,
- path=path,
- dataset_path=dataset_path,
- language="mls_german_opus",
- )
-
-
if __name__ == "__main__":
- run_cli() # pylint: disable=no-value-for-parameter
+ main() # pylint: disable=no-value-for-parameter
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index 74d10c9..e939e1d 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -76,20 +76,6 @@ def split_to_mls_split(split_name: Split) -> MLSSplit:
return split_name # type: ignore
-class Sample(TypedDict):
- """Type for a sample in the dataset"""
-
- waveform: torch.Tensor
- spectrogram: torch.Tensor
- input_length: int
- utterance: torch.Tensor
- utterance_length: int
- sample_rate: int
- speaker_id: str
- book_id: str
- chapter_id: str
-
-
class MLSDataset(Dataset):
"""Custom Dataset for reading Multilingual LibriSpeech
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/utils/loss_scores.py
index 80285f6..80285f6 100644
--- a/swr2_asr/loss_scores.py
+++ b/swr2_asr/utils/loss_scores.py