aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.philipp.yaml22
-rw-r--r--swr2_asr/train.py8
-rw-r--r--swr2_asr/utils/data.py31
-rw-r--r--swr2_asr/utils/tokenizer.py12
4 files changed, 14 insertions, 59 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml
index 4a723c6..f72ce2e 100644
--- a/config.philipp.yaml
+++ b/config.philipp.yaml
@@ -4,30 +4,30 @@ model:
rnn_dim: 512
n_feats: 128 # number of mel features
stride: 2
- dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets
+ dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets
training:
learning_rate: 0.0005
- batch_size: 2 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
- epochs: 3
- eval_every_n: 1 # evaluate every n epochs
+ batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
+ epochs: 150
+ eval_every_n: 5 # evaluate every n epochs
num_workers: 4 # number of workers for dataloader
device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
dataset:
- download: True
- dataset_root_path: "/Volumes/pherkel 1/SWR2-ASR" # files will be downloaded into this dir
+ download: true
+ dataset_root_path: "data" # files will be downloaded into this dir
language_name: "mls_german_opus"
- limited_supervision: True # set to True if you want to use limited supervision
- dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%)
- shuffle: True
+ limited_supervision: false # set to True if you want to use limited supervision
+ dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%)
+ shuffle: true
tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
checkpoints:
- model_load_path: "data/runs/epoch30" # path to load model from
- model_save_path: ~ # path to save model to
+ model_load_path: "data/runs/epoch31" # path to load model from
+ model_save_path: "data/runs/epoch" # path to save model to
inference:
model_load_path: "data/runs/epoch30" # path to load model from
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ec25918..3ed3ac8 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -187,16 +187,14 @@ def main(config_path: str):
dataset_config = config_dict.get("dataset", {})
tokenizer_config = config_dict.get("tokenizer", {})
checkpoints_config = config_dict.get("checkpoints", {})
-
- print(training_config["learning_rate"])
-
+
if not os.path.isdir(dataset_config["dataset_root_path"]):
os.makedirs(dataset_config["dataset_root_path"])
train_dataset = MLSDataset(
dataset_config["dataset_root_path"],
dataset_config["language_name"],
- Split.TEST,
+ Split.TRAIN,
download=dataset_config["download"],
limited=dataset_config["limited_supervision"],
size=dataset_config["dataset_percentage"],
@@ -204,7 +202,7 @@ def main(config_path: str):
valid_dataset = MLSDataset(
dataset_config["dataset_root_path"],
dataset_config["language_name"],
- Split.TRAIN,
+ Split.TEST,
download=dataset_config["download"],
limited=dataset_config["limited_supervision"],
size=dataset_config["dataset_percentage"],
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index 10f0ea8..d551c98 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -134,11 +134,6 @@ class MLSDataset(Dataset):
def initialize_limited(self) -> None:
"""Initializes the limited supervision dataset"""
- # get file handles
- # get file paths
- # get transcripts
- # create train or validation split
-
handles = set()
train_root_path = os.path.join(self.dataset_path, self.language, "train")
@@ -348,29 +343,3 @@ class MLSDataset(Dataset):
dataset_lookup_entry["chapterid"],
idx,
) # type: ignore
-
-
-if __name__ == "__main__":
- DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
- LANGUAGE = "mls_german_opus"
- split = Split.DEV
- DOWNLOAD = False
-
- dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD)
-
- dataloader = DataLoader(
- dataset,
- batch_size=1,
- shuffle=True,
- collate_fn=DataProcessing(
- "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
- ),
- )
-
- for batch in dataloader:
- print(batch)
- break
-
- print(len(dataset))
-
- print(dataset[0])
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 22569eb..1cc7b84 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -120,15 +120,3 @@ class CharTokenizer:
load_tokenizer.char_map[char] = int(index)
load_tokenizer.index_map[int(index)] = char
return load_tokenizer
-
-
-if __name__ == "__main__":
- tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus")
- print(tokenizer.char_map)
- print(tokenizer.index_map)
- print(tokenizer.get_vocab_size())
- print(tokenizer.get_blank_token())
- print(tokenizer.get_unk_token())
- print(tokenizer.get_space_token())
- print(tokenizer.encode("hallo welt"))
- print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))