aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
authorPherkel2023-09-11 22:16:26 +0200
committerPherkel2023-09-11 22:16:26 +0200
commit64dbb9d32a51b1bce6c9de67069dc8f5943a5399 (patch)
tree74422120bd33ed39814364801493cb03de8029b0 /swr2_asr/train.py
parent58b30927bd870604a4077a8af9ec3cad7b0be21c (diff)
added n_feats from config
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index eb79ee2..ca70d21 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -223,8 +223,8 @@ def main(config_path: str):
)
tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
- train_data_processing = DataProcessing("train", tokenizer)
- valid_data_processing = DataProcessing("valid", tokenizer)
+ train_data_processing = DataProcessing("train", tokenizer, {"n_feats": model_config["n_feats"]})
+ valid_data_processing = DataProcessing("valid", tokenizer, {"n_feats": model_config["n_feats"]})
train_loader = DataLoader(
dataset=train_dataset,