aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--swr2_asr/tokenizer.py45
-rw-r--r--swr2_asr/train.py2
-rw-r--r--swr2_asr/utils.py16
3 files changed, 32 insertions, 31 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index 64227a4..e4df93b 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -90,7 +90,7 @@ class CharTokenizer(TokenizerType):
self.char_map[token] = len(self.char_map)
self.index_map[len(self.index_map)] = token
- def train(self, dataset_path: str, language: str, split: str, download: bool = True):
+ def train(self, dataset_path: str, language: str, split: str):
"""Train the tokenizer on the given dataset
Args:
@@ -110,10 +110,6 @@ class CharTokenizer(TokenizerType):
for s_plit in splits:
transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
- # check if dataset is downloaded, download if not
- if download and not os.path.exists(transcript_path):
- MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
-
with open(
transcript_path,
"r",
@@ -215,7 +211,6 @@ class CharTokenizer(TokenizerType):
@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use (including all)")
-@click.option("--download", default=True, help="Whether to download the dataset")
@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to")
@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
def train_bpe_tokenizer_cli(
@@ -223,24 +218,23 @@ def train_bpe_tokenizer_cli(
language: str,
split: str,
out_path: str,
- download: bool,
vocab_size: int,
):
- train_bpe_tokenizer(
- dataset_path,
- language,
- split,
- out_path,
- download,
- vocab_size,
-)
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset"""
+ train_bpe_tokenizer(
+ dataset_path,
+ language,
+ split,
+ out_path,
+ vocab_size,
+ )
+
def train_bpe_tokenizer(
dataset_path: str,
language: str,
split: str,
out_path: str,
- download: bool,
vocab_size: int,
):
"""Train a Byte-Pair Encoder tokenizer on the MLS dataset
@@ -268,9 +262,11 @@ def train_bpe_tokenizer(
for s_plit in splits:
transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
- if download and not os.path.exists(transcripts_path):
- # TODO: move to own dataset
- MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
+ if not os.path.exists(transcripts_path):
+ raise FileNotFoundError(
+ f"Could not find transcripts.txt in {transcripts_path}. "
+ "Please make sure that the dataset is downloaded."
+ )
with open(
transcripts_path,
@@ -355,22 +351,21 @@ def train_bpe_tokenizer(
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use")
@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to")
-@click.option("--download", default=True, help="Whether to download the dataset")
def train_char_tokenizer_cli(
dataset_path: str,
language: str,
split: str,
out_path: str,
- download: bool,
- ):
- train_char_tokenizer(dataset_path, language, split, out_path, download)
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset"""
+ train_char_tokenizer(dataset_path, language, split, out_path)
+
def train_char_tokenizer(
dataset_path: str,
language: str,
split: str,
out_path: str,
- download: bool,
):
"""Train a Byte-Pair Encoder tokenizer on the MLS dataset
@@ -386,7 +381,7 @@ def train_char_tokenizer(
"""
char_tokenizer = CharTokenizer()
- char_tokenizer.train(dataset_path, language, split, download)
+ char_tokenizer.train(dataset_path, language, split)
char_tokenizer.save(out_path)
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index aea99e0..63deb72 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -304,4 +304,4 @@ def run_cli(
if __name__ == "__main__":
- run_cli()
+ run_cli() # pylint: disable=no-value-for-parameter
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 1c755a6..8a950ab 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -181,15 +181,21 @@ class MLSDataset(Dataset):
os.makedirs(self.dataset_path, exist_ok=True)
url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
- torch.hub.download_url_to_file(url, os.path.join(self.dataset_path, self.language) + ".tar.gz")
+ torch.hub.download_url_to_file(
+ url, os.path.join(self.dataset_path, self.language) + ".tar.gz"
+ )
# unzip the dataset
if not os.path.isdir(os.path.join(self.dataset_path, self.language)):
- print(f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}")
- extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True)
+ print(
+ f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
+ )
+ extract_archive(
+ os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True
+ )
else:
- print("Dataset is already unzipped, validating it now")
- return
+ print("Dataset is already unzipped, validating it now")
+ return
def _validate_local_directory(self):
# check if dataset_path exists