aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index d13683f..8943f71 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -57,6 +57,7 @@ def data_processing(data, data_type="train"):
def greedy_decoder(
output, labels, label_lengths, blank_label=28, collapse_repeated=True
):
+ # TODO: adopt to support both tokenizers
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
@@ -77,6 +78,7 @@ def greedy_decoder(
return decodes, targets
+# TODO: restructure into own file / class
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
@@ -280,6 +282,7 @@ def train(
return loss.item()
+# TODO: check how dataloader can be made more efficient
def test(model, device, test_loader, criterion):
"""Test"""
print("\nevaluating...")
@@ -327,7 +330,7 @@ def run(
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
- "n_class": 36,
+ "n_class": 36, # TODO: dynamically determine this from vocab size
"n_feats": 128,
"stride": 2,
"dropout": 0.1,