diff options
author | Pherkel | 2023-08-20 14:52:15 +0200 |
---|---|---|
committer | Pherkel | 2023-08-20 14:52:15 +0200 |
commit | 3ae21cbc432113531aa15e0cebd8a34c3767ba35 (patch) | |
tree | 51e766e228777bc2e776d49a8b85b480eacf0afc | |
parent | 1d7939c80024b7f55b26a3bd297fc4b8d54dc4f6 (diff) |
added todos
-rw-r--r-- | swr2_asr/tokenizer.py | 6 | ||||
-rw-r--r-- | swr2_asr/train.py | 5 |
2 files changed, 9 insertions, 2 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index d9cd622..79d6727 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -26,7 +26,7 @@ class CharTokenizer: Simply checks what characters are in the dataset and uses them as tokens. Exposes the same interface as tokenizers from the huggingface library, i.e. - encode, decode, decode_batch, save, from_file and train. + encode, decode, decode_batch, get_vocab_size, save, from_file and train. """ def __init__(self): @@ -140,6 +140,10 @@ class CharTokenizer: strings.append("".join(string).replace("<SPACE>", " ")) return strings + def get_vocab_size(self): + """Get the size of the vocabulary""" + return len(self.char_map) + def save(self, path: str): """Save the tokenizer to a file""" with open(path, "w", encoding="utf-8") as file: diff --git a/swr2_asr/train.py b/swr2_asr/train.py index d13683f..8943f71 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -57,6 +57,7 @@ def data_processing(data, data_type="train"): def greedy_decoder( output, labels, label_lengths, blank_label=28, collapse_repeated=True ): + # TODO: adopt to support both tokenizers """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] @@ -77,6 +78,7 @@ def greedy_decoder( return decodes, targets +# TODO: restructure into own file / class class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" @@ -280,6 +282,7 @@ def train( return loss.item() +# TODO: check how dataloader can be made more efficient def test(model, device, test_loader, criterion): """Test""" print("\nevaluating...") @@ -327,7 +330,7 @@ def run( "n_cnn_layers": 3, "n_rnn_layers": 5, "rnn_dim": 512, - "n_class": 36, + "n_class": 36, # TODO: dynamically determine this from vocab size "n_feats": 128, "stride": 2, "dropout": 0.1, |