aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/loss_scores.py
diff options
context:
space:
mode:
authorPherkel2023-08-20 15:23:02 +0200
committerPherkel2023-08-20 15:23:02 +0200
commit33744072d8c0906950cdc9cd00fc1f345a51d9d4 (patch)
treecd6771a7196a1ad39e8aa8d0f371480f1b6f3df0 /swr2_asr/loss_scores.py
parent3939a51657814712500073eaa2830ef8cdde12e4 (diff)
please the linters
Diffstat (limited to 'swr2_asr/loss_scores.py')
-rw-r--r--swr2_asr/loss_scores.py36
1 files changed, 19 insertions, 17 deletions
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py
index 977462d..c49cc15 100644
--- a/swr2_asr/loss_scores.py
+++ b/swr2_asr/loss_scores.py
@@ -1,7 +1,9 @@
+"""Methods for determining the loss and scores of the model."""
import numpy as np
def avg_wer(wer_scores, combined_ref_len):
+ """Calculate the average word error rate (WER) of the model."""
return float(sum(wer_scores)) / float(combined_ref_len)
@@ -13,34 +15,34 @@ def _levenshtein_distance(ref, hyp):
extend the edits to word level when calculate levenshtein disctance for
two sentences.
"""
- m = len(ref)
- n = len(hyp)
+ len_ref = len(ref)
+ len_hyp = len(hyp)
# special case
if ref == hyp:
return 0
- if m == 0:
- return n
- if n == 0:
- return m
+ if len_ref == 0:
+ return len_hyp
+ if len_hyp == 0:
+ return len_ref
- if m < n:
+ if len_ref < len_hyp:
ref, hyp = hyp, ref
- m, n = n, m
+ len_ref, len_hyp = len_hyp, len_ref
# use O(min(m, n)) space
- distance = np.zeros((2, n + 1), dtype=np.int32)
+ distance = np.zeros((2, len_hyp + 1), dtype=np.int32)
# initialize distance matrix
- for j in range(0, n + 1):
+ for j in range(0, len_hyp + 1):
distance[0][j] = j
# calculate levenshtein distance
- for i in range(1, m + 1):
+ for i in range(1, len_ref + 1):
prev_row_idx = (i - 1) % 2
cur_row_idx = i % 2
distance[cur_row_idx][0] = i
- for j in range(1, n + 1):
+ for j in range(1, len_hyp + 1):
if ref[i - 1] == hyp[j - 1]:
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
else:
@@ -49,7 +51,7 @@ def _levenshtein_distance(ref, hyp):
d_num = distance[prev_row_idx][j] + 1
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
- return distance[m % 2][n]
+ return distance[len_ref % 2][len_hyp]
def word_errors(
@@ -143,8 +145,8 @@ def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "):
if ref_len == 0:
raise ValueError("Reference's word number should be greater than 0.")
- wer = float(edit_distance) / ref_len
- return wer
+ word_error_rate = float(edit_distance) / ref_len
+ return word_error_rate
def cer(reference, hypothesis, ignore_case=False, remove_space=False):
@@ -181,5 +183,5 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False):
if ref_len == 0:
raise ValueError("Length of reference should be greater than 0.")
- cer = float(edit_distance) / ref_len
- return cer
+ char_error_rate = float(edit_distance) / ref_len
+ return char_error_rate