diff options
author | Pherkel | 2023-09-11 22:16:26 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 22:16:26 +0200 |
commit | 64dbb9d32a51b1bce6c9de67069dc8f5943a5399 (patch) | |
tree | 74422120bd33ed39814364801493cb03de8029b0 | |
parent | 58b30927bd870604a4077a8af9ec3cad7b0be21c (diff) |
added n_feats from config
-rw-r--r-- | config.philipp.yaml | 2 | ||||
-rw-r--r-- | swr2_asr/train.py | 4 | ||||
-rw-r--r-- | swr2_asr/utils/data.py | 7 |
3 files changed, 7 insertions, 6 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml index 638b5ef..6b905cd 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -4,7 +4,7 @@ model: rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets training: learning_rate: 0.0005 diff --git a/swr2_asr/train.py b/swr2_asr/train.py index eb79ee2..ca70d21 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -223,8 +223,8 @@ def main(config_path: str): ) tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) - train_data_processing = DataProcessing("train", tokenizer) - valid_data_processing = DataProcessing("valid", tokenizer) + train_data_processing = DataProcessing("train", tokenizer, {"n_feats": model_config["n_feats"]}) + valid_data_processing = DataProcessing("valid", tokenizer, {"n_feats": model_config["n_feats"]}) train_loader = DataLoader( dataset=train_dataset, diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 0e06eec..10f0ea8 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -15,18 +15,19 @@ from swr2_asr.utils.tokenizer import CharTokenizer class DataProcessing: """Data processing class for the dataloader""" - def __init__(self, data_type: str, tokenizer: CharTokenizer): + def __init__(self, data_type: str, tokenizer: CharTokenizer, hparams: dict): self.data_type = data_type self.tokenizer = tokenizer + n_features = hparams["n_feats"] if data_type == "train": self.audio_transform = torch.nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=n_features), torchaudio.transforms.FrequencyMasking(freq_mask_param=30), torchaudio.transforms.TimeMasking(time_mask_param=100), ) elif data_type == "valid": - self.audio_transform = torchaudio.transforms.MelSpectrogram() + self.audio_transform = torchaudio.transforms.MelSpectrogram(n_mels=n_features) def __call__(self, data) -> tuple[Tensor, Tensor, list, list]: spectrograms = [] |