aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
authorPhilipp Merkel2023-09-11 22:36:28 +0000
committerPhilipp Merkel2023-09-11 22:36:28 +0000
commit4aff1fcd70cd8601541a1dd5bd820b0263ed1362 (patch)
treefe30e408ad30e25e7ea2891e223240e7316986c0 /swr2_asr/train.py
parent3811dc68de2e2572b3656b8f4460553136eb11b4 (diff)
fix: switched up training and test splits in train.py
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ec25918..3ed3ac8 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -187,16 +187,14 @@ def main(config_path: str):
dataset_config = config_dict.get("dataset", {})
tokenizer_config = config_dict.get("tokenizer", {})
checkpoints_config = config_dict.get("checkpoints", {})
-
- print(training_config["learning_rate"])
-
+
if not os.path.isdir(dataset_config["dataset_root_path"]):
os.makedirs(dataset_config["dataset_root_path"])
train_dataset = MLSDataset(
dataset_config["dataset_root_path"],
dataset_config["language_name"],
- Split.TEST,
+ Split.TRAIN,
download=dataset_config["download"],
limited=dataset_config["limited_supervision"],
size=dataset_config["dataset_percentage"],
@@ -204,7 +202,7 @@ def main(config_path: str):
valid_dataset = MLSDataset(
dataset_config["dataset_root_path"],
dataset_config["language_name"],
- Split.TRAIN,
+ Split.TEST,
download=dataset_config["download"],
limited=dataset_config["limited_supervision"],
size=dataset_config["dataset_percentage"],