From c078ce6789c134aa05607903d3bf9e4be64df45d Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 15:45:35 +0200 Subject: big change! --- swr2_asr/train.py | 372 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 195 insertions(+), 177 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9f12bcb..ac7666b 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -7,50 +7,17 @@ import torch import torch.nn.functional as F from torch import nn, optim from torch.utils.data import DataLoader -from tqdm import tqdm - -from swr2_asr.model_deep_speech import SpeechRecognitionModel -from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer -from swr2_asr.utils import MLSDataset, Split, collate_fn,plot - -from .loss_scores import cer, wer - - -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): - """Greedily decode a sequence.""" - print("output shape", output.shape) - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes, targets - - -class IterMeter: +from tqdm.autonotebook import tqdm + +from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel +from swr2_asr.utils.data import DataProcessing, MLSDataset, Split +from swr2_asr.utils.decoder import greedy_decoder +from swr2_asr.utils.tokenizer import CharTokenizer + +from .utils.loss_scores import cer, wer + + +class IterMeter(object): """keeps track of total iterations""" def __init__(self): @@ -61,123 +28,195 @@ class IterMeter: self.val += 1 def get(self): - """get""" + """get steps""" return self.val -def train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, -): - """Train""" +class TrainArgs(TypedDict): + """Type for the arguments of the training function.""" + + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + train_loader: DataLoader + criterion: nn.CTCLoss + optimizer: optim.AdamW + scheduler: optim.lr_scheduler.OneCycleLR + epoch: int + iter_meter: IterMeter + + +def train(train_args) -> float: + """Train + Args: + model: model + device: device type + train_loader: train dataloader + criterion: loss function + optimizer: optimizer + scheduler: learning rate scheduler + epoch: epoch number + iter_meter: iteration meter + + Returns: + avg_train_loss: avg_train_loss for the epoch + + Information: + spectrograms: (batch, time, feature) + labels: (batch, label_length) + + model output: (batch,time, n_class) + + """ + # get values from train_args: + ( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) = train_args.values() + model.train() - print(f"Epoch: {epoch}") - losses = [] - for _data in tqdm(train_loader, desc="batches"): - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + print(f"training batch {epoch}") + train_losses = [] + for _data in tqdm(train_loader, desc="Training batches"): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) + train_losses.append(loss) loss.backward() optimizer.step() scheduler.step() iter_meter.step() + avg_train_loss = sum(train_losses) / len(train_losses) + print(f"Train set: Average loss: {avg_train_loss:.2f}") + return avg_train_loss + - losses.append(loss.item()) +class TestArgs(TypedDict): + """Type for the arguments of the test function.""" - print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") - return sum(losses) / len(losses) + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + test_loader: DataLoader + criterion: nn.CTCLoss + tokenizer: CharTokenizer + decoder: str -def test(model, device, test_loader, criterion, tokenizer): - """Test""" +def test(test_args: TestArgs) -> tuple[float, float, float]: print("\nevaluating...") + + # get values from test_args: + model, device, test_loader, criterion, tokenizer, decoder = test_args.values() + + if decoder == "greedy": + decoder = greedy_decoder + model.eval() test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for _data in test_loader: - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output=output.transpose(0, 1), - labels=labels, - label_lengths=_data["utterance_length"], - tokenizer=tokenizer, + output.transpose(0, 1), labels, label_lengths, tokenizer ) - for j, pred in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], pred)) - test_wer.append(wer(decoded_targets[j], pred)) + if i == 1: + print(f"decoding first sample: {decoded_preds}") + for j, _ in enumerate(decoded_preds): + test_cer.append(cer(decoded_targets[j], decoded_preds[j])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j])) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) print( - f"Test set: Average loss:\ - {test_loss}, Average CER: {None} Average WER: {None}\n" + f"Test set: \ + Average loss: {test_loss:.4f}, \ + Average CER: {avg_cer:4f} \ + Average WER: {avg_wer:.4f}\n" ) return test_loss, avg_cer, avg_wer -def run( +def main( learning_rate: float, batch_size: int, epochs: int, - load: bool, - path: str, dataset_path: str, language: str, -) -> None: - """Runs the training script.""" + limited_supervision: bool, + model_load_path: str, + model_save_path: str, + dataset_percentage: float, + eval_every: int, + num_workers: int, +): + """Main function for training the model. + + Args: + learning_rate: learning rate for the optimizer + batch_size: batch size + epochs: number of epochs to train + dataset_path: path for the dataset + language: language of the dataset + limited_supervision: whether to use only limited supervision + model_load_path: path to load a model from + model_save_path: path to save the model to + dataset_percentage: percentage of the dataset to use + eval_every: evaluate every n epochs + num_workers: number of workers for the dataloader + """ use_cuda = torch.cuda.is_available() - torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - # device = torch.device("mps") - # load dataset + torch.manual_seed(7) + + if not os.path.isdir(dataset_path): + os.makedirs(dataset_path) + train_dataset = MLSDataset( - dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True + dataset_path, + language, + Split.TEST, + download=True, + limited=limited_supervision, + size=dataset_percentage, ) valid_dataset = MLSDataset( - dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True + dataset_path, + language, + Split.TRAIN, + download=False, + limited=Falimited_supervisionlse, + size=dataset_percentage, ) - # load tokenizer (bpe by default): - if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): - print("There is no tokenizer available. Do you want to train it on the dataset?") - input("Press Enter to continue...") - train_char_tokenizer( - dataset_path=dataset_path, - language=language, - split="all", - out_path="data/tokenizers/char_tokenizer_german.json", - ) - - tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - - train_dataset.set_tokenizer(tokenizer) # type: ignore - valid_dataset.set_tokenizer(tokenizer) # type: ignore + # TODO: initialize and possibly train tokenizer if none found - print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") + kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} hparams = HParams( n_cnn_layers=3, @@ -192,24 +231,24 @@ def run( epochs=epochs, ) + train_data_processing = DataProcessing("train", tokenizer) + valid_data_processing = DataProcessing("valid", tokenizer) + train_loader = DataLoader( - train_dataset, + dataset=train_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: collate_fn(x), + collate_fn=train_data_processing, + **kwargs, ) - valid_loader = DataLoader( - valid_dataset, + dataset=valid_dataset, batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: collate_fn(x), + shuffle=False, + collate_fn=valid_data_processing, + **kwargs, ) - # enable flag to find the most compatible algorithms in advance - if use_cuda: - torch.backends.cudnn.benchmark = True # pylance: disable=no-member - model = SpeechRecognitionModel( hparams["n_cnn_layers"], hparams["n_rnn_layers"], @@ -219,16 +258,9 @@ def run( hparams["stride"], hparams["dropout"], ).to(device) - print(tokenizer.encode(" ")) - print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) + optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) - if load: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["model_state_dict"]) - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - epoch = checkpoint["epoch"] - loss = checkpoint["loss"] + criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hparams["learning_rate"], @@ -236,77 +268,63 @@ def run( epochs=hparams["epochs"], anneal_strategy="linear", ) + prev_epoch = 0 + + if model_load_path is not None: + checkpoint = torch.load(model_load_path) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - for epoch in range(1, epochs + 1): - loss = train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, + if not os.path.isdir(os.path.dirname(model_save_path)): + os.makedirs(os.path.dirname(model_save_path)) + for epoch in range(prev_epoch + 1, epochs + 1): + train_args: TrainArgs = dict( + model=model, + device=device, + train_loader=train_loader, + criterion=criterion, + optimizer=optimizer, + scheduler=scheduler, + epoch=epoch, + iter_meter=iter_meter, ) - test_loss, avg_cer, avg_wer = test( + train_loss = train(train_args) + + test_loss, test_cer, test_wer = 0, 0, 0 + + test_args: TestArgs = dict( model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer=tokenizer, + decoder="greedy", ) - print("saving epoch", str(epoch)) + + if epoch % eval_every == 0: + test_loss, test_cer, test_wer = test(test_args) + + if model_save_path is None: + continue + + if not os.path.isdir(os.path.dirname(model_save_path)): + os.makedirs(os.path.dirname(model_save_path)) torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), - "loss": loss, + "optimizer_state_dict": optimizer.state_dict(), + "train_loss": train_loss, "test_loss": test_loss, - "avg_cer": avg_cer, - "avg_wer": avg_wer, + "avg_cer": test_cer, + "avg_wer": test_wer, }, - path + str(epoch), - plot(epochs,path) + model_save_path + str(epoch), ) -@click.command() -@click.option("--learning_rate", default=1e-3, help="Learning rate") -@click.option("--batch_size", default=10, help="Batch size") -@click.option("--epochs", default=1, help="Number of epochs") -@click.option("--load", default=False, help="Do you want to load a model?") -@click.option( - "--path", - default="model", - help="Path where the model will be saved to/loaded from", -) -@click.option( - "--dataset_path", - default="data/", - help="Path for the dataset directory", -) -def run_cli( - learning_rate: float, - batch_size: int, - epochs: int, - load: bool, - path: str, - dataset_path: str, -) -> None: - """Runs the training script.""" - - run( - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - load=load, - path=path, - dataset_path=dataset_path, - language="mls_german_opus", - ) - - if __name__ == "__main__": - run_cli() # pylint: disable=no-value-for-parameter + main() # pylint: disable=no-value-for-parameter -- cgit v1.2.3 From 58b30927bd870604a4077a8af9ec3cad7b0be21c Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 21:52:42 +0200 Subject: changed config to yaml! --- config.philipp.yaml | 29 ++++++ config.train.yaml | 28 ++++++ poetry.lock | 51 ++++++++++- pyproject.toml | 1 + requirements.txt | 1 + swr2_asr/__main__.py | 12 --- swr2_asr/inference.py | 16 ++-- swr2_asr/model_deep_speech.py | 17 ---- swr2_asr/train.py | 192 +++++++++++++++++++--------------------- swr2_asr/utils/data.py | 7 +- swr2_asr/utils/tokenizer.py | 8 +- swr2_asr/utils/visualization.py | 8 +- 12 files changed, 218 insertions(+), 152 deletions(-) create mode 100644 config.philipp.yaml create mode 100644 config.train.yaml delete mode 100644 swr2_asr/__main__.py (limited to 'swr2_asr/train.py') diff --git a/config.philipp.yaml b/config.philipp.yaml new file mode 100644 index 0000000..638b5ef --- /dev/null +++ b/config.philipp.yaml @@ -0,0 +1,29 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 0.0005 + batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 1 # evaluate every n epochs + num_workers: 4 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: True # set to True if you want to use limited supervision + dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) + shuffle: True + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.json" + +checkpoints: + model_load_path: ~ # path to load model from + model_save_path: ~ # path to save model to \ No newline at end of file diff --git a/config.train.yaml b/config.train.yaml new file mode 100644 index 0000000..c82439d --- /dev/null +++ b/config.train.yaml @@ -0,0 +1,28 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 5e-4 + batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 3 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" + +checkpoints: + model_load_path: "YOUR/PATH" # path to load model from + model_save_path: "YOUR/PATH" # path to save model to \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 3901b8c..a1f916b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1083,6 +1083,55 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "ruff" version = "0.0.285" @@ -1482,4 +1531,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "b9efbbcd85e7d70578496491d81aa6ef8a610a77ffe134c08446300d5de42ed5" +content-hash = "e45a9c1ba8b67cbe83c4b010c3f4718eee990b064b90a3ccd64380387e734faf" diff --git a/pyproject.toml b/pyproject.toml index 38cc51a..f6d19dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ mido = "^1.3.0" tokenizers = "^0.13.3" click = "^8.1.7" matplotlib = "^3.7.2" +pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] black = "^23.7.0" diff --git a/requirements.txt b/requirements.txt index 3b39b56..040fed0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,7 @@ platformdirs==3.10.0 pylint==2.17.5 pyparsing==3.0.9 python-dateutil==2.8.2 +PyYAML==6.0.1 ruff==0.0.285 six==1.16.0 sympy==1.12 diff --git a/swr2_asr/__main__.py b/swr2_asr/__main__.py deleted file mode 100644 index be294fb..0000000 --- a/swr2_asr/__main__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Main entrypoint for swr2-asr.""" -import torch -import torchaudio - -if __name__ == "__main__": - # test if GPU is available - print("GPU available: ", torch.cuda.is_available()) - - # test if torchaudio is installed correctly - print("torchaudio version: ", torchaudio.__version__) - print("torchaudio backend: ", torchaudio.get_audio_backend()) - print("torchaudio info: ", torchaudio.get_audio_backend()) diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index c3eec42..f8342f7 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,11 +1,12 @@ """Training script for the ASR model.""" +from typing import TypedDict + import torch -import torchaudio import torch.nn.functional as F -from typing import TypedDict +import torchaudio -from swr2_asr.tokenizer import CharTokenizer from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.tokenizer import CharTokenizer class HParams(TypedDict): @@ -28,8 +29,7 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member blank_label = tokenizer.encode(" ").ids[0] decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): + for _i, args in enumerate(arg_maxes): decode = [] for j, index in enumerate(args): if index != blank_label: @@ -44,7 +44,7 @@ def main() -> None: """inference function.""" device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) + device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") @@ -90,7 +90,7 @@ def main() -> None: model.load_state_dict(state_dict) # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") + waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member if sample_rate != spectrogram_hparams["sample_rate"]: resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) waveform = resampler(waveform) @@ -103,7 +103,7 @@ def main() -> None: specs = [spec] specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) - output = model(specs) + output = model(specs) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) decodes = greedy_decoder(output, tokenizer) diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index 8ddbd99..77f4c8a 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -3,27 +3,10 @@ Following definition by Assembly AI (https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) """ -from typing import TypedDict - import torch.nn.functional as F from torch import nn -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ac7666b..eb79ee2 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,11 +5,12 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +import yaml from torch import nn, optim from torch.utils.data import DataLoader from tqdm.autonotebook import tqdm -from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel +from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split from swr2_asr.utils.decoder import greedy_decoder from swr2_asr.utils.tokenizer import CharTokenizer @@ -17,7 +18,7 @@ from swr2_asr.utils.tokenizer import CharTokenizer from .utils.loss_scores import cer, wer -class IterMeter(object): +class IterMeter: """keeps track of total iterations""" def __init__(self): @@ -116,6 +117,7 @@ class TestArgs(TypedDict): def test(test_args: TestArgs) -> tuple[float, float, float]: + """Test""" print("\nevaluating...") # get values from test_args: @@ -128,7 +130,7 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")): + for _data in tqdm(test_loader, desc="Validation Batches"): spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) @@ -142,8 +144,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: decoded_preds, decoded_targets = greedy_decoder( output.transpose(0, 1), labels, label_lengths, tokenizer ) - if i == 1: - print(f"decoding first sample: {decoded_preds}") for j, _ in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], decoded_preds[j])) test_wer.append(wer(decoded_targets[j], decoded_preds[j])) @@ -161,157 +161,149 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: return test_loss, avg_cer, avg_wer -def main( - learning_rate: float, - batch_size: int, - epochs: int, - dataset_path: str, - language: str, - limited_supervision: bool, - model_load_path: str, - model_save_path: str, - dataset_percentage: float, - eval_every: int, - num_workers: int, -): +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +def main(config_path: str): """Main function for training the model. - Args: - learning_rate: learning rate for the optimizer - batch_size: batch size - epochs: number of epochs to train - dataset_path: path for the dataset - language: language of the dataset - limited_supervision: whether to use only limited supervision - model_load_path: path to load a model from - model_save_path: path to save the model to - dataset_percentage: percentage of the dataset to use - eval_every: evaluate every n epochs - num_workers: number of workers for the dataloader + Gets all configuration arguments from yaml config file. """ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member torch.manual_seed(7) - if not os.path.isdir(dataset_path): - os.makedirs(dataset_path) + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + training_config = config_dict.get("training", {}) + dataset_config = config_dict.get("dataset", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + checkpoints_config = config_dict.get("checkpoints", {}) + + print(training_config["learning_rate"]) + + if not os.path.isdir(dataset_config["dataset_root_path"]): + os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TEST, - download=True, - limited=limited_supervision, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) valid_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TRAIN, - download=False, - limited=Falimited_supervisionlse, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) - # TODO: initialize and possibly train tokenizer if none found - - kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - ) + kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {} + + if tokenizer_config["tokenizer_path"] is None: + print("Tokenizer not found!") + if click.confirm("Do you want to train a new tokenizer?", default=True): + pass + else: + return + tokenizer = CharTokenizer.train( + dataset_config["dataset_root_path"], dataset_config["language_name"] + ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) train_data_processing = DataProcessing("train", tokenizer) valid_data_processing = DataProcessing("valid", tokenizer) train_loader = DataLoader( dataset=train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=train_data_processing, **kwargs, ) valid_loader = DataLoader( dataset=valid_dataset, - batch_size=hparams["batch_size"], - shuffle=False, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=valid_data_processing, **kwargs, ) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"]) criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=hparams["learning_rate"], + max_lr=training_config["learning_rate"], steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], + epochs=training_config["epochs"], anneal_strategy="linear", ) prev_epoch = 0 - if model_load_path is not None: - checkpoint = torch.load(model_load_path) + if checkpoints_config["model_load_path"] is not None: + checkpoint = torch.load(checkpoints_config["model_load_path"]) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) - for epoch in range(prev_epoch + 1, epochs + 1): - train_args: TrainArgs = dict( - model=model, - device=device, - train_loader=train_loader, - criterion=criterion, - optimizer=optimizer, - scheduler=scheduler, - epoch=epoch, - iter_meter=iter_meter, - ) + + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): + train_args: TrainArgs = { + "model": model, + "device": device, + "train_loader": train_loader, + "criterion": criterion, + "optimizer": optimizer, + "scheduler": scheduler, + "epoch": epoch, + "iter_meter": iter_meter, + } train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = dict( - model=model, - device=device, - test_loader=valid_loader, - criterion=criterion, - tokenizer=tokenizer, - decoder="greedy", - ) + test_args: TestArgs = { + "model": model, + "device": device, + "test_loader": valid_loader, + "criterion": criterion, + "tokenizer": tokenizer, + "decoder": "greedy", + } - if epoch % eval_every == 0: + if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: test_loss, test_cer, test_wer = test(test_args) - if model_save_path is None: + if checkpoints_config["model_save_path"] is None: continue - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) + if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])): + os.makedirs(os.path.dirname(checkpoints_config["model_save_path"])) + torch.save( { "epoch": epoch, @@ -322,7 +314,7 @@ def main( "avg_cer": test_cer, "avg_wer": test_wer, }, - model_save_path + str(epoch), + checkpoints_config["model_save_path"] + str(epoch), ) diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index e939e1d..0e06eec 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -1,13 +1,12 @@ """Class containing utils for the ASR system.""" import os from enum import Enum -from typing import TypedDict import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -125,7 +124,7 @@ class MLSDataset(Dataset): self._handle_download_dataset(download) self._validate_local_directory() - if limited and (split == Split.TRAIN or split == Split.VALID): + if limited and split in (Split.TRAIN, Split.VALID): self.initialize_limited() else: self.initialize() @@ -351,8 +350,6 @@ class MLSDataset(Dataset): if __name__ == "__main__": - from torch.utils.data import DataLoader - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" LANGUAGE = "mls_german_opus" split = Split.DEV diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 5482bbe..22569eb 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -1,8 +1,6 @@ """Tokenizer for Multilingual Librispeech datasets""" - - -from datetime import datetime import os +from datetime import datetime from tqdm.autonotebook import tqdm @@ -119,8 +117,8 @@ class CharTokenizer: line = line.strip() if line: char, index = line.split() - tokenizer.char_map[char] = int(index) - tokenizer.index_map[int(index)] = char + load_tokenizer.char_map[char] = int(index) + load_tokenizer.index_map[int(index)] = char return load_tokenizer diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index 80f942a..a55d0d5 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -6,10 +6,10 @@ import torch def plot(epochs, path): """Plots the losses over the epochs""" - losses = list() - test_losses = list() - cers = list() - wers = list() + losses = [] + test_losses = [] + cers = [] + wers = [] for epoch in range(1, epochs + 1): current_state = torch.load(path + str(epoch)) losses.append(current_state["loss"]) -- cgit v1.2.3 From 64dbb9d32a51b1bce6c9de67069dc8f5943a5399 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 22:16:26 +0200 Subject: added n_feats from config --- config.philipp.yaml | 2 +- swr2_asr/train.py | 4 ++-- swr2_asr/utils/data.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/config.philipp.yaml b/config.philipp.yaml index 638b5ef..6b905cd 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -4,7 +4,7 @@ model: rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets training: learning_rate: 0.0005 diff --git a/swr2_asr/train.py b/swr2_asr/train.py index eb79ee2..ca70d21 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -223,8 +223,8 @@ def main(config_path: str): ) tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) - train_data_processing = DataProcessing("train", tokenizer) - valid_data_processing = DataProcessing("valid", tokenizer) + train_data_processing = DataProcessing("train", tokenizer, {"n_feats": model_config["n_feats"]}) + valid_data_processing = DataProcessing("valid", tokenizer, {"n_feats": model_config["n_feats"]}) train_loader = DataLoader( dataset=train_dataset, diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 0e06eec..10f0ea8 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -15,18 +15,19 @@ from swr2_asr.utils.tokenizer import CharTokenizer class DataProcessing: """Data processing class for the dataloader""" - def __init__(self, data_type: str, tokenizer: CharTokenizer): + def __init__(self, data_type: str, tokenizer: CharTokenizer, hparams: dict): self.data_type = data_type self.tokenizer = tokenizer + n_features = hparams["n_feats"] if data_type == "train": self.audio_transform = torch.nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=n_features), torchaudio.transforms.FrequencyMasking(freq_mask_param=30), torchaudio.transforms.TimeMasking(time_mask_param=100), ) elif data_type == "valid": - self.audio_transform = torchaudio.transforms.MelSpectrogram() + self.audio_transform = torchaudio.transforms.MelSpectrogram(n_mels=n_features) def __call__(self, data) -> tuple[Tensor, Tensor, list, list]: spectrograms = [] -- cgit v1.2.3 From 6f5513140f153206cfa91df3077e67ce58043d35 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 22:58:19 +0200 Subject: model loading is broken :( --- config.philipp.yaml | 9 +++- config.train.yaml | 28 ---------- config.yaml | 34 ++++++++++++ swr2_asr/inference.py | 140 ++++++++++++++++++++++---------------------------- swr2_asr/train.py | 2 +- 5 files changed, 103 insertions(+), 110 deletions(-) delete mode 100644 config.train.yaml create mode 100644 config.yaml (limited to 'swr2_asr/train.py') diff --git a/config.philipp.yaml b/config.philipp.yaml index 6b905cd..4a723c6 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -12,6 +12,7 @@ training: epochs: 3 eval_every_n: 1 # evaluate every n epochs num_workers: 4 # number of workers for dataloader + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically dataset: download: True @@ -25,5 +26,9 @@ tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" checkpoints: - model_load_path: ~ # path to load model from - model_save_path: ~ # path to save model to \ No newline at end of file + model_load_path: "data/runs/epoch30" # path to load model from + model_save_path: ~ # path to save model to + +inference: + model_load_path: "data/runs/epoch30" # path to load model from + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file diff --git a/config.train.yaml b/config.train.yaml deleted file mode 100644 index c82439d..0000000 --- a/config.train.yaml +++ /dev/null @@ -1,28 +0,0 @@ -model: - n_cnn_layers: 3 - n_rnn_layers: 5 - rnn_dim: 512 - n_feats: 128 # number of mel features - stride: 2 - dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets - -training: - learning_rate: 5e-4 - batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 3 - eval_every_n: 3 # evaluate every n epochs - num_workers: 8 # number of workers for dataloader - -dataset: - download: True - dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir - language_name: "mls_german_opus" - limited_supervision: False # set to True if you want to use limited supervision - dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) - -tokenizer: - tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" - -checkpoints: - model_load_path: "YOUR/PATH" # path to load model from - model_save_path: "YOUR/PATH" # path to save model to \ No newline at end of file diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..e5ff43a --- /dev/null +++ b/config.yaml @@ -0,0 +1,34 @@ +model: + n_cnn_layers: 3 + n_rnn_layers: 5 + rnn_dim: 512 + n_feats: 128 # number of mel features + stride: 2 + dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + +training: + learning_rate: 5e-4 + batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 3 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + +dataset: + download: True + dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + shuffle: True + +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" + +checkpoints: + model_load_path: "YOUR/PATH" # path to load model from + model_save_path: "YOUR/PATH" # path to save model to + +inference: + model_load_path: "YOUR/PATH" # path to load model from + beam_width: 10 # beam width for beam search + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index f8342f7..6495a9a 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,35 +1,20 @@ """Training script for the ASR model.""" -from typing import TypedDict - +import click import torch import torch.nn.functional as F import torchaudio +import yaml from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.tokenizer import CharTokenizer -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, collapse_repeated=True): +def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] + blank_label = tokenizer.get_blank_token() decodes = [] - for _i, args in enumerate(arg_maxes): + for args in arg_maxes: decode = [] for j, index in enumerate(args): if index != blank_label: @@ -40,75 +25,72 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): return decodes -def main() -> None: +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +@click.option( + "--file_path", + help="Path to audio file", + type=click.Path(exists=True), +) +def main(config_path: str, file_path: str) -> None: """inference function.""" - - device = "cuda" if torch.cuda.is_available() else "cpu" + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + inference_config = config_dict.get("inference", {}) + + if inference_config["device"] == "cpu": + device = "cpu" + elif inference_config["device"] == "cuda": + device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # pylint: disable=no-member - tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") - - spectrogram_hparams = { - "sample_rate": 16000, - "n_fft": 400, - "win_length": 400, - "hop_length": 160, - "n_mels": 128, - "f_min": 0, - "f_max": 8000, - "power": 2.0, - } - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=0.1, - batch_size=30, - epochs=100, - ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - checkpoint = torch.load("model8", map_location=device) - state_dict = { - k[len("module.") :] if k.startswith("module.") else k: v - for k, v in checkpoint["model_state_dict"].items() - } - model.load_state_dict(state_dict) - - # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member - if sample_rate != spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) + checkpoint = torch.load(inference_config["model_load_path"], map_location=device) + print(checkpoint["model_state_dict"].keys()) + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + model.eval() + waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member + if waveform.shape[0] != 1: + waveform = waveform[1] + waveform = waveform.unsqueeze(0) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) + sample_rate = 16000 + + data_processing = torchaudio.transforms.MelSpectrogram(n_mels=model_config["n_feats"]) + + spec = data_processing(waveform).squeeze(0).transpose(0, 1) - spec = ( - torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform) - .squeeze(0) - .transpose(0, 1) - ) - specs = [spec] - specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) + spec = spec.unsqueeze(0) + spec = spec.transpose(1, 2) + spec = spec.unsqueeze(0) + output = model(spec) # pylint: disable=not-callable + output = F.log_softmax(output, dim=2) # (batch, time, n_class) + decoded_preds = greedy_decoder(output, tokenizer) - output = model(specs) # pylint: disable=not-callable - output = F.log_softmax(output, dim=2) - output = output.transpose(0, 1) # (time, batch, n_class) - decodes = greedy_decoder(output, tokenizer) - print(decodes) + print(decoded_preds) if __name__ == "__main__": - main() + main() # pylint: disable=no-value-for-parameter diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ca70d21..ec25918 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -263,7 +263,7 @@ def main(config_path: str): prev_epoch = 0 if checkpoints_config["model_load_path"] is not None: - checkpoint = torch.load(checkpoints_config["model_load_path"]) + checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] -- cgit v1.2.3 From 4aff1fcd70cd8601541a1dd5bd820b0263ed1362 Mon Sep 17 00:00:00 2001 From: Philipp Merkel Date: Mon, 11 Sep 2023 22:36:28 +0000 Subject: fix: switched up training and test splits in train.py --- config.philipp.yaml | 22 +++++++++++----------- swr2_asr/train.py | 8 +++----- swr2_asr/utils/data.py | 31 ------------------------------- swr2_asr/utils/tokenizer.py | 12 ------------ 4 files changed, 14 insertions(+), 59 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/config.philipp.yaml b/config.philipp.yaml index 4a723c6..f72ce2e 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -4,30 +4,30 @@ model: rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets + dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets training: learning_rate: 0.0005 - batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 3 - eval_every_n: 1 # evaluate every n epochs + batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 150 + eval_every_n: 5 # evaluate every n epochs num_workers: 4 # number of workers for dataloader device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically dataset: - download: True - dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir + download: true + dataset_root_path: "data" # files will be downloaded into this dir language_name: "mls_german_opus" - limited_supervision: True # set to True if you want to use limited supervision - dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) - shuffle: True + limited_supervision: false # set to True if you want to use limited supervision + dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%) + shuffle: true tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" checkpoints: - model_load_path: "data/runs/epoch30" # path to load model from - model_save_path: ~ # path to save model to + model_load_path: "data/runs/epoch31" # path to load model from + model_save_path: "data/runs/epoch" # path to save model to inference: model_load_path: "data/runs/epoch30" # path to load model from diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ec25918..3ed3ac8 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -187,16 +187,14 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) - - print(training_config["learning_rate"]) - + if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TEST, + Split.TRAIN, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], @@ -204,7 +202,7 @@ def main(config_path: str): valid_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TRAIN, + Split.TEST, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 10f0ea8..d551c98 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -134,11 +134,6 @@ class MLSDataset(Dataset): def initialize_limited(self) -> None: """Initializes the limited supervision dataset""" - # get file handles - # get file paths - # get transcripts - # create train or validation split - handles = set() train_root_path = os.path.join(self.dataset_path, self.language, "train") @@ -348,29 +343,3 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore - - -if __name__ == "__main__": - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" - LANGUAGE = "mls_german_opus" - split = Split.DEV - DOWNLOAD = False - - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD) - - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=True, - collate_fn=DataProcessing( - "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - ), - ) - - for batch in dataloader: - print(batch) - break - - print(len(dataset)) - - print(dataset[0]) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 22569eb..1cc7b84 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,15 +120,3 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer - - -if __name__ == "__main__": - tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus") - print(tokenizer.char_map) - print(tokenizer.index_map) - print(tokenizer.get_vocab_size()) - print(tokenizer.get_blank_token()) - print(tokenizer.get_unk_token()) - print(tokenizer.get_space_token()) - print(tokenizer.encode("hallo welt")) - print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) -- cgit v1.2.3 From 7b71dab87591e04d874cd636614450b0e65e3f2b Mon Sep 17 00:00:00 2001 From: Pherkel Date: Tue, 12 Sep 2023 14:14:19 +0200 Subject: fixed black formatting issue --- swr2_asr/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 3ed3ac8..ffdae73 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -187,7 +187,7 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) - + if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) -- cgit v1.2.3