aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils
diff options
context:
space:
mode:
authorPherkel2023-09-11 21:52:42 +0200
committerPherkel2023-09-11 21:52:42 +0200
commit58b30927bd870604a4077a8af9ec3cad7b0be21c (patch)
tree7dd492fa8f14ff61c88545448972022ead324c31 /swr2_asr/utils
parent9ca17d8a83369257f4cc42c963e25baf35a28f8f (diff)
changed config to yaml!
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r--swr2_asr/utils/data.py7
-rw-r--r--swr2_asr/utils/tokenizer.py8
-rw-r--r--swr2_asr/utils/visualization.py8
3 files changed, 9 insertions, 14 deletions
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index e939e1d..0e06eec 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -1,13 +1,12 @@
"""Class containing utils for the ASR system."""
import os
from enum import Enum
-from typing import TypedDict
import numpy as np
import torch
import torchaudio
from torch import Tensor, nn
-from torch.utils.data import Dataset
+from torch.utils.data import DataLoader, Dataset
from torchaudio.datasets.utils import _extract_tar
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -125,7 +124,7 @@ class MLSDataset(Dataset):
self._handle_download_dataset(download)
self._validate_local_directory()
- if limited and (split == Split.TRAIN or split == Split.VALID):
+ if limited and split in (Split.TRAIN, Split.VALID):
self.initialize_limited()
else:
self.initialize()
@@ -351,8 +350,6 @@ class MLSDataset(Dataset):
if __name__ == "__main__":
- from torch.utils.data import DataLoader
-
DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
LANGUAGE = "mls_german_opus"
split = Split.DEV
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 5482bbe..22569eb 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -1,8 +1,6 @@
"""Tokenizer for Multilingual Librispeech datasets"""
-
-
-from datetime import datetime
import os
+from datetime import datetime
from tqdm.autonotebook import tqdm
@@ -119,8 +117,8 @@ class CharTokenizer:
line = line.strip()
if line:
char, index = line.split()
- tokenizer.char_map[char] = int(index)
- tokenizer.index_map[int(index)] = char
+ load_tokenizer.char_map[char] = int(index)
+ load_tokenizer.index_map[int(index)] = char
return load_tokenizer
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
index 80f942a..a55d0d5 100644
--- a/swr2_asr/utils/visualization.py
+++ b/swr2_asr/utils/visualization.py
@@ -6,10 +6,10 @@ import torch
def plot(epochs, path):
"""Plots the losses over the epochs"""
- losses = list()
- test_losses = list()
- cers = list()
- wers = list()
+ losses = []
+ test_losses = []
+ cers = []
+ wers = []
for epoch in range(1, epochs + 1):
current_state = torch.load(path + str(epoch))
losses.append(current_state["loss"])