aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPherkel2023-09-11 22:16:26 +0200
committerPherkel2023-09-11 22:16:26 +0200
commit64dbb9d32a51b1bce6c9de67069dc8f5943a5399 (patch)
tree74422120bd33ed39814364801493cb03de8029b0
parent58b30927bd870604a4077a8af9ec3cad7b0be21c (diff)
added n_feats from config
-rw-r--r--config.philipp.yaml2
-rw-r--r--swr2_asr/train.py4
-rw-r--r--swr2_asr/utils/data.py7
3 files changed, 7 insertions, 6 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml
index 638b5ef..6b905cd 100644
--- a/config.philipp.yaml
+++ b/config.philipp.yaml
@@ -4,7 +4,7 @@ model:
rnn_dim: 512
n_feats: 128 # number of mel features
stride: 2
- dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets
+ dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets
training:
learning_rate: 0.0005
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index eb79ee2..ca70d21 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -223,8 +223,8 @@ def main(config_path: str):
)
tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
- train_data_processing = DataProcessing("train", tokenizer)
- valid_data_processing = DataProcessing("valid", tokenizer)
+ train_data_processing = DataProcessing("train", tokenizer, {"n_feats": model_config["n_feats"]})
+ valid_data_processing = DataProcessing("valid", tokenizer, {"n_feats": model_config["n_feats"]})
train_loader = DataLoader(
dataset=train_dataset,
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index 0e06eec..10f0ea8 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -15,18 +15,19 @@ from swr2_asr.utils.tokenizer import CharTokenizer
class DataProcessing:
"""Data processing class for the dataloader"""
- def __init__(self, data_type: str, tokenizer: CharTokenizer):
+ def __init__(self, data_type: str, tokenizer: CharTokenizer, hparams: dict):
self.data_type = data_type
self.tokenizer = tokenizer
+ n_features = hparams["n_feats"]
if data_type == "train":
self.audio_transform = torch.nn.Sequential(
- torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=n_features),
torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
torchaudio.transforms.TimeMasking(time_mask_param=100),
)
elif data_type == "valid":
- self.audio_transform = torchaudio.transforms.MelSpectrogram()
+ self.audio_transform = torchaudio.transforms.MelSpectrogram(n_mels=n_features)
def __call__(self, data) -> tuple[Tensor, Tensor, list, list]:
spectrograms = []