diff options
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/train.py | 8 | ||||
-rw-r--r-- | swr2_asr/utils/data.py | 31 | ||||
-rw-r--r-- | swr2_asr/utils/tokenizer.py | 12 |
3 files changed, 3 insertions, 48 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ec25918..3ed3ac8 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -187,16 +187,14 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) - - print(training_config["learning_rate"]) - + if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TEST, + Split.TRAIN, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], @@ -204,7 +202,7 @@ def main(config_path: str): valid_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TRAIN, + Split.TEST, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 10f0ea8..d551c98 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -134,11 +134,6 @@ class MLSDataset(Dataset): def initialize_limited(self) -> None: """Initializes the limited supervision dataset""" - # get file handles - # get file paths - # get transcripts - # create train or validation split - handles = set() train_root_path = os.path.join(self.dataset_path, self.language, "train") @@ -348,29 +343,3 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore - - -if __name__ == "__main__": - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" - LANGUAGE = "mls_german_opus" - split = Split.DEV - DOWNLOAD = False - - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD) - - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=True, - collate_fn=DataProcessing( - "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - ), - ) - - for batch in dataloader: - print(batch) - break - - print(len(dataset)) - - print(dataset[0]) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 22569eb..1cc7b84 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,15 +120,3 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer - - -if __name__ == "__main__": - tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus") - print(tokenizer.char_map) - print(tokenizer.index_map) - print(tokenizer.get_vocab_size()) - print(tokenizer.get_blank_token()) - print(tokenizer.get_unk_token()) - print(tokenizer.get_space_token()) - print(tokenizer.encode("hallo welt")) - print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) |