aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils
diff options
context:
space:
mode:
authorPherkel2023-09-11 15:45:35 +0200
committerPherkel2023-09-11 15:45:35 +0200
commitc078ce6789c134aa05607903d3bf9e4be64df45d (patch)
treeafff5a3dd3e19a1cf906c096c3938f8b70fa683d /swr2_asr/utils
parenteffde1d9e71864a2c5bd8464db0958f5bf2d1733 (diff)
big change!
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r--swr2_asr/utils/data.py14
-rw-r--r--swr2_asr/utils/loss_scores.py203
2 files changed, 203 insertions, 14 deletions
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index 74d10c9..e939e1d 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -76,20 +76,6 @@ def split_to_mls_split(split_name: Split) -> MLSSplit:
return split_name # type: ignore
-class Sample(TypedDict):
- """Type for a sample in the dataset"""
-
- waveform: torch.Tensor
- spectrogram: torch.Tensor
- input_length: int
- utterance: torch.Tensor
- utterance_length: int
- sample_rate: int
- speaker_id: str
- book_id: str
- chapter_id: str
-
-
class MLSDataset(Dataset):
"""Custom Dataset for reading Multilingual LibriSpeech
diff --git a/swr2_asr/utils/loss_scores.py b/swr2_asr/utils/loss_scores.py
new file mode 100644
index 0000000..80285f6
--- /dev/null
+++ b/swr2_asr/utils/loss_scores.py
@@ -0,0 +1,203 @@
+"""Methods for determining the loss and scores of the model."""
+import numpy as np
+
+
+def avg_wer(wer_scores, combined_ref_len) -> float:
+ """Calculate the average word error rate (WER).
+
+ Args:
+ wer_scores: word error rate scores
+ combined_ref_len: combined length of reference sentences
+
+ Returns:
+ average word error rate (float)
+
+ Usage:
+ >>> avg_wer([0.5, 0.5], 2)
+ 0.5
+ """
+ return float(sum(wer_scores)) / float(combined_ref_len)
+
+
+def _levenshtein_distance(ref, hyp) -> int:
+ """Levenshtein distance.
+
+ Args:
+ ref: reference sentence
+ hyp: hypothesis sentence
+
+ Returns:
+ distance: levenshtein distance between reference and hypothesis
+
+ Usage:
+ >>> _levenshtein_distance("hello", "helo")
+ 2
+ """
+ len_ref = len(ref)
+ len_hyp = len(hyp)
+
+ # special case
+ if ref == hyp:
+ return 0
+ if len_ref == 0:
+ return len_hyp
+ if len_hyp == 0:
+ return len_ref
+
+ if len_ref < len_hyp:
+ ref, hyp = hyp, ref
+ len_ref, len_hyp = len_hyp, len_ref
+
+ # use O(min(m, n)) space
+ distance = np.zeros((2, len_hyp + 1), dtype=np.int32)
+
+ # initialize distance matrix
+ for j in range(0, len_hyp + 1):
+ distance[0][j] = j
+
+ # calculate levenshtein distance
+ for i in range(1, len_ref + 1):
+ prev_row_idx = (i - 1) % 2
+ cur_row_idx = i % 2
+ distance[cur_row_idx][0] = i
+ for j in range(1, len_hyp + 1):
+ if ref[i - 1] == hyp[j - 1]:
+ distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
+ else:
+ s_num = distance[prev_row_idx][j - 1] + 1
+ i_num = distance[cur_row_idx][j - 1] + 1
+ d_num = distance[prev_row_idx][j] + 1
+ distance[cur_row_idx][j] = min(s_num, i_num, d_num)
+
+ return distance[len_ref % 2][len_hyp]
+
+
+def word_errors(
+ reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "
+) -> tuple[float, int]:
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in word-level.
+
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ delimiter: Delimiter of input sentences.
+
+ Returns:
+ Levenshtein distance and length of reference sentence.
+
+ Usage:
+ >>> word_errors("hello world", "hello")
+ 1, 2
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ ref_words = reference.split(delimiter)
+ hyp_words = hypothesis.split(delimiter)
+
+ edit_distance = _levenshtein_distance(ref_words, hyp_words)
+ return float(edit_distance), len(ref_words)
+
+
+def char_errors(
+ reference: str,
+ hypothesis: str,
+ ignore_case: bool = False,
+ remove_space: bool = False,
+) -> tuple[float, int]:
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in char-level.
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ remove_space: Whether remove internal space characters
+
+ Returns:
+ Levenshtein distance and length of reference sentence.
+
+ Usage:
+ >>> char_errors("hello world", "hello")
+ 1, 10
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ join_char = " "
+ if remove_space:
+ join_char = ""
+
+ reference = join_char.join(filter(None, reference.split(" ")))
+ hypothesis = join_char.join(filter(None, hypothesis.split(" ")))
+
+ edit_distance = _levenshtein_distance(reference, hypothesis)
+ return float(edit_distance), len(reference)
+
+
+def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" ") -> float:
+ """Calculate word error rate (WER). WER compares reference text and
+ hypothesis text in word-level.
+ WER is defined as:
+ WER = (Sw + Dw + Iw) / Nw
+ with:
+ Sw is the number of words subsituted,
+ Dw is the number of words deleted,
+ Iw is the number of words inserted,
+ Nw is the number of words in the reference
+
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ delimiter: Delimiter of input sentences.
+
+ Returns:
+ Word error rate (float)
+
+ Usage:
+ >>> wer("hello world", "hello")
+ 0.5
+ """
+ edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter)
+
+ if ref_len == 0:
+ raise ValueError("Reference's word number should be greater than 0.")
+
+ word_error_rate = float(edit_distance) / ref_len
+ return word_error_rate
+
+
+def cer(reference, hypothesis, ignore_case=False, remove_space=False):
+ """Calculate charactor error rate (CER). CER compares reference text and
+ hypothesis text in char-level. CER is defined as:
+ CER = (Sc + Dc + Ic) / Nc
+ with
+ Sc is the number of characters substituted,
+ Dc is the number of characters deleted,
+ Ic is the number of characters inserted
+ Nc is the number of characters in the reference
+
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ remove_space: Whether remove internal space characters
+
+ Returns:
+ Character error rate (float)
+
+ Usage:
+ >>> cer("hello world", "hello")
+ 0.2727272727272727
+ """
+ edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space)
+
+ if ref_len == 0:
+ raise ValueError("Length of reference should be greater than 0.")
+
+ char_error_rate = float(edit_distance) / ref_len
+ return char_error_rate