aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
authorPherkel2023-09-12 14:19:15 +0200
committerGitHub2023-09-12 14:19:15 +0200
commit7a9a6c783e69b5a537a3d3f5bfe8d5fdc656c807 (patch)
tree0725631b9b68aeb65b292420a15941dcfa3fc04f /swr2_asr/train.py
parentf9846193289c81d89342b6a36e951605c2cfa189 (diff)
parent7b71dab87591e04d874cd636614450b0e65e3f2b (diff)
Merge pull request #37 from Algo-Boys/fix/ultimate
Fix/ultimate
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py424
1 files changed, 216 insertions, 208 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 9f12bcb..ffdae73 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -5,49 +5,17 @@ from typing import TypedDict
import click
import torch
import torch.nn.functional as F
+import yaml
from torch import nn, optim
from torch.utils.data import DataLoader
-from tqdm import tqdm
+from tqdm.autonotebook import tqdm
from swr2_asr.model_deep_speech import SpeechRecognitionModel
-from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer
-from swr2_asr.utils import MLSDataset, Split, collate_fn,plot
-
-from .loss_scores import cer, wer
-
-
-class HParams(TypedDict):
- """Type for the hyperparameters of the model."""
-
- n_cnn_layers: int
- n_rnn_layers: int
- rnn_dim: int
- n_class: int
- n_feats: int
- stride: int
- dropout: float
- learning_rate: float
- batch_size: int
- epochs: int
-
-
-def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True):
- """Greedily decode a sequence."""
- print("output shape", output.shape)
- arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- blank_label = tokenizer.encode(" ").ids[0]
- decodes = []
- targets = []
- for i, args in enumerate(arg_maxes):
- decode = []
- targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
- for j, index in enumerate(args):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == args[j - 1]:
- continue
- decode.append(index.item())
- decodes.append(tokenizer.decode(decode))
- return decodes, targets
+from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
+from swr2_asr.utils.decoder import greedy_decoder
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+from .utils.loss_scores import cer, wer
class IterMeter:
@@ -61,252 +29,292 @@ class IterMeter:
self.val += 1
def get(self):
- """get"""
+ """get steps"""
return self.val
-def train(
- model,
- device,
- train_loader,
- criterion,
- optimizer,
- scheduler,
- epoch,
- iter_meter,
-):
- """Train"""
+class TrainArgs(TypedDict):
+ """Type for the arguments of the training function."""
+
+ model: SpeechRecognitionModel
+ device: torch.device # pylint: disable=no-member
+ train_loader: DataLoader
+ criterion: nn.CTCLoss
+ optimizer: optim.AdamW
+ scheduler: optim.lr_scheduler.OneCycleLR
+ epoch: int
+ iter_meter: IterMeter
+
+
+def train(train_args) -> float:
+ """Train
+ Args:
+ model: model
+ device: device type
+ train_loader: train dataloader
+ criterion: loss function
+ optimizer: optimizer
+ scheduler: learning rate scheduler
+ epoch: epoch number
+ iter_meter: iteration meter
+
+ Returns:
+ avg_train_loss: avg_train_loss for the epoch
+
+ Information:
+ spectrograms: (batch, time, feature)
+ labels: (batch, label_length)
+
+ model output: (batch,time, n_class)
+
+ """
+ # get values from train_args:
+ (
+ model,
+ device,
+ train_loader,
+ criterion,
+ optimizer,
+ scheduler,
+ epoch,
+ iter_meter,
+ ) = train_args.values()
+
model.train()
- print(f"Epoch: {epoch}")
- losses = []
- for _data in tqdm(train_loader, desc="batches"):
- spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
+ print(f"training batch {epoch}")
+ train_losses = []
+ for _data in tqdm(train_loader, desc="Training batches"):
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
+
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
+ loss = criterion(output, labels, input_lengths, label_lengths)
+ train_losses.append(loss)
loss.backward()
optimizer.step()
scheduler.step()
iter_meter.step()
+ avg_train_loss = sum(train_losses) / len(train_losses)
+ print(f"Train set: Average loss: {avg_train_loss:.2f}")
+ return avg_train_loss
+
- losses.append(loss.item())
+class TestArgs(TypedDict):
+ """Type for the arguments of the test function."""
- print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}")
- return sum(losses) / len(losses)
+ model: SpeechRecognitionModel
+ device: torch.device # pylint: disable=no-member
+ test_loader: DataLoader
+ criterion: nn.CTCLoss
+ tokenizer: CharTokenizer
+ decoder: str
-def test(model, device, test_loader, criterion, tokenizer):
+def test(test_args: TestArgs) -> tuple[float, float, float]:
"""Test"""
print("\nevaluating...")
+
+ # get values from test_args:
+ model, device, test_loader, criterion, tokenizer, decoder = test_args.values()
+
+ if decoder == "greedy":
+ decoder = greedy_decoder
+
model.eval()
test_loss = 0
test_cer, test_wer = [], []
with torch.no_grad():
- for _data in test_loader:
- spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
+ for _data in tqdm(test_loader, desc="Validation Batches"):
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
+ loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
decoded_preds, decoded_targets = greedy_decoder(
- output=output.transpose(0, 1),
- labels=labels,
- label_lengths=_data["utterance_length"],
- tokenizer=tokenizer,
+ output.transpose(0, 1), labels, label_lengths, tokenizer
)
- for j, pred in enumerate(decoded_preds):
- test_cer.append(cer(decoded_targets[j], pred))
- test_wer.append(wer(decoded_targets[j], pred))
+ for j, _ in enumerate(decoded_preds):
+ test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
+ test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
print(
- f"Test set: Average loss:\
- {test_loss}, Average CER: {None} Average WER: {None}\n"
+ f"Test set: \
+ Average loss: {test_loss:.4f}, \
+ Average CER: {avg_cer:4f} \
+ Average WER: {avg_wer:.4f}\n"
)
return test_loss, avg_cer, avg_wer
-def run(
- learning_rate: float,
- batch_size: int,
- epochs: int,
- load: bool,
- path: str,
- dataset_path: str,
- language: str,
-) -> None:
- """Runs the training script."""
+@click.command()
+@click.option(
+ "--config_path",
+ default="config.yaml",
+ help="Path to yaml config file",
+ type=click.Path(exists=True),
+)
+def main(config_path: str):
+ """Main function for training the model.
+
+ Gets all configuration arguments from yaml config file.
+ """
use_cuda = torch.cuda.is_available()
- torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
- # device = torch.device("mps")
- # load dataset
+ torch.manual_seed(7)
+
+ with open(config_path, "r", encoding="utf-8") as yaml_file:
+ config_dict = yaml.safe_load(yaml_file)
+
+ # Create separate dictionaries for each top-level key
+ model_config = config_dict.get("model", {})
+ training_config = config_dict.get("training", {})
+ dataset_config = config_dict.get("dataset", {})
+ tokenizer_config = config_dict.get("tokenizer", {})
+ checkpoints_config = config_dict.get("checkpoints", {})
+
+ if not os.path.isdir(dataset_config["dataset_root_path"]):
+ os.makedirs(dataset_config["dataset_root_path"])
+
train_dataset = MLSDataset(
- dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True
+ dataset_config["dataset_root_path"],
+ dataset_config["language_name"],
+ Split.TRAIN,
+ download=dataset_config["download"],
+ limited=dataset_config["limited_supervision"],
+ size=dataset_config["dataset_percentage"],
)
valid_dataset = MLSDataset(
- dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True
+ dataset_config["dataset_root_path"],
+ dataset_config["language_name"],
+ Split.TEST,
+ download=dataset_config["download"],
+ limited=dataset_config["limited_supervision"],
+ size=dataset_config["dataset_percentage"],
)
- # load tokenizer (bpe by default):
- if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"):
- print("There is no tokenizer available. Do you want to train it on the dataset?")
- input("Press Enter to continue...")
- train_char_tokenizer(
- dataset_path=dataset_path,
- language=language,
- split="all",
- out_path="data/tokenizers/char_tokenizer_german.json",
- )
-
- tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+ kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {}
- train_dataset.set_tokenizer(tokenizer) # type: ignore
- valid_dataset.set_tokenizer(tokenizer) # type: ignore
-
- print(f"Waveform shape: {train_dataset[0]['waveform'].shape}")
+ if tokenizer_config["tokenizer_path"] is None:
+ print("Tokenizer not found!")
+ if click.confirm("Do you want to train a new tokenizer?", default=True):
+ pass
+ else:
+ return
+ tokenizer = CharTokenizer.train(
+ dataset_config["dataset_root_path"], dataset_config["language_name"]
+ )
+ tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
- hparams = HParams(
- n_cnn_layers=3,
- n_rnn_layers=5,
- rnn_dim=512,
- n_class=tokenizer.get_vocab_size(),
- n_feats=128,
- stride=2,
- dropout=0.1,
- learning_rate=learning_rate,
- batch_size=batch_size,
- epochs=epochs,
- )
+ train_data_processing = DataProcessing("train", tokenizer, {"n_feats": model_config["n_feats"]})
+ valid_data_processing = DataProcessing("valid", tokenizer, {"n_feats": model_config["n_feats"]})
train_loader = DataLoader(
- train_dataset,
- batch_size=hparams["batch_size"],
- shuffle=True,
- collate_fn=lambda x: collate_fn(x),
+ dataset=train_dataset,
+ batch_size=training_config["batch_size"],
+ shuffle=dataset_config["shuffle"],
+ collate_fn=train_data_processing,
+ **kwargs,
)
-
valid_loader = DataLoader(
- valid_dataset,
- batch_size=hparams["batch_size"],
- shuffle=True,
- collate_fn=lambda x: collate_fn(x),
+ dataset=valid_dataset,
+ batch_size=training_config["batch_size"],
+ shuffle=dataset_config["shuffle"],
+ collate_fn=valid_data_processing,
+ **kwargs,
)
- # enable flag to find the most compatible algorithms in advance
- if use_cuda:
- torch.backends.cudnn.benchmark = True # pylance: disable=no-member
-
model = SpeechRecognitionModel(
- hparams["n_cnn_layers"],
- hparams["n_rnn_layers"],
- hparams["rnn_dim"],
- hparams["n_class"],
- hparams["n_feats"],
- hparams["stride"],
- hparams["dropout"],
+ model_config["n_cnn_layers"],
+ model_config["n_rnn_layers"],
+ model_config["rnn_dim"],
+ tokenizer.get_vocab_size(),
+ model_config["n_feats"],
+ model_config["stride"],
+ model_config["dropout"],
).to(device)
- print(tokenizer.encode(" "))
- print("Num Model Parameters", sum((param.nelement() for param in model.parameters())))
- optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
- criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device)
- if load:
- checkpoint = torch.load(path)
- model.load_state_dict(checkpoint["model_state_dict"])
- optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
- epoch = checkpoint["epoch"]
- loss = checkpoint["loss"]
+
+ optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"])
+ criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
- max_lr=hparams["learning_rate"],
+ max_lr=training_config["learning_rate"],
steps_per_epoch=int(len(train_loader)),
- epochs=hparams["epochs"],
+ epochs=training_config["epochs"],
anneal_strategy="linear",
)
+ prev_epoch = 0
+
+ if checkpoints_config["model_load_path"] is not None:
+ checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device)
+ model.load_state_dict(checkpoint["model_state_dict"])
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+ prev_epoch = checkpoint["epoch"]
iter_meter = IterMeter()
- for epoch in range(1, epochs + 1):
- loss = train(
- model,
- device,
- train_loader,
- criterion,
- optimizer,
- scheduler,
- epoch,
- iter_meter,
- )
- test_loss, avg_cer, avg_wer = test(
- model=model,
- device=device,
- test_loader=valid_loader,
- criterion=criterion,
- tokenizer=tokenizer,
- )
- print("saving epoch", str(epoch))
+ for epoch in range(prev_epoch + 1, training_config["epochs"] + 1):
+ train_args: TrainArgs = {
+ "model": model,
+ "device": device,
+ "train_loader": train_loader,
+ "criterion": criterion,
+ "optimizer": optimizer,
+ "scheduler": scheduler,
+ "epoch": epoch,
+ "iter_meter": iter_meter,
+ }
+
+ train_loss = train(train_args)
+
+ test_loss, test_cer, test_wer = 0, 0, 0
+
+ test_args: TestArgs = {
+ "model": model,
+ "device": device,
+ "test_loader": valid_loader,
+ "criterion": criterion,
+ "tokenizer": tokenizer,
+ "decoder": "greedy",
+ }
+
+ if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0:
+ test_loss, test_cer, test_wer = test(test_args)
+
+ if checkpoints_config["model_save_path"] is None:
+ continue
+
+ if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])):
+ os.makedirs(os.path.dirname(checkpoints_config["model_save_path"]))
+
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
- "loss": loss,
+ "optimizer_state_dict": optimizer.state_dict(),
+ "train_loss": train_loss,
"test_loss": test_loss,
- "avg_cer": avg_cer,
- "avg_wer": avg_wer,
+ "avg_cer": test_cer,
+ "avg_wer": test_wer,
},
- path + str(epoch),
- plot(epochs,path)
+ checkpoints_config["model_save_path"] + str(epoch),
)
-@click.command()
-@click.option("--learning_rate", default=1e-3, help="Learning rate")
-@click.option("--batch_size", default=10, help="Batch size")
-@click.option("--epochs", default=1, help="Number of epochs")
-@click.option("--load", default=False, help="Do you want to load a model?")
-@click.option(
- "--path",
- default="model",
- help="Path where the model will be saved to/loaded from",
-)
-@click.option(
- "--dataset_path",
- default="data/",
- help="Path for the dataset directory",
-)
-def run_cli(
- learning_rate: float,
- batch_size: int,
- epochs: int,
- load: bool,
- path: str,
- dataset_path: str,
-) -> None:
- """Runs the training script."""
-
- run(
- learning_rate=learning_rate,
- batch_size=batch_size,
- epochs=epochs,
- load=load,
- path=path,
- dataset_path=dataset_path,
- language="mls_german_opus",
- )
-
-
if __name__ == "__main__":
- run_cli() # pylint: disable=no-value-for-parameter
+ main() # pylint: disable=no-value-for-parameter