aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/loss_scores.py
diff options
context:
space:
mode:
authorPherkel2023-09-11 14:11:26 +0200
committerPherkel2023-09-11 14:11:26 +0200
commitb8a4cdc6673787333cac282de744fd11604ca161 (patch)
tree69ad8f463fc707d767b5163c428ca8ceabd6afac /swr2_asr/loss_scores.py
parentf9846193289c81d89342b6a36e951605c2cfa189 (diff)
improved loss_scores
Diffstat (limited to 'swr2_asr/loss_scores.py')
-rw-r--r--swr2_asr/loss_scores.py154
1 files changed, 87 insertions, 67 deletions
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py
index 63c8a8f..80285f6 100644
--- a/swr2_asr/loss_scores.py
+++ b/swr2_asr/loss_scores.py
@@ -2,18 +2,36 @@
import numpy as np
-def avg_wer(wer_scores, combined_ref_len):
- """Calculate the average word error rate (WER) of the model."""
+def avg_wer(wer_scores, combined_ref_len) -> float:
+ """Calculate the average word error rate (WER).
+
+ Args:
+ wer_scores: word error rate scores
+ combined_ref_len: combined length of reference sentences
+
+ Returns:
+ average word error rate (float)
+
+ Usage:
+ >>> avg_wer([0.5, 0.5], 2)
+ 0.5
+ """
return float(sum(wer_scores)) / float(combined_ref_len)
-def _levenshtein_distance(ref, hyp):
- """Levenshtein distance is a string metric for measuring the difference
- between two sequences. Informally, the levenshtein disctance is defined as
- the minimum number of single-character edits (substitutions, insertions or
- deletions) required to change one word into the other. We can naturally
- extend the edits to word level when calculating levenshtein disctance for
- two sentences.
+def _levenshtein_distance(ref, hyp) -> int:
+ """Levenshtein distance.
+
+ Args:
+ ref: reference sentence
+ hyp: hypothesis sentence
+
+ Returns:
+ distance: levenshtein distance between reference and hypothesis
+
+ Usage:
+ >>> _levenshtein_distance("hello", "helo")
+ 2
"""
len_ref = len(ref)
len_hyp = len(hyp)
@@ -54,19 +72,24 @@ def _levenshtein_distance(ref, hyp):
return distance[len_ref % 2][len_hyp]
-def word_errors(reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "):
+def word_errors(
+ reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "
+) -> tuple[float, int]:
"""Compute the levenshtein distance between reference sequence and
hypothesis sequence in word-level.
- :param reference: The reference sentence.
- :type reference: basestring
- :param hypothesis: The hypothesis sentence.
- :type hypothesis: basestring
- :param ignore_case: Whether case-sensitive or not.
- :type ignore_case: bool
- :param delimiter: Delimiter of input sentences.
- :type delimiter: char
- :return: Levenshtein distance and word number of reference sentence.
- :rtype: list
+
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ delimiter: Delimiter of input sentences.
+
+ Returns:
+ Levenshtein distance and length of reference sentence.
+
+ Usage:
+ >>> word_errors("hello world", "hello")
+ 1, 2
"""
if ignore_case:
reference = reference.lower()
@@ -84,19 +107,21 @@ def char_errors(
hypothesis: str,
ignore_case: bool = False,
remove_space: bool = False,
-):
+) -> tuple[float, int]:
"""Compute the levenshtein distance between reference sequence and
hypothesis sequence in char-level.
- :param reference: The reference sentence.
- :type reference: basestring
- :param hypothesis: The hypothesis sentence.
- :type hypothesis: basestring
- :param ignore_case: Whether case-sensitive or not.
- :type ignore_case: bool
- :param remove_space: Whether remove internal space characters
- :type remove_space: bool
- :return: Levenshtein distance and length of reference sentence.
- :rtype: list
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ remove_space: Whether remove internal space characters
+
+ Returns:
+ Levenshtein distance and length of reference sentence.
+
+ Usage:
+ >>> char_errors("hello world", "hello")
+ 1, 10
"""
if ignore_case:
reference = reference.lower()
@@ -113,30 +138,29 @@ def char_errors(
return float(edit_distance), len(reference)
-def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "):
+def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" ") -> float:
"""Calculate word error rate (WER). WER compares reference text and
- hypothesis text in word-level. WER is defined as:
- .. math::
+ hypothesis text in word-level.
+ WER is defined as:
WER = (Sw + Dw + Iw) / Nw
- where
- .. code-block:: text
+ with:
Sw is the number of words subsituted,
Dw is the number of words deleted,
Iw is the number of words inserted,
Nw is the number of words in the reference
- We can use levenshtein distance to calculate WER. Please draw an attention
- that empty items will be removed when splitting sentences by delimiter.
- :param reference: The reference sentence.
- :type reference: basestring
- :param hypothesis: The hypothesis sentence.
- :type hypothesis: basestring
- :param ignore_case: Whether case-sensitive or not.
- :type ignore_case: bool
- :param delimiter: Delimiter of input sentences.
- :type delimiter: char
- :return: Word error rate.
- :rtype: float
- :raises ValueError: If word number of reference is zero.
+
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ delimiter: Delimiter of input sentences.
+
+ Returns:
+ Word error rate (float)
+
+ Usage:
+ >>> wer("hello world", "hello")
+ 0.5
"""
edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter)
@@ -150,29 +174,25 @@ def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "):
def cer(reference, hypothesis, ignore_case=False, remove_space=False):
"""Calculate charactor error rate (CER). CER compares reference text and
hypothesis text in char-level. CER is defined as:
- .. math::
CER = (Sc + Dc + Ic) / Nc
- where
- .. code-block:: text
+ with
Sc is the number of characters substituted,
Dc is the number of characters deleted,
Ic is the number of characters inserted
Nc is the number of characters in the reference
- We can use levenshtein distance to calculate CER. Chinese input should be
- encoded to unicode. Please draw an attention that the leading and tailing
- space characters will be truncated and multiple consecutive space
- characters in a sentence will be replaced by one space character.
- :param reference: The reference sentence.
- :type reference: basestring
- :param hypothesis: The hypothesis sentence.
- :type hypothesis: basestring
- :param ignore_case: Whether case-sensitive or not.
- :type ignore_case: bool
- :param remove_space: Whether remove internal space characters
- :type remove_space: bool
- :return: Character error rate.
- :rtype: float
- :raises ValueError: If the reference length is zero.
+
+ Args:
+ reference: The reference sentence.
+ hypothesis: The hypothesis sentence.
+ ignore_case: Whether case-sensitive or not.
+ remove_space: Whether remove internal space characters
+
+ Returns:
+ Character error rate (float)
+
+ Usage:
+ >>> cer("hello world", "hello")
+ 0.2727272727272727
"""
edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space)