aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.philipp.yaml10
-rw-r--r--pyproject.toml2
-rw-r--r--swr2_asr/train.py32
-rw-r--r--swr2_asr/utils/decoder.py2
-rw-r--r--swr2_asr/utils/tokenizer.py8
5 files changed, 33 insertions, 21 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml
index 38a68f8..608720f 100644
--- a/config.philipp.yaml
+++ b/config.philipp.yaml
@@ -3,7 +3,7 @@ dataset:
dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir
language_name: "mls_german_opus"
limited_supervision: True # set to True if you want to use limited supervision
- dataset_percentage: 0.15 # percentage of dataset to use (1.0 = 100%)
+ dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%)
shuffle: True
model:
@@ -33,13 +33,13 @@ decoder:
training:
learning_rate: 0.0005
batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
- epochs: 3
- eval_every_n: 3 # evaluate every n epochs
+ epochs: 100
+ eval_every_n: 1 # evaluate every n epochs
num_workers: 8 # number of workers for dataloader
checkpoints: # use "~" to disable saving/loading
- model_load_path: "YOUR/PATH" # path to load model from
- model_save_path: "YOUR/PATH" # path to save model to
+ model_load_path: "data/epoch67" # path to load model from
+ model_save_path: ~ # path to save model to
inference:
model_load_path: "data/epoch67" # path to load model from \ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index f6d19dd..2f26e5e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,7 @@ target-version = "py310"
line-length = 100
[tool.poetry.scripts]
-train = "swr2_asr.train:run_cli"
+train = "swr2_asr.train:main"
train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer"
train-char-tokenizer = "swr2_asr.tokenizer:train_char_tokenizer"
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 9c7ede9..1e57ba0 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
-from swr2_asr.utils.decoder import greedy_decoder
-from swr2_asr.utils.tokenizer import CharTokenizer
-
+from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.loss_scores import cer, wer
+from swr2_asr.utils.tokenizer import CharTokenizer
class IterMeter:
@@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
# get values from test_args:
model, device, test_loader, criterion, tokenizer, decoder = test_args.values()
- if decoder == "greedy":
- decoder = greedy_decoder
-
model.eval()
test_loss = 0
test_cer, test_wer = [], []
@@ -141,12 +137,13 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
- decoded_preds, decoded_targets = greedy_decoder(
- output.transpose(0, 1), labels, label_lengths, tokenizer
- )
+ decoded_targets = tokenizer.decode_batch(labels)
+ decoded_preds = decoder(output.transpose(0, 1))
for j, _ in enumerate(decoded_preds):
- test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
- test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
+ if j >= len(decoded_targets):
+ break
+ test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0]))
+ test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0]))
avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
@@ -187,6 +184,7 @@ def main(config_path: str):
dataset_config = config_dict.get("dataset", {})
tokenizer_config = config_dict.get("tokenizer", {})
checkpoints_config = config_dict.get("checkpoints", {})
+ decoder_config = config_dict.get("decoder", {})
if not os.path.isdir(dataset_config["dataset_root_path"]):
os.makedirs(dataset_config["dataset_root_path"])
@@ -262,12 +260,19 @@ def main(config_path: str):
if checkpoints_config["model_load_path"] is not None:
checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device)
- model.load_state_dict(checkpoint["model_state_dict"])
+ state_dict = {
+ k[len("module.") :] if k.startswith("module.") else k: v
+ for k, v in checkpoint["model_state_dict"].items()
+ }
+
+ model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]
iter_meter = IterMeter()
+ decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config)
+
for epoch in range(prev_epoch + 1, training_config["epochs"] + 1):
train_args: TrainArgs = {
"model": model,
@@ -283,14 +288,13 @@ def main(config_path: str):
train_loss = train(train_args)
test_loss, test_cer, test_wer = 0, 0, 0
-
test_args: TestArgs = {
"model": model,
"device": device,
"test_loader": valid_loader,
"criterion": criterion,
"tokenizer": tokenizer,
- "decoder": "greedy",
+ "decoder": decoder,
}
if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0:
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index 2b6d29b..6ffdef2 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -1,6 +1,6 @@
"""Decoder for CTC-based ASR.""" ""
-from dataclasses import dataclass
import os
+from dataclasses import dataclass
import torch
from torchaudio.datasets.utils import _extract_tar
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index b1de83e..4e3fddd 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -29,9 +29,17 @@ class CharTokenizer:
"""Use a character map and convert integer labels to an text sequence"""
string = []
for i in labels:
+ i = int(i)
string.append(self.index_map[i])
return "".join(string).replace("<SPACE>", " ")
+ def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ """Use a character map and convert integer labels to an text sequence"""
+ string = []
+ for label in labels:
+ string.append(self.decode(label))
+ return string
+
def get_vocab_size(self) -> int:
"""Get the number of unique characters in the dataset"""
return len(self.char_map)