From 58b30927bd870604a4077a8af9ec3cad7b0be21c Mon Sep 17 00:00:00 2001
From: Pherkel
Date: Mon, 11 Sep 2023 21:52:42 +0200
Subject: changed config to yaml!

---
 swr2_asr/utils/data.py          | 7 ++-----
 swr2_asr/utils/tokenizer.py     | 8 +++-----
 swr2_asr/utils/visualization.py | 8 ++++----
 3 files changed, 9 insertions(+), 14 deletions(-)

(limited to 'swr2_asr/utils')

diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index e939e1d..0e06eec 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -1,13 +1,12 @@
 """Class containing utils for the ASR system."""
 import os
 from enum import Enum
-from typing import TypedDict
 
 import numpy as np
 import torch
 import torchaudio
 from torch import Tensor, nn
-from torch.utils.data import Dataset
+from torch.utils.data import DataLoader, Dataset
 from torchaudio.datasets.utils import _extract_tar
 
 from swr2_asr.utils.tokenizer import CharTokenizer
@@ -125,7 +124,7 @@ class MLSDataset(Dataset):
 
         self._handle_download_dataset(download)
         self._validate_local_directory()
-        if limited and (split == Split.TRAIN or split == Split.VALID):
+        if limited and split in (Split.TRAIN, Split.VALID):
             self.initialize_limited()
         else:
             self.initialize()
@@ -351,8 +350,6 @@ class MLSDataset(Dataset):
 
 
 if __name__ == "__main__":
-    from torch.utils.data import DataLoader
-
     DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
     LANGUAGE = "mls_german_opus"
     split = Split.DEV
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 5482bbe..22569eb 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -1,8 +1,6 @@
 """Tokenizer for Multilingual Librispeech datasets"""
-
-
-from datetime import datetime
 import os
+from datetime import datetime
 
 from tqdm.autonotebook import tqdm
 
@@ -119,8 +117,8 @@ class CharTokenizer:
                 line = line.strip()
                 if line:
                     char, index = line.split()
-                    tokenizer.char_map[char] = int(index)
-                    tokenizer.index_map[int(index)] = char
+                    load_tokenizer.char_map[char] = int(index)
+                    load_tokenizer.index_map[int(index)] = char
         return load_tokenizer
 
 
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
index 80f942a..a55d0d5 100644
--- a/swr2_asr/utils/visualization.py
+++ b/swr2_asr/utils/visualization.py
@@ -6,10 +6,10 @@ import torch
 
 def plot(epochs, path):
     """Plots the losses over the epochs"""
-    losses = list()
-    test_losses = list()
-    cers = list()
-    wers = list()
+    losses = []
+    test_losses = []
+    cers = []
+    wers = []
     for epoch in range(1, epochs + 1):
         current_state = torch.load(path + str(epoch))
         losses.append(current_state["loss"])
-- 
cgit v1.2.3