aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore7
-rw-r--r--readme.md6
-rw-r--r--swr2_asr/train_2.py36
-rw-r--r--tests/__init__.py0
4 files changed, 30 insertions, 19 deletions
diff --git a/.gitignore b/.gitignore
index 33600ee..8e64e4b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -118,6 +118,8 @@ ipython_config.py
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
+.pdm-python
+.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
@@ -162,14 +164,9 @@ dmypy.json
# Cython debug symbols
cython_debug/
-# linter
-**/.ruff
-
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
-
-data/
diff --git a/readme.md b/readme.md
index 99d741c..47d9a31 100644
--- a/readme.md
+++ b/readme.md
@@ -5,7 +5,7 @@ recogniton 2 (SWR2) in the summer term 2023.
# Installation
```
-pip install -r requirements.txt
+poetry install
```
# Usage
@@ -14,13 +14,13 @@ pip install -r requirements.txt
Train using the provided train script:
- poetry run train --data PATH/TO/DATA --lr 0.01
+ poetry run train
## Evaluation
## Inference
- poetry run recognize --data PATH/TO/FILE
+ poetry run recognize
## CI
diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py
index 2e690e2..b1b597a 100644
--- a/swr2_asr/train_2.py
+++ b/swr2_asr/train_2.py
@@ -48,6 +48,20 @@ class TextTransform:
ö 29
ü 30
ß 31
+ - 32
+ é 33
+ è 34
+ à 35
+ ù 36
+ ç 37
+ â 38
+ ê 39
+ î 40
+ ô 41
+ û 42
+ ë 43
+ ï 44
+ ü 45
"""
self.char_map = {}
self.index_map = {}
@@ -93,15 +107,15 @@ def data_processing(data, data_type="train"):
labels = []
input_lengths = []
label_lengths = []
- for waveform, _, utterance, _, _, _ in data:
+ for x in data:
if data_type == "train":
- spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
+ spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1)
elif data_type == "valid":
- spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
+ spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1)
else:
raise Exception("data_type should be train or valid")
spectrograms.append(spec)
- label = torch.Tensor(text_transform.text_to_int(utterance.lower()))
+ label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower()))
labels.append(label)
input_lengths.append(spec.shape[0] // 2)
label_lengths.append(len(label))
@@ -364,12 +378,12 @@ def test(model, device, test_loader, criterion, epoch, iter_meter):
)
-def run(lr: float, batch_size: int, epochs: int) -> None:
+def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
hparams = {
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
- "n_class": 33,
+ "n_class": 46,
"n_feats": 128,
"stride": 2,
"dropout": 0.1,
@@ -381,13 +395,13 @@ def run(lr: float, batch_size: int, epochs: int) -> None:
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu")
- device = torch.device("mps")
+ # device = torch.device("mps")
train_dataset = MultilingualLibriSpeech(
- "data", "mls_german_opus", split="train", download=False
+ "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False
)
test_dataset = MultilingualLibriSpeech(
- "data", "mls_german_opus", split="test", download=False
+ "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False
)
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
@@ -401,7 +415,7 @@ def run(lr: float, batch_size: int, epochs: int) -> None:
)
test_loader = DataLoader(
- train_dataset,
+ test_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
collate_fn=lambda x: data_processing(x, "train"),
@@ -449,4 +463,4 @@ def run(lr: float, batch_size: int, epochs: int) -> None:
if __name__ == "__main__":
- run(lr=5e-4, batch_size=20, epochs=10)
+ run(lr=5e-4, batch_size=16, epochs=1)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py