aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r--swr2_asr/utils/__init__.py0
-rw-r--r--swr2_asr/utils/data.py345
-rw-r--r--swr2_asr/utils/decoder.py26
-rw-r--r--swr2_asr/utils/loss_scores.py203
-rw-r--r--swr2_asr/utils/tokenizer.py122
-rw-r--r--swr2_asr/utils/visualization.py22
6 files changed, 718 insertions, 0 deletions
diff --git a/swr2_asr/utils/__init__.py b/swr2_asr/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/swr2_asr/utils/__init__.py
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
new file mode 100644
index 0000000..d551c98
--- /dev/null
+++ b/swr2_asr/utils/data.py
@@ -0,0 +1,345 @@
+"""Class containing utils for the ASR system."""
+import os
+from enum import Enum
+
+import numpy as np
+import torch
+import torchaudio
+from torch import Tensor, nn
+from torch.utils.data import DataLoader, Dataset
+from torchaudio.datasets.utils import _extract_tar
+
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+
+class DataProcessing:
+ """Data processing class for the dataloader"""
+
+ def __init__(self, data_type: str, tokenizer: CharTokenizer, hparams: dict):
+ self.data_type = data_type
+ self.tokenizer = tokenizer
+ n_features = hparams["n_feats"]
+
+ if data_type == "train":
+ self.audio_transform = torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=n_features),
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+ torchaudio.transforms.TimeMasking(time_mask_param=100),
+ )
+ elif data_type == "valid":
+ self.audio_transform = torchaudio.transforms.MelSpectrogram(n_mels=n_features)
+
+ def __call__(self, data) -> tuple[Tensor, Tensor, list, list]:
+ spectrograms = []
+ labels = []
+ input_lengths = []
+ label_lengths = []
+ for waveform, _, utterance, _, _, _ in data:
+ spec = self.audio_transform(waveform).squeeze(0).transpose(0, 1)
+ spectrograms.append(spec)
+ label = torch.Tensor(self.tokenizer.encode(utterance.lower()))
+ labels.append(label)
+ input_lengths.append(spec.shape[0] // 2)
+ label_lengths.append(len(label))
+
+ spectrograms = (
+ nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
+ )
+ labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+ return spectrograms, labels, input_lengths, label_lengths
+
+
+# create enum specifiying dataset splits
+class MLSSplit(str, Enum):
+ """Enum specifying dataset as they are defined in the
+ Multilingual LibriSpeech dataset"""
+
+ TRAIN = "train"
+ TEST = "test"
+ DEV = "dev"
+
+
+class Split(str, Enum):
+ """Extending the MLSSplit class to allow for a custom validation split"""
+
+ TRAIN = "train"
+ VALID = "valid"
+ TEST = "test"
+ DEV = "dev"
+
+
+def split_to_mls_split(split_name: Split) -> MLSSplit:
+ """Converts the custom split to a MLSSplit"""
+ if split_name == Split.VALID:
+ return MLSSplit.TRAIN
+ return split_name # type: ignore
+
+
+class MLSDataset(Dataset):
+ """Custom Dataset for reading Multilingual LibriSpeech
+
+ Attributes:
+ dataset_path (str):
+ path to the dataset
+ language (str):
+ language of the dataset
+ split (Split):
+ split of the dataset
+ mls_split (MLSSplit):
+ split of the dataset as defined in the Multilingual LibriSpeech dataset
+ dataset_lookup (list):
+ list of dicts containing the speakerid, bookid, chapterid and utterance
+
+ directory structure:
+ <dataset_path>
+ ├── <language>
+ │ ├── train
+ │ │ ├── transcripts.txt
+ │ │ └── audio
+ │ │ └── <speakerid>
+ │ │ └── <bookid>
+ │ │ └── <speakerid>_<bookid>_<chapterid>.opus / .flac
+
+ each line in transcripts.txt has the following format:
+ <speakerid>_<bookid>_<chapterid> <utterance>
+ """
+
+ def __init__(
+ self,
+ dataset_path: str,
+ language: str,
+ split: Split, # pylint: disable=redefined-outer-name
+ limited: bool = False,
+ download: bool = True,
+ size: float = 0.2,
+ ):
+ """Initializes the dataset"""
+ self.dataset_path = dataset_path
+ self.language = language
+ self.file_ext = ".opus" if "opus" in language else ".flac"
+ self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
+ self.split: Split = split # split used internally
+
+ self.dataset_lookup = []
+
+ self._handle_download_dataset(download)
+ self._validate_local_directory()
+ if limited and split in (Split.TRAIN, Split.VALID):
+ self.initialize_limited()
+ else:
+ self.initialize()
+
+ self.dataset_lookup = self.dataset_lookup[: int(len(self.dataset_lookup) * size)]
+
+ def initialize_limited(self) -> None:
+ """Initializes the limited supervision dataset"""
+ handles = set()
+
+ train_root_path = os.path.join(self.dataset_path, self.language, "train")
+
+ # get file handles for 9h
+ with open(
+ os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"),
+ "r",
+ encoding="utf-8",
+ ) as file:
+ for line in file:
+ handles.add(line.strip())
+
+ # get file handles for 1h splits
+ for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")):
+ if handle_path not in range(0, 6):
+ continue
+ with open(
+ os.path.join(
+ train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt"
+ ),
+ "r",
+ encoding="utf-8",
+ ) as file:
+ for line in file:
+ handles.add(line.strip())
+
+ # get file paths for handles
+ file_paths = []
+ for handle in handles:
+ file_paths.append(
+ os.path.join(
+ train_root_path,
+ "audio",
+ handle.split("_")[0],
+ handle.split("_")[1],
+ handle + self.file_ext,
+ )
+ )
+
+ # get transcripts for handles
+ transcripts = []
+ with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file:
+ for line in file:
+ if line.split("\t")[0] in handles:
+ transcripts.append(line.strip())
+
+ # create train or valid split randomly with seed 42
+ if self.split == Split.TRAIN:
+ np.random.seed(42)
+ indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8))
+ file_paths = [file_paths[i] for i in indices]
+ transcripts = [transcripts[i] for i in indices]
+ elif self.split == Split.VALID:
+ np.random.seed(42)
+ indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2))
+ file_paths = [file_paths[i] for i in indices]
+ transcripts = [transcripts[i] for i in indices]
+
+ # create dataset lookup
+ self.dataset_lookup = [
+ {
+ "speakerid": path.split("/")[-3],
+ "bookid": path.split("/")[-2],
+ "chapterid": path.split("/")[-1].split("_")[2].split(".")[0],
+ "utterance": utterance.split("\t")[1],
+ }
+ for path, utterance in zip(file_paths, transcripts, strict=False)
+ ]
+
+ def initialize(self) -> None:
+ """Initializes the entire dataset
+
+ Reads the transcripts.txt file and creates a lookup table
+ """
+ transcripts_path = os.path.join(
+ self.dataset_path, self.language, self.mls_split, "transcripts.txt"
+ )
+
+ with open(transcripts_path, "r", encoding="utf-8") as script_file:
+ # read all lines in transcripts.txt
+ transcripts = script_file.readlines()
+ # split each line into (<speakerid>_<bookid>_<chapterid>, <utterance>)
+ transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore
+ utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore
+ identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore
+ identifier = [path.split("_") for path in identifier]
+
+ if self.split == Split.VALID:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.2))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+ elif self.split == Split.TRAIN:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.8))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+
+ self.dataset_lookup = [
+ {
+ "speakerid": path[0],
+ "bookid": path[1],
+ "chapterid": path[2],
+ "utterance": utterance,
+ }
+ for path, utterance in zip(identifier, utterances, strict=False)
+ ]
+
+ def _handle_download_dataset(self, download: bool) -> None:
+ """Download the dataset"""
+ if not download:
+ print("Download flag not set, skipping download")
+ return
+ # zip exists:
+ if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download:
+ print(f"Found dataset at {self.dataset_path}. Skipping download")
+ # path exists:
+ elif os.path.isdir(os.path.join(self.dataset_path, self.language)) and download:
+ return
+ else:
+ os.makedirs(self.dataset_path, exist_ok=True)
+ url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
+
+ torch.hub.download_url_to_file(
+ url, os.path.join(self.dataset_path, self.language) + ".tar.gz"
+ )
+
+ # unzip the dataset
+ if not os.path.isdir(os.path.join(self.dataset_path, self.language)):
+ print(
+ f"Unzipping the dataset at \
+ {os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
+ )
+ _extract_tar(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True)
+ else:
+ print("Dataset is already unzipped, validating it now")
+ return
+
+ def _validate_local_directory(self):
+ # check if dataset_path exists
+ if not os.path.exists(self.dataset_path):
+ raise ValueError("Dataset path does not exist")
+ if not os.path.exists(os.path.join(self.dataset_path, self.language)):
+ raise ValueError("Language not downloaded!")
+ if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
+ raise ValueError("Split not found in dataset")
+
+ def __len__(self):
+ """Returns the length of the dataset"""
+ return len(self.dataset_lookup)
+
+ def __getitem__(self, idx: int) -> tuple[Tensor, int, str, int, int, int]:
+ """One sample
+
+ Returns:
+ Tuple of the following items;
+
+ Tensor:
+ Waveform
+ int:
+ Sample rate
+ str:
+ Transcript
+ int:
+ Speaker ID
+ int:
+ Chapter ID
+ int:
+ Utterance ID
+ """
+ # get the utterance
+ dataset_lookup_entry = self.dataset_lookup[idx]
+
+ utterance = dataset_lookup_entry["utterance"]
+
+ # get the audio file
+ audio_path = os.path.join(
+ self.dataset_path,
+ self.language,
+ self.mls_split,
+ "audio",
+ self.dataset_lookup[idx]["speakerid"],
+ self.dataset_lookup[idx]["bookid"],
+ "_".join(
+ [
+ self.dataset_lookup[idx]["speakerid"],
+ self.dataset_lookup[idx]["bookid"],
+ self.dataset_lookup[idx]["chapterid"],
+ ]
+ )
+ + self.file_ext,
+ )
+
+ waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member
+
+ # resample if necessary
+ if sample_rate != 16000:
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
+ waveform = resampler(waveform)
+
+ return (
+ waveform,
+ sample_rate,
+ utterance,
+ dataset_lookup_entry["speakerid"],
+ dataset_lookup_entry["chapterid"],
+ idx,
+ ) # type: ignore
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
new file mode 100644
index 0000000..fcddb79
--- /dev/null
+++ b/swr2_asr/utils/decoder.py
@@ -0,0 +1,26 @@
+"""Decoder for CTC-based ASR.""" ""
+import torch
+
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+
+# TODO: refactor to use torch CTC decoder class
+def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
+ """Greedily decode a sequence."""
+ blank_label = tokenizer.get_blank_token()
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(tokenizer.decode(decode))
+ return decodes, targets
+
+
+# TODO: add beam search decoder
diff --git a/swr2_asr/utils/loss_scores.py b/swr2_asr/utils/loss_scores.py
new file mode 100644
index 0000000..80285f6
--- /dev/null
+++ b/swr2_asr/utils/loss_scores.py
@@ -0,0 +1,203 @@
+"""Methods for determining the loss and scores of the model."""
+import numpy as np
+
+
+def avg_wer(wer_scores, combined_ref_len) -> float:
+ """Calculate the average word error rate (WER).
+
+ Args:
+ wer_scores: word error rate scores
+ combined_ref_len: combined length of reference sentences
+
+ Returns:
+ average word error rate (float)
+
+ Usage:
+ >>> avg_wer([0.5, 0.5], 2)
+ 0.5
+ """
+ return float(sum(wer_scores)) / float(combined_ref_len)
+
+
+def _levenshtein_distance(ref, hyp) -> int:
+ """Levenshtein distance.
+
+ Args:
+ ref: reference sentence
+ hyp: hypothesis sentence
+
+ Returns:
+ distance: levenshtein distance between reference and hypothesis
+
+ Usage:
+ >>> _levenshtein_distance("hello", "helo")
+ 2
+ """
+ len_ref = len(ref)
+ len_hyp = len(hyp)
+
+ # special case
+ if ref == hyp:
+ return 0
+ if len_ref == 0:
+ return len_hyp
+ if len_hyp == 0:
+ return len_ref
+
+ if len_ref < len_hyp:
+ ref, hyp = hyp, ref
+ len_ref, len_hyp = len_hyp, len_ref
+
+ # use O(min(m, n)) space
+ distance = np.zeros((2, len_hyp + 1), dtype=np.int32)
+
+ # initialize distance matrix
+ for j in range(0, len_hyp + 1):
+ distance[0][j] = j
+
+ # calculate levenshtein distance
+ for i in range(1, len_ref + 1):
+ prev_row_idx = (i - 1) % 2
+ cur_row_idx = i % 2
+ distance[cur_row_idx][0] = i
+ for j in range(1, len_hyp + 1):
+ if ref[i - 1] == hyp[j - 1]:
+ distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
+ else:
+ s_num = distance[prev_row_idx][j - 1] + 1
+ i_num = distance[cur_row_idx][j - 1] + 1
+ d_num = distance[prev_row_idx][j] + 1
+ distance[cur_row_idx][j] = min(s_num, i_num, d_num)
+
+ return distance[len_ref % 2][len_hyp]
+
+
+def word_errors(
+ reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "
+) -> tuple[float, int]:
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in word-level.
+
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ delimiter: Delimiter of input sentences.
+
+ Returns:
+ Levenshtein distance and length of reference sentence.
+
+ Usage:
+ >>> word_errors("hello world", "hello")
+ 1, 2
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ ref_words = reference.split(delimiter)
+ hyp_words = hypothesis.split(delimiter)
+
+ edit_distance = _levenshtein_distance(ref_words, hyp_words)
+ return float(edit_distance), len(ref_words)
+
+
+def char_errors(
+ reference: str,
+ hypothesis: str,
+ ignore_case: bool = False,
+ remove_space: bool = False,
+) -> tuple[float, int]:
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in char-level.
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ remove_space: Whether remove internal space characters
+
+ Returns:
+ Levenshtein distance and length of reference sentence.
+
+ Usage:
+ >>> char_errors("hello world", "hello")
+ 1, 10
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ join_char = " "
+ if remove_space:
+ join_char = ""
+
+ reference = join_char.join(filter(None, reference.split(" ")))
+ hypothesis = join_char.join(filter(None, hypothesis.split(" ")))
+
+ edit_distance = _levenshtein_distance(reference, hypothesis)
+ return float(edit_distance), len(reference)
+
+
+def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" ") -> float:
+ """Calculate word error rate (WER). WER compares reference text and
+ hypothesis text in word-level.
+ WER is defined as:
+ WER = (Sw + Dw + Iw) / Nw
+ with:
+ Sw is the number of words subsituted,
+ Dw is the number of words deleted,
+ Iw is the number of words inserted,
+ Nw is the number of words in the reference
+
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ delimiter: Delimiter of input sentences.
+
+ Returns:
+ Word error rate (float)
+
+ Usage:
+ >>> wer("hello world", "hello")
+ 0.5
+ """
+ edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter)
+
+ if ref_len == 0:
+ raise ValueError("Reference's word number should be greater than 0.")
+
+ word_error_rate = float(edit_distance) / ref_len
+ return word_error_rate
+
+
+def cer(reference, hypothesis, ignore_case=False, remove_space=False):
+ """Calculate charactor error rate (CER). CER compares reference text and
+ hypothesis text in char-level. CER is defined as:
+ CER = (Sc + Dc + Ic) / Nc
+ with
+ Sc is the number of characters substituted,
+ Dc is the number of characters deleted,
+ Ic is the number of characters inserted
+ Nc is the number of characters in the reference
+
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ remove_space: Whether remove internal space characters
+
+ Returns:
+ Character error rate (float)
+
+ Usage:
+ >>> cer("hello world", "hello")
+ 0.2727272727272727
+ """
+ edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space)
+
+ if ref_len == 0:
+ raise ValueError("Length of reference should be greater than 0.")
+
+ char_error_rate = float(edit_distance) / ref_len
+ return char_error_rate
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
new file mode 100644
index 0000000..1cc7b84
--- /dev/null
+++ b/swr2_asr/utils/tokenizer.py
@@ -0,0 +1,122 @@
+"""Tokenizer for Multilingual Librispeech datasets"""
+import os
+from datetime import datetime
+
+from tqdm.autonotebook import tqdm
+
+
+class CharTokenizer:
+ """Maps characters to integers and vice versa"""
+
+ def __init__(self):
+ self.char_map = {}
+ self.index_map = {}
+
+ def encode(self, text: str) -> list[int]:
+ """Use a character map and convert text to an integer sequence"""
+ int_sequence = []
+ for char in text:
+ if char == " ":
+ char = self.char_map["<SPACE>"]
+ elif char not in self.char_map:
+ char = self.char_map["<UNK>"]
+ else:
+ char = self.char_map[char]
+ int_sequence.append(char)
+ return int_sequence
+
+ def decode(self, labels: list[int]) -> str:
+ """Use a character map and convert integer labels to an text sequence"""
+ string = []
+ for i in labels:
+ string.append(self.index_map[i])
+ return "".join(string).replace("<SPACE>", " ")
+
+ def get_vocab_size(self) -> int:
+ """Get the number of unique characters in the dataset"""
+ return len(self.char_map)
+
+ def get_blank_token(self) -> int:
+ """Get the integer representation of the <BLANK> character"""
+ return self.char_map["<BLANK>"]
+
+ def get_unk_token(self) -> int:
+ """Get the integer representation of the <UNK> character"""
+ return self.char_map["<UNK>"]
+
+ def get_space_token(self) -> int:
+ """Get the integer representation of the <SPACE> character"""
+ return self.char_map["<SPACE>"]
+
+ @staticmethod
+ def train(dataset_path: str, language: str) -> "CharTokenizer":
+ """Train the tokenizer on a dataset"""
+ chars = set()
+ root_path = os.path.join(dataset_path, language)
+ for split in os.listdir(root_path):
+ split_dir = os.path.join(root_path, split)
+ if os.path.isdir(split_dir):
+ transcript_path = os.path.join(split_dir, "transcripts.txt")
+
+ with open(transcript_path, "r", encoding="utf-8") as transcrips:
+ lines = transcrips.readlines()
+ lines = [line.split(" ", 1)[1] for line in lines]
+ lines = [line.strip() for line in lines]
+ lines = [line.lower() for line in lines]
+
+ for line in tqdm(lines, desc=f"Training tokenizer on {split_dir} split"):
+ chars.update(line)
+
+ # sort chars
+ chars.remove(" ")
+ chars = sorted(chars)
+
+ train_tokenizer = CharTokenizer()
+
+ train_tokenizer.char_map["_"] = 0
+ train_tokenizer.char_map["<BLANK>"] = 1
+ train_tokenizer.char_map["<UNK>"] = 2
+ train_tokenizer.char_map["<SPACE>"] = 3
+
+ train_tokenizer.index_map[0] = "_"
+ train_tokenizer.index_map[1] = "<BLANK>"
+ train_tokenizer.index_map[2] = "<UNK>"
+ train_tokenizer.index_map[3] = "<SPACE>"
+
+ offset = 4
+
+ for idx, char in enumerate(chars):
+ idx += offset
+ train_tokenizer.char_map[char] = idx
+ train_tokenizer.index_map[idx] = char
+
+ train_tokenizer_dir = os.path.join("data/tokenizers")
+ train_tokenizer_path = os.path.join(
+ train_tokenizer_dir,
+ f"char_tokenizer_{language}_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.json",
+ )
+
+ if not os.path.exists(os.path.dirname(train_tokenizer_dir)):
+ os.makedirs(train_tokenizer_dir)
+ train_tokenizer.save(train_tokenizer_path)
+
+ return train_tokenizer
+
+ def save(self, path: str) -> None:
+ """Save the tokenizer to a file"""
+ with open(path, "w", encoding="utf-8") as file:
+ for char, index in self.char_map.items():
+ file.write(f"{char} {index}\n")
+
+ @staticmethod
+ def from_file(tokenizer_file: str) -> "CharTokenizer":
+ """Instantiate a CharTokenizer from a file"""
+ load_tokenizer = CharTokenizer()
+ with open(tokenizer_file, "r", encoding="utf-8") as file:
+ for line in file:
+ line = line.strip()
+ if line:
+ char, index = line.split()
+ load_tokenizer.char_map[char] = int(index)
+ load_tokenizer.index_map[int(index)] = char
+ return load_tokenizer
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
new file mode 100644
index 0000000..a55d0d5
--- /dev/null
+++ b/swr2_asr/utils/visualization.py
@@ -0,0 +1,22 @@
+"""Utilities for visualizing the training process and results."""
+
+import matplotlib.pyplot as plt
+import torch
+
+
+def plot(epochs, path):
+ """Plots the losses over the epochs"""
+ losses = []
+ test_losses = []
+ cers = []
+ wers = []
+ for epoch in range(1, epochs + 1):
+ current_state = torch.load(path + str(epoch))
+ losses.append(current_state["loss"])
+ test_losses.append(current_state["test_loss"])
+ cers.append(current_state["avg_cer"])
+ wers.append(current_state["avg_wer"])
+
+ plt.plot(losses)
+ plt.plot(test_losses)
+ plt.savefig("losses.svg")