diff options
-rw-r--r-- | config.philipp.yaml | 22 | ||||
-rw-r--r-- | swr2_asr/train.py | 8 | ||||
-rw-r--r-- | swr2_asr/utils/data.py | 31 | ||||
-rw-r--r-- | swr2_asr/utils/tokenizer.py | 12 |
4 files changed, 14 insertions, 59 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml index 4a723c6..f72ce2e 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -4,30 +4,30 @@ model: rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets + dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets training: learning_rate: 0.0005 - batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 3 - eval_every_n: 1 # evaluate every n epochs + batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 150 + eval_every_n: 5 # evaluate every n epochs num_workers: 4 # number of workers for dataloader device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically dataset: - download: True - dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir + download: true + dataset_root_path: "data" # files will be downloaded into this dir language_name: "mls_german_opus" - limited_supervision: True # set to True if you want to use limited supervision - dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) - shuffle: True + limited_supervision: false # set to True if you want to use limited supervision + dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%) + shuffle: true tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" checkpoints: - model_load_path: "data/runs/epoch30" # path to load model from - model_save_path: ~ # path to save model to + model_load_path: "data/runs/epoch31" # path to load model from + model_save_path: "data/runs/epoch" # path to save model to inference: model_load_path: "data/runs/epoch30" # path to load model from diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ec25918..3ed3ac8 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -187,16 +187,14 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) - - print(training_config["learning_rate"]) - + if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TEST, + Split.TRAIN, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], @@ -204,7 +202,7 @@ def main(config_path: str): valid_dataset = MLSDataset( dataset_config["dataset_root_path"], dataset_config["language_name"], - Split.TRAIN, + Split.TEST, download=dataset_config["download"], limited=dataset_config["limited_supervision"], size=dataset_config["dataset_percentage"], diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 10f0ea8..d551c98 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -134,11 +134,6 @@ class MLSDataset(Dataset): def initialize_limited(self) -> None: """Initializes the limited supervision dataset""" - # get file handles - # get file paths - # get transcripts - # create train or validation split - handles = set() train_root_path = os.path.join(self.dataset_path, self.language, "train") @@ -348,29 +343,3 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore - - -if __name__ == "__main__": - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" - LANGUAGE = "mls_german_opus" - split = Split.DEV - DOWNLOAD = False - - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD) - - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=True, - collate_fn=DataProcessing( - "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - ), - ) - - for batch in dataloader: - print(batch) - break - - print(len(dataset)) - - print(dataset[0]) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 22569eb..1cc7b84 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,15 +120,3 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer - - -if __name__ == "__main__": - tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus") - print(tokenizer.char_map) - print(tokenizer.index_map) - print(tokenizer.get_vocab_size()) - print(tokenizer.get_blank_token()) - print(tokenizer.get_unk_token()) - print(tokenizer.get_space_token()) - print(tokenizer.encode("hallo welt")) - print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) |