From 0a629be04f7a27c531b671921e1a445de34895b4 Mon Sep 17 00:00:00 2001
From: Pherkel
Date: Sun, 20 Aug 2023 13:11:58 +0200
Subject: added tokenizer training

---
 pyproject.toml | 4 ++++
 1 file changed, 4 insertions(+)

(limited to 'pyproject.toml')

diff --git a/pyproject.toml b/pyproject.toml
index 8490aa5..b7e6ffb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,14 +15,18 @@ audioloader = {git = "https://github.com/marvinborner/AudioLoader.git"}
 tqdm = "^4.66.1"
 numpy = "^1.25.2"
 mido = "^1.3.0"
+tokenizers = "^0.13.3"
 
 [tool.poetry.group.dev.dependencies]
 black = "^23.7.0"
 mypy = "^1.5.1"
 pylint = "^2.17.5"
+ruff = "^0.0.285"
 
 [tool.poetry.scripts]
 train = "swr2_asr.train:run_cli"
+train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer"
+train-char-tokenizer = "swr2_asr.tokenizer:train_char_tokenizer"
 
 [build-system]
 requires = ["poetry-core"]
-- 
cgit v1.2.3


From 33744072d8c0906950cdc9cd00fc1f345a51d9d4 Mon Sep 17 00:00:00 2001
From: Pherkel
Date: Sun, 20 Aug 2023 15:23:02 +0200
Subject: please the linters

---
 .github/workflows/format.yml | 38 +++++++++++++++++++++-----------------
 Makefile                     |  3 +++
 mypy.ini                     |  6 ++++++
 poetry.lock                  | 25 +++++++++++++++++++++----
 pyproject.toml               |  1 +
 swr2_asr/loss_scores.py      | 36 +++++++++++++++++++-----------------
 swr2_asr/tokenizer.py        | 18 ++++++++++--------
 swr2_asr/train.py            | 12 ++++++------
 8 files changed, 87 insertions(+), 52 deletions(-)

(limited to 'pyproject.toml')

diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index 4a5a509..8411d60 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -6,20 +6,24 @@ jobs:
   build:
     runs-on: ubuntu-latest
     steps:
-    - uses: actions/checkout@master
-    - name: Set up Python
-      uses: actions/setup-python@v3
-      with:
-        python-version: "3.10"
-    - name: Install dependencies
-      run: |
-        python -m pip install -U pip poetry
-        poetry --version
-        poetry check --no-interaction
-        poetry config virtualenvs.in-project true
-        poetry install --no-interaction
-    - name: Run CI
-      run: |
-        make lint
-
-        
+      - uses: actions/checkout@master
+      - name: Set up Python
+        uses: actions/setup-python@v3
+        with:
+          python-version: "3.10"
+      - name: Install dependencies
+        run: |
+          python -m pip install -U pip poetry
+          poetry --version
+          poetry check --no-interaction
+          poetry config virtualenvs.in-project true
+          poetry install --no-interaction
+      - name: Check for format issues
+        run: |
+          make format-check
+      - name: Run pylint
+        run: |
+          poetry run pylint swr2_asr
+      - name: Run mypy
+        run: |
+          poetry run mypy --strict swr2_asr
diff --git a/Makefile b/Makefile
index 4f0ea9c..a37644c 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,9 @@
 format:
 	@poetry run black .
 
+format-check:
+	@poetry run black --check .
+
 lint:
 	@poetry run mypy --strict swr2_asr
 	@poetry run pylint swr2_asr
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index f7cfc59..c13aa05 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -9,3 +9,9 @@ ignore_missing_imports = true
 
 [mypy-click.*]
 ignore_missing_imports = true
+
+[mypy-tokenizers.*]
+ignore_missing_imports = true
+
+[mypy-tqmd.*]
+ignore_missing_imports = true
\ No newline at end of file
diff --git a/poetry.lock b/poetry.lock
index 49d37d1..1f3609a 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -14,7 +14,10 @@ files = [
 [package.dependencies]
 lazy-object-proxy = ">=1.4.0"
 typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
-wrapt = {version = ">=1.11,<2", markers = "python_version < \"3.11\""}
+wrapt = [
+    {version = ">=1.11,<2", markers = "python_version < \"3.11\""},
+    {version = ">=1.14,<2", markers = "python_version >= \"3.11\""},
+]
 
 [[package]]
 name = "AudioLoader"
@@ -680,7 +683,10 @@ files = [
 [package.dependencies]
 astroid = ">=2.15.6,<=2.17.0-dev0"
 colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
-dill = {version = ">=0.2", markers = "python_version < \"3.11\""}
+dill = [
+    {version = ">=0.2", markers = "python_version < \"3.11\""},
+    {version = ">=0.3.6", markers = "python_version >= \"3.11\""},
+]
 isort = ">=4.2.5,<6"
 mccabe = ">=0.6,<0.8"
 platformdirs = ">=2.2.0"
@@ -967,6 +973,17 @@ torch = "*"
 tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"]
 tutorials = ["matplotlib", "pandas", "tabulate"]
 
+[[package]]
+name = "types-tqdm"
+version = "4.66.0.1"
+description = "Typing stubs for tqdm"
+optional = false
+python-versions = "*"
+files = [
+    {file = "types-tqdm-4.66.0.1.tar.gz", hash = "sha256:6457c90f03cc5a0fe8dd11839c8cbf5572bf542b438b1af74233801728b5dfbc"},
+    {file = "types_tqdm-4.66.0.1-py3-none-any.whl", hash = "sha256:6a1516788cbb33d725803439b79c25bfed7e8176b8d782020b5c24aedac1649b"},
+]
+
 [[package]]
 name = "typing-extensions"
 version = "4.7.1"
@@ -1078,5 +1095,5 @@ files = [
 
 [metadata]
 lock-version = "2.0"
-python-versions = "~3.10"
-content-hash = "a72b4e5791a6216b58b53a72bf68d97dbdbc95978b3974fddd9e5f9b76e36321"
+python-versions = "^3.10"
+content-hash = "6b42e36364178f1670267137f73e8d2b2f3fc1d534a2b198d4ca3f65457d55c2"
diff --git a/pyproject.toml b/pyproject.toml
index eb17479..fabe364 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,6 +23,7 @@ black = "^23.7.0"
 mypy = "^1.5.1"
 pylint = "^2.17.5"
 ruff = "^0.0.285"
+types-tqdm = "^4.66.0.1"
 
 [tool.poetry.scripts]
 train = "swr2_asr.train:run_cli"
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py
index 977462d..c49cc15 100644
--- a/swr2_asr/loss_scores.py
+++ b/swr2_asr/loss_scores.py
@@ -1,7 +1,9 @@
+"""Methods for determining the loss and scores of the model."""
 import numpy as np
 
 
 def avg_wer(wer_scores, combined_ref_len):
+    """Calculate the average word error rate (WER) of the model."""
     return float(sum(wer_scores)) / float(combined_ref_len)
 
 
@@ -13,34 +15,34 @@ def _levenshtein_distance(ref, hyp):
     extend the edits to word level when calculate levenshtein disctance for
     two sentences.
     """
-    m = len(ref)
-    n = len(hyp)
+    len_ref = len(ref)
+    len_hyp = len(hyp)
 
     # special case
     if ref == hyp:
         return 0
-    if m == 0:
-        return n
-    if n == 0:
-        return m
+    if len_ref == 0:
+        return len_hyp
+    if len_hyp == 0:
+        return len_ref
 
-    if m < n:
+    if len_ref < len_hyp:
         ref, hyp = hyp, ref
-        m, n = n, m
+        len_ref, len_hyp = len_hyp, len_ref
 
     # use O(min(m, n)) space
-    distance = np.zeros((2, n + 1), dtype=np.int32)
+    distance = np.zeros((2, len_hyp + 1), dtype=np.int32)
 
     # initialize distance matrix
-    for j in range(0, n + 1):
+    for j in range(0, len_hyp + 1):
         distance[0][j] = j
 
     # calculate levenshtein distance
-    for i in range(1, m + 1):
+    for i in range(1, len_ref + 1):
         prev_row_idx = (i - 1) % 2
         cur_row_idx = i % 2
         distance[cur_row_idx][0] = i
-        for j in range(1, n + 1):
+        for j in range(1, len_hyp + 1):
             if ref[i - 1] == hyp[j - 1]:
                 distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
             else:
@@ -49,7 +51,7 @@ def _levenshtein_distance(ref, hyp):
                 d_num = distance[prev_row_idx][j] + 1
                 distance[cur_row_idx][j] = min(s_num, i_num, d_num)
 
-    return distance[m % 2][n]
+    return distance[len_ref % 2][len_hyp]
 
 
 def word_errors(
@@ -143,8 +145,8 @@ def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "):
     if ref_len == 0:
         raise ValueError("Reference's word number should be greater than 0.")
 
-    wer = float(edit_distance) / ref_len
-    return wer
+    word_error_rate = float(edit_distance) / ref_len
+    return word_error_rate
 
 
 def cer(reference, hypothesis, ignore_case=False, remove_space=False):
@@ -181,5 +183,5 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False):
     if ref_len == 0:
         raise ValueError("Length of reference should be greater than 0.")
 
-    cer = float(edit_distance) / ref_len
-    return cer
+    char_error_rate = float(edit_distance) / ref_len
+    return char_error_rate
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index 79d6727..a665159 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -63,15 +63,15 @@ class CharTokenizer:
         else:
             splits = [split]
 
-        chars = set()
-        for sp in splits:
+        chars: set = set()
+        for s_plit in splits:
             transcript_path = os.path.join(
-                dataset_path, language, sp, "transcripts.txt"
+                dataset_path, language, s_plit, "transcripts.txt"
             )
 
             # check if dataset is downloaded, download if not
             if download and not os.path.exists(transcript_path):
-                MultilingualLibriSpeech(dataset_path, language, sp, download=True)
+                MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
 
             with open(
                 transcript_path,
@@ -82,7 +82,7 @@ class CharTokenizer:
             lines = [line.split(" ", 1)[1] for line in lines]
             lines = [line.strip() for line in lines]
 
-            for line in tqdm(lines, desc=f"Training tokenizer on {sp} split"):
+            for line in tqdm(lines, desc=f"Training tokenizer on {s_plit} split"):
                 chars.update(line)
         offset = len(self.char_map)
         for i, char in enumerate(chars):
@@ -205,10 +205,12 @@ def train_bpe_tokenizer(
 
     lines = []
 
-    for sp in splits:
-        transcripts_path = os.path.join(dataset_path, language, sp, "transcripts.txt")
+    for s_plit in splits:
+        transcripts_path = os.path.join(
+            dataset_path, language, s_plit, "transcripts.txt"
+        )
         if download and not os.path.exists(transcripts_path):
-            MultilingualLibriSpeech(dataset_path, language, sp, download=True)
+            MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
 
         with open(
             transcripts_path,
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 8943f71..6af1e80 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -83,7 +83,7 @@ class CNNLayerNorm(nn.Module):
     """Layer normalization built for cnns input"""
 
     def __init__(self, n_feats: int):
-        super(CNNLayerNorm, self).__init__()
+        super().__init__()
         self.layer_norm = nn.LayerNorm(n_feats)
 
     def forward(self, data):
@@ -105,7 +105,7 @@ class ResidualCNN(nn.Module):
         dropout: float,
         n_feats: int,
     ):
-        super(ResidualCNN, self).__init__()
+        super().__init__()
 
         self.cnn1 = nn.Conv2d(
             in_channels, out_channels, kernel, stride, padding=kernel // 2
@@ -147,7 +147,7 @@ class BidirectionalGRU(nn.Module):
         dropout: float,
         batch_first: bool,
     ):
-        super(BidirectionalGRU, self).__init__()
+        super().__init__()
 
         self.bi_gru = nn.GRU(
             input_size=rnn_dim,
@@ -181,7 +181,7 @@ class SpeechRecognitionModel(nn.Module):
         stride: int = 2,
         dropout: float = 0.1,
     ):
-        super(SpeechRecognitionModel, self).__init__()
+        super().__init__()
         n_feats //= 2
         self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
         # n residual cnn layers with filter size of 32
@@ -227,7 +227,7 @@ class SpeechRecognitionModel(nn.Module):
         return data
 
 
-class IterMeter(object):
+class IterMeter:
     """keeps track of total iterations"""
 
     def __init__(self):
@@ -381,7 +381,7 @@ def run(
     ).to(device)
 
     print(
-        "Num Model Parameters", sum([param.nelement() for param in model.parameters()])
+        "Num Model Parameters", sum((param.nelement() for param in model.parameters()))
     )
     optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
     criterion = nn.CTCLoss(blank=28).to(device)
-- 
cgit v1.2.3