aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--swr2_asr/tokenizer.py28
-rw-r--r--swr2_asr/train.py3
-rw-r--r--swr2_asr/utils.py14
3 files changed, 39 insertions, 6 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index f02d4f5..64227a4 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -187,6 +187,7 @@ class CharTokenizer(TokenizerType):
def save(self, path: str):
"""Save the tokenizer to a file"""
+ os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as file:
# save it in the following format:
# {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
@@ -217,6 +218,23 @@ class CharTokenizer(TokenizerType):
@click.option("--download", default=True, help="Whether to download the dataset")
@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to")
@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
+def train_bpe_tokenizer_cli(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ download: bool,
+ vocab_size: int,
+):
+ train_bpe_tokenizer(
+ dataset_path,
+ language,
+ split,
+ out_path,
+ download,
+ vocab_size,
+)
+
def train_bpe_tokenizer(
dataset_path: str,
language: str,
@@ -251,6 +269,7 @@ def train_bpe_tokenizer(
for s_plit in splits:
transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
if download and not os.path.exists(transcripts_path):
+ # TODO: move to own dataset
MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
with open(
@@ -337,6 +356,15 @@ def train_bpe_tokenizer(
@click.option("--split", default="train", help="Split to use")
@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to")
@click.option("--download", default=True, help="Whether to download the dataset")
+def train_char_tokenizer_cli(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ download: bool,
+ ):
+ train_char_tokenizer(dataset_path, language, split, out_path, download)
+
def train_char_tokenizer(
dataset_path: str,
language: str,
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 8fc0b78..aea99e0 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -173,7 +173,6 @@ def run(
split="all",
download=False,
out_path="data/tokenizers/char_tokenizer_german.json",
- vocab_size=3000,
)
tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
@@ -305,4 +304,4 @@ def run_cli(
if __name__ == "__main__":
- run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus")
+ run_cli()
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 0330cd2..1c755a6 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -168,22 +168,28 @@ class MLSDataset(Dataset):
"""Sets the tokenizer"""
self.tokenizer = tokenizer
- def _handle_download_dataset(self, download: bool):
+ def _handle_download_dataset(self, download: bool) -> None:
"""Download the dataset"""
if not download:
+ print("Download flag not set, skipping download")
return
# zip exists:
if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download:
print(f"Found dataset at {self.dataset_path}. Skipping download")
# zip does not exist:
else:
- os.makedirs(self.dataset_path)
+ os.makedirs(self.dataset_path, exist_ok=True)
url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
- torch.hub.download_url_to_file(url, self.dataset_path)
+ torch.hub.download_url_to_file(url, os.path.join(self.dataset_path, self.language) + ".tar.gz")
# unzip the dataset
- extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz")
+ if not os.path.isdir(os.path.join(self.dataset_path, self.language)):
+ print(f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}")
+ extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True)
+ else:
+ print("Dataset is already unzipped, validating it now")
+ return
def _validate_local_directory(self):
# check if dataset_path exists