diff options
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index d13683f..8943f71 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -57,6 +57,7 @@ def data_processing(data, data_type="train"): def greedy_decoder( output, labels, label_lengths, blank_label=28, collapse_repeated=True ): + # TODO: adopt to support both tokenizers """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] @@ -77,6 +78,7 @@ def greedy_decoder( return decodes, targets +# TODO: restructure into own file / class class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" @@ -280,6 +282,7 @@ def train( return loss.item() +# TODO: check how dataloader can be made more efficient def test(model, device, test_loader, criterion): """Test""" print("\nevaluating...") @@ -327,7 +330,7 @@ def run( "n_cnn_layers": 3, "n_rnn_layers": 5, "rnn_dim": 512, - "n_class": 36, + "n_class": 36, # TODO: dynamically determine this from vocab size "n_feats": 128, "stride": 2, "dropout": 0.1, |