From 9dc3bc07424908dd7cf3f052708f506fd58b6e2c Mon Sep 17 00:00:00 2001
From: Pherkel
Date: Mon, 11 Sep 2023 14:49:28 +0200
Subject: refactor utilities (data, vis, tokenizer)

---
 swr2_asr/utils/__init__.py      |   0
 swr2_asr/utils/data.py          | 371 ++++++++++++++++++++++++++++++++++++++++
 swr2_asr/utils/decoder.py       |  26 +++
 swr2_asr/utils/tokenizer.py     | 126 ++++++++++++++
 swr2_asr/utils/visualization.py |  22 +++
 5 files changed, 545 insertions(+)
 create mode 100644 swr2_asr/utils/__init__.py
 create mode 100644 swr2_asr/utils/data.py
 create mode 100644 swr2_asr/utils/decoder.py
 create mode 100644 swr2_asr/utils/tokenizer.py
 create mode 100644 swr2_asr/utils/visualization.py

(limited to 'swr2_asr/utils')

diff --git a/swr2_asr/utils/__init__.py b/swr2_asr/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
new file mode 100644
index 0000000..93f4a9a
--- /dev/null
+++ b/swr2_asr/utils/data.py
@@ -0,0 +1,371 @@
+"""Class containing utils for the ASR system."""
+import os
+from enum import Enum
+from typing import TypedDict
+
+import numpy as np
+import torch
+import torchaudio
+from torch import Tensor, nn
+from torch.utils.data import Dataset
+from torchaudio.datasets.utils import _extract_tar
+
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+
+class DataProcessing:
+    """Data processing class for the dataloader"""
+
+    def __init__(self, data_type: str, tokenizer: CharTokenizer):
+        self.data_type = data_type
+        self.tokenizer = tokenizer
+
+        if data_type == "train":
+            self.audio_transform = torch.nn.Sequential(
+                torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+                torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+                torchaudio.transforms.TimeMasking(time_mask_param=100),
+            )
+        elif data_type == "valid":
+            self.audio_transform = torchaudio.transforms.MelSpectrogram()
+
+    def __call__(self, data) -> tuple[Tensor, Tensor, list, list]:
+        spectrograms = []
+        labels = []
+        input_lengths = []
+        label_lengths = []
+        for waveform, _, utterance, _, _, _ in data:
+            spec = self.audio_transform(waveform).squeeze(0).transpose(0, 1)
+            spectrograms.append(spec)
+            label = torch.Tensor(self.tokenizer.encode(utterance.lower()))
+            labels.append(label)
+            input_lengths.append(spec.shape[0] // 2)
+            label_lengths.append(len(label))
+
+        spectrograms = (
+            nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
+        )
+        labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+        return spectrograms, labels, input_lengths, label_lengths
+
+
+# create enum specifiying dataset splits
+class MLSSplit(str, Enum):
+    """Enum specifying dataset as they are defined in the
+    Multilingual LibriSpeech dataset"""
+
+    TRAIN = "train"
+    TEST = "test"
+    DEV = "dev"
+
+
+class Split(str, Enum):
+    """Extending the MLSSplit class to allow for a custom validation split"""
+
+    TRAIN = "train"
+    VALID = "valid"
+    TEST = "test"
+    DEV = "dev"
+
+
+def split_to_mls_split(split_name: Split) -> MLSSplit:
+    """Converts the custom split to a MLSSplit"""
+    if split_name == Split.VALID:
+        return MLSSplit.TRAIN
+    return split_name  # type: ignore
+
+
+class Sample(TypedDict):
+    """Type for a sample in the dataset"""
+
+    waveform: torch.Tensor
+    spectrogram: torch.Tensor
+    input_length: int
+    utterance: torch.Tensor
+    utterance_length: int
+    sample_rate: int
+    speaker_id: str
+    book_id: str
+    chapter_id: str
+
+
+class MLSDataset(Dataset):
+    """Custom Dataset for reading Multilingual LibriSpeech
+
+    Attributes:
+        dataset_path (str):
+            path to the dataset
+        language (str):
+            language of the dataset
+        split (Split):
+            split of the dataset
+        mls_split (MLSSplit):
+            split of the dataset as defined in the Multilingual LibriSpeech dataset
+        dataset_lookup (list):
+            list of dicts containing the speakerid, bookid, chapterid and utterance
+
+    directory structure:
+        <dataset_path>
+        ├── <language>
+        │  ├── train
+        │  │  ├── transcripts.txt
+        │  │  └── audio
+        │  │     └── <speakerid>
+        │  │        └── <bookid>
+        │  │           └── <speakerid>_<bookid>_<chapterid>.opus / .flac
+
+        each line in transcripts.txt has the following format:
+        <speakerid>_<bookid>_<chapterid> <utterance>
+    """
+
+    def __init__(
+        self,
+        dataset_path: str,
+        language: str,
+        split: Split,
+        limited: bool,
+        download: bool,
+        size: float = 0.2,
+    ):
+        """Initializes the dataset"""
+        self.dataset_path = dataset_path
+        self.language = language
+        self.file_ext = ".opus" if "opus" in language else ".flac"
+        self.mls_split: MLSSplit = split_to_mls_split(split)  # split path on disk
+        self.split: Split = split  # split used internally
+
+        self.dataset_lookup = []
+
+        self._handle_download_dataset(download)
+        self._validate_local_directory()
+        if limited and (split == Split.TRAIN or split == Split.VALID):
+            self.initialize_limited()
+        else:
+            self.initialize()
+
+        self.dataset_lookup = self.dataset_lookup[: int(len(self.dataset_lookup) * size)]
+
+    def initialize_limited(self) -> None:
+        """Initializes the limited supervision dataset"""
+        # get file handles
+        # get file paths
+        # get transcripts
+        # create train or validation split
+
+        handles = set()
+
+        train_root_path = os.path.join(self.dataset_path, self.language, "train")
+
+        # get file handles for 9h
+        with open(
+            os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"),
+            "r",
+            encoding="utf-8",
+        ) as file:
+            for line in file:
+                handles.add(line.strip())
+
+        # get file handles for 1h splits
+        for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")):
+            if handle_path not in range(0, 6):
+                continue
+            with open(
+                os.path.join(
+                    train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt"
+                ),
+                "r",
+                encoding="utf-8",
+            ) as file:
+                for line in file:
+                    handles.add(line.strip())
+
+        # get file paths for handles
+        file_paths = []
+        for handle in handles:
+            file_paths.append(
+                os.path.join(
+                    train_root_path,
+                    "audio",
+                    handle.split("_")[0],
+                    handle.split("_")[1],
+                    handle + self.file_ext,
+                )
+            )
+
+        # get transcripts for handles
+        transcripts = []
+        with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file:
+            for line in file:
+                if line.split("\t")[0] in handles:
+                    transcripts.append(line.strip())
+
+        # create train or valid split randomly with seed 42
+        if self.split == Split.TRAIN:
+            np.random.seed(42)
+            indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8))
+            file_paths = [file_paths[i] for i in indices]
+            transcripts = [transcripts[i] for i in indices]
+        elif self.split == Split.VALID:
+            np.random.seed(42)
+            indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2))
+            file_paths = [file_paths[i] for i in indices]
+            transcripts = [transcripts[i] for i in indices]
+
+        # create dataset lookup
+        self.dataset_lookup = [
+            {
+                "speakerid": path.split("/")[-3],
+                "bookid": path.split("/")[-2],
+                "chapterid": path.split("/")[-1].split("_")[2].split(".")[0],
+                "utterance": utterance.split("\t")[1],
+            }
+            for path, utterance in zip(file_paths, transcripts, strict=False)
+        ]
+
+    def initialize(self) -> None:
+        """Initializes the entire dataset
+
+        Reads the transcripts.txt file and creates a lookup table
+        """
+        transcripts_path = os.path.join(
+            self.dataset_path, self.language, self.mls_split, "transcripts.txt"
+        )
+
+        with open(transcripts_path, "r", encoding="utf-8") as script_file:
+            # read all lines in transcripts.txt
+            transcripts = script_file.readlines()
+            # split each line into (<speakerid>_<bookid>_<chapterid>, <utterance>)
+            transcripts = [line.strip().split("\t", 1) for line in transcripts]  # type: ignore
+            utterances = [utterance.strip() for _, utterance in transcripts]  # type: ignore
+            identifier = [identifier.strip() for identifier, _ in transcripts]  # type: ignore
+            identifier = [path.split("_") for path in identifier]
+
+            if self.split == Split.VALID:
+                np.random.seed(42)
+                indices = np.random.choice(len(utterances), int(len(utterances) * 0.2))
+                utterances = [utterances[i] for i in indices]
+                identifier = [identifier[i] for i in indices]
+            elif self.split == Split.TRAIN:
+                np.random.seed(42)
+                indices = np.random.choice(len(utterances), int(len(utterances) * 0.8))
+                utterances = [utterances[i] for i in indices]
+                identifier = [identifier[i] for i in indices]
+
+            self.dataset_lookup = [
+                {
+                    "speakerid": path[0],
+                    "bookid": path[1],
+                    "chapterid": path[2],
+                    "utterance": utterance,
+                }
+                for path, utterance in zip(identifier, utterances, strict=False)
+            ]
+
+    def _handle_download_dataset(self, download: bool) -> None:
+        """Download the dataset"""
+        if not download:
+            print("Download flag not set, skipping download")
+            return
+        # zip exists:
+        if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download:
+            print(f"Found dataset at {self.dataset_path}. Skipping download")
+        # path exists:
+        elif os.path.isdir(os.path.join(self.dataset_path, self.language)) and download:
+            return
+        else:
+            os.makedirs(self.dataset_path, exist_ok=True)
+            url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
+
+            torch.hub.download_url_to_file(
+                url, os.path.join(self.dataset_path, self.language) + ".tar.gz"
+            )
+
+        # unzip the dataset
+        if not os.path.isdir(os.path.join(self.dataset_path, self.language)):
+            print(
+                f"Unzipping the dataset at \
+                    {os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
+            )
+            _extract_tar(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True)
+        else:
+            print("Dataset is already unzipped, validating it now")
+            return
+
+    def _validate_local_directory(self):
+        # check if dataset_path exists
+        if not os.path.exists(self.dataset_path):
+            raise ValueError("Dataset path does not exist")
+        if not os.path.exists(os.path.join(self.dataset_path, self.language)):
+            raise ValueError("Language not downloaded!")
+        if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
+            raise ValueError("Split not found in dataset")
+
+    def __len__(self):
+        """Returns the length of the dataset"""
+        return len(self.dataset_lookup)
+
+    def __getitem__(self, idx: int) -> tuple[Tensor, int, str, int, int, int]:
+        """One sample
+
+        Returns:
+            Tuple of the following items;
+
+            Tensor:
+                Waveform
+            int:
+                Sample rate
+            str:
+                Transcript
+            int:
+                Speaker ID
+            int:
+                Chapter ID
+            int:
+                Utterance ID
+        """
+        # get the utterance
+        dataset_lookup_entry = self.dataset_lookup[idx]
+
+        utterance = dataset_lookup_entry["utterance"]
+
+        # get the audio file
+        audio_path = os.path.join(
+            self.dataset_path,
+            self.language,
+            self.mls_split,
+            "audio",
+            self.dataset_lookup[idx]["speakerid"],
+            self.dataset_lookup[idx]["bookid"],
+            "_".join(
+                [
+                    self.dataset_lookup[idx]["speakerid"],
+                    self.dataset_lookup[idx]["bookid"],
+                    self.dataset_lookup[idx]["chapterid"],
+                ]
+            )
+            + self.file_ext,
+        )
+
+        waveform, sample_rate = torchaudio.load(audio_path)  # pylint: disable=no-member
+
+        # resample if necessary
+        if sample_rate != 16000:
+            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
+            waveform = resampler(waveform)
+
+        return (
+            waveform,
+            sample_rate,
+            utterance,
+            dataset_lookup_entry["speakerid"],
+            dataset_lookup_entry["chapterid"],
+            idx,
+        )  # type: ignore
+
+
+if __name__ == "__main__":
+    DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
+    LANGUAGE = "mls_german_opus"
+    split = Split.TRAIN
+    DOWNLOAD = False
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
new file mode 100644
index 0000000..fcddb79
--- /dev/null
+++ b/swr2_asr/utils/decoder.py
@@ -0,0 +1,26 @@
+"""Decoder for CTC-based ASR.""" ""
+import torch
+
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+
+# TODO: refactor to use torch CTC decoder class
+def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
+    """Greedily decode a sequence."""
+    blank_label = tokenizer.get_blank_token()
+    arg_maxes = torch.argmax(output, dim=2)  # pylint: disable=no-member
+    decodes = []
+    targets = []
+    for i, args in enumerate(arg_maxes):
+        decode = []
+        targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
+        for j, index in enumerate(args):
+            if index != blank_label:
+                if collapse_repeated and j != 0 and index == args[j - 1]:
+                    continue
+                decode.append(index.item())
+        decodes.append(tokenizer.decode(decode))
+    return decodes, targets
+
+
+# TODO: add beam search decoder
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
new file mode 100644
index 0000000..d92465a
--- /dev/null
+++ b/swr2_asr/utils/tokenizer.py
@@ -0,0 +1,126 @@
+"""Tokenizer for Multilingual Librispeech datasets"""
+
+
+class CharTokenizer:
+    """Maps characters to integers and vice versa"""
+
+    def __init__(self):
+        char_map_str = """
+        _
+        <BLANK>
+        <UNK>
+        <SPACE>
+        a
+        b
+        c
+        d
+        e
+        f
+        g
+        h
+        i
+        j
+        k
+        l
+        m
+        n
+        o
+        p
+        q
+        r
+        s
+        t
+        u
+        v
+        w
+        x
+        y
+        z
+        é
+        à
+        ä
+        ö
+        ß
+        ü
+        -
+        '
+        
+        """
+
+        self.char_map = {}
+        self.index_map = {}
+        for idx, char in enumerate(char_map_str.strip().split("\n")):
+            char = char.strip()
+            self.char_map[char] = idx
+            self.index_map[idx] = char
+        self.index_map[1] = " "
+
+    def encode(self, text: str) -> list[int]:
+        """Use a character map and convert text to an integer sequence"""
+        int_sequence = []
+        for char in text:
+            if char == " ":
+                char = self.char_map["<SPACE>"]
+            elif char not in self.char_map:
+                char = self.char_map["<UNK>"]
+            else:
+                char = self.char_map[char]
+            int_sequence.append(char)
+        return int_sequence
+
+    def decode(self, labels: list[int]) -> str:
+        """Use a character map and convert integer labels to an text sequence"""
+        string = []
+        for i in labels:
+            string.append(self.index_map[i])
+        return "".join(string).replace("<SPACE>", " ")
+
+    def get_vocab_size(self) -> int:
+        """Get the number of unique characters in the dataset"""
+        return len(self.char_map)
+
+    def get_blank_token(self) -> int:
+        """Get the integer representation of the <BLANK> character"""
+        return self.char_map["<BLANK>"]
+
+    def get_unk_token(self) -> int:
+        """Get the integer representation of the <UNK> character"""
+        return self.char_map["<UNK>"]
+
+    def get_space_token(self) -> int:
+        """Get the integer representation of the <SPACE> character"""
+        return self.char_map["<SPACE>"]
+
+    # TODO: add train function
+
+    def save(self, path: str) -> None:
+        """Save the tokenizer to a file"""
+        with open(path, "w", encoding="utf-8") as file:
+            for char, index in self.char_map.items():
+                file.write(f"{char} {index}\n")
+
+    @staticmethod
+    def from_file(tokenizer_file: str) -> "CharTokenizer":
+        """Instantiate a CharTokenizer from a file"""
+        load_tokenizer = CharTokenizer()
+        with open(tokenizer_file, "r", encoding="utf-8") as file:
+            for line in file:
+                line = line.strip()
+                if line:
+                    char, index = line.split()
+                    tokenizer.char_map[char] = int(index)
+                    tokenizer.index_map[int(index)] = char
+        return load_tokenizer
+
+
+if __name__ == "__main__":
+    tokenizer = CharTokenizer()
+    tokenizer.save("data/tokenizers/char_tokenizer_german.json")
+    print(tokenizer.char_map)
+    print(tokenizer.index_map)
+    print(tokenizer.get_vocab_size())
+    print(tokenizer.get_blank_token())
+    print(tokenizer.get_unk_token())
+    print(tokenizer.get_space_token())
+    print(tokenizer.encode("hallo welt"))
+    print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
new file mode 100644
index 0000000..80f942a
--- /dev/null
+++ b/swr2_asr/utils/visualization.py
@@ -0,0 +1,22 @@
+"""Utilities for visualizing the training process and results."""
+
+import matplotlib.pyplot as plt
+import torch
+
+
+def plot(epochs, path):
+    """Plots the losses over the epochs"""
+    losses = list()
+    test_losses = list()
+    cers = list()
+    wers = list()
+    for epoch in range(1, epochs + 1):
+        current_state = torch.load(path + str(epoch))
+        losses.append(current_state["loss"])
+        test_losses.append(current_state["test_loss"])
+        cers.append(current_state["avg_cer"])
+        wers.append(current_state["avg_wer"])
+
+    plt.plot(losses)
+    plt.plot(test_losses)
+    plt.savefig("losses.svg")
-- 
cgit v1.2.3