aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/train.py8
-rw-r--r--swr2_asr/utils/data.py31
-rw-r--r--swr2_asr/utils/tokenizer.py12
3 files changed, 3 insertions, 48 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ec25918..3ed3ac8 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -187,16 +187,14 @@ def main(config_path: str):
dataset_config = config_dict.get("dataset", {})
tokenizer_config = config_dict.get("tokenizer", {})
checkpoints_config = config_dict.get("checkpoints", {})
-
- print(training_config["learning_rate"])
-
+
if not os.path.isdir(dataset_config["dataset_root_path"]):
os.makedirs(dataset_config["dataset_root_path"])
train_dataset = MLSDataset(
dataset_config["dataset_root_path"],
dataset_config["language_name"],
- Split.TEST,
+ Split.TRAIN,
download=dataset_config["download"],
limited=dataset_config["limited_supervision"],
size=dataset_config["dataset_percentage"],
@@ -204,7 +202,7 @@ def main(config_path: str):
valid_dataset = MLSDataset(
dataset_config["dataset_root_path"],
dataset_config["language_name"],
- Split.TRAIN,
+ Split.TEST,
download=dataset_config["download"],
limited=dataset_config["limited_supervision"],
size=dataset_config["dataset_percentage"],
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index 10f0ea8..d551c98 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -134,11 +134,6 @@ class MLSDataset(Dataset):
def initialize_limited(self) -> None:
"""Initializes the limited supervision dataset"""
- # get file handles
- # get file paths
- # get transcripts
- # create train or validation split
-
handles = set()
train_root_path = os.path.join(self.dataset_path, self.language, "train")
@@ -348,29 +343,3 @@ class MLSDataset(Dataset):
dataset_lookup_entry["chapterid"],
idx,
) # type: ignore
-
-
-if __name__ == "__main__":
- DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
- LANGUAGE = "mls_german_opus"
- split = Split.DEV
- DOWNLOAD = False
-
- dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD)
-
- dataloader = DataLoader(
- dataset,
- batch_size=1,
- shuffle=True,
- collate_fn=DataProcessing(
- "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
- ),
- )
-
- for batch in dataloader:
- print(batch)
- break
-
- print(len(dataset))
-
- print(dataset[0])
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 22569eb..1cc7b84 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -120,15 +120,3 @@ class CharTokenizer:
load_tokenizer.char_map[char] = int(index)
load_tokenizer.index_map[int(index)] = char
return load_tokenizer
-
-
-if __name__ == "__main__":
- tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus")
- print(tokenizer.char_map)
- print(tokenizer.index_map)
- print(tokenizer.get_vocab_size())
- print(tokenizer.get_blank_token())
- print(tokenizer.get_unk_token())
- print(tokenizer.get_space_token())
- print(tokenizer.encode("hallo welt"))
- print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))