aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPherkel2023-08-20 14:52:15 +0200
committerPherkel2023-08-20 14:52:15 +0200
commit3ae21cbc432113531aa15e0cebd8a34c3767ba35 (patch)
tree51e766e228777bc2e776d49a8b85b480eacf0afc
parent1d7939c80024b7f55b26a3bd297fc4b8d54dc4f6 (diff)
added todos
-rw-r--r--swr2_asr/tokenizer.py6
-rw-r--r--swr2_asr/train.py5
2 files changed, 9 insertions, 2 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index d9cd622..79d6727 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -26,7 +26,7 @@ class CharTokenizer:
Simply checks what characters are in the dataset and uses them as tokens.
Exposes the same interface as tokenizers from the huggingface library, i.e.
- encode, decode, decode_batch, save, from_file and train.
+ encode, decode, decode_batch, get_vocab_size, save, from_file and train.
"""
def __init__(self):
@@ -140,6 +140,10 @@ class CharTokenizer:
strings.append("".join(string).replace("<SPACE>", " "))
return strings
+ def get_vocab_size(self):
+ """Get the size of the vocabulary"""
+ return len(self.char_map)
+
def save(self, path: str):
"""Save the tokenizer to a file"""
with open(path, "w", encoding="utf-8") as file:
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index d13683f..8943f71 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -57,6 +57,7 @@ def data_processing(data, data_type="train"):
def greedy_decoder(
output, labels, label_lengths, blank_label=28, collapse_repeated=True
):
+ # TODO: adopt to support both tokenizers
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
@@ -77,6 +78,7 @@ def greedy_decoder(
return decodes, targets
+# TODO: restructure into own file / class
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
@@ -280,6 +282,7 @@ def train(
return loss.item()
+# TODO: check how dataloader can be made more efficient
def test(model, device, test_loader, criterion):
"""Test"""
print("\nevaluating...")
@@ -327,7 +330,7 @@ def run(
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
- "n_class": 36,
+ "n_class": 36, # TODO: dynamically determine this from vocab size
"n_feats": 128,
"stride": 2,
"dropout": 0.1,