diff options
author | Pherkel | 2023-08-20 15:23:02 +0200 |
---|---|---|
committer | Pherkel | 2023-08-20 15:23:02 +0200 |
commit | 33744072d8c0906950cdc9cd00fc1f345a51d9d4 (patch) | |
tree | cd6771a7196a1ad39e8aa8d0f371480f1b6f3df0 /swr2_asr | |
parent | 3939a51657814712500073eaa2830ef8cdde12e4 (diff) |
please the linters
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/loss_scores.py | 36 | ||||
-rw-r--r-- | swr2_asr/tokenizer.py | 18 | ||||
-rw-r--r-- | swr2_asr/train.py | 12 |
3 files changed, 35 insertions, 31 deletions
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py index 977462d..c49cc15 100644 --- a/swr2_asr/loss_scores.py +++ b/swr2_asr/loss_scores.py @@ -1,7 +1,9 @@ +"""Methods for determining the loss and scores of the model.""" import numpy as np def avg_wer(wer_scores, combined_ref_len): + """Calculate the average word error rate (WER) of the model.""" return float(sum(wer_scores)) / float(combined_ref_len) @@ -13,34 +15,34 @@ def _levenshtein_distance(ref, hyp): extend the edits to word level when calculate levenshtein disctance for two sentences. """ - m = len(ref) - n = len(hyp) + len_ref = len(ref) + len_hyp = len(hyp) # special case if ref == hyp: return 0 - if m == 0: - return n - if n == 0: - return m + if len_ref == 0: + return len_hyp + if len_hyp == 0: + return len_ref - if m < n: + if len_ref < len_hyp: ref, hyp = hyp, ref - m, n = n, m + len_ref, len_hyp = len_hyp, len_ref # use O(min(m, n)) space - distance = np.zeros((2, n + 1), dtype=np.int32) + distance = np.zeros((2, len_hyp + 1), dtype=np.int32) # initialize distance matrix - for j in range(0, n + 1): + for j in range(0, len_hyp + 1): distance[0][j] = j # calculate levenshtein distance - for i in range(1, m + 1): + for i in range(1, len_ref + 1): prev_row_idx = (i - 1) % 2 cur_row_idx = i % 2 distance[cur_row_idx][0] = i - for j in range(1, n + 1): + for j in range(1, len_hyp + 1): if ref[i - 1] == hyp[j - 1]: distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] else: @@ -49,7 +51,7 @@ def _levenshtein_distance(ref, hyp): d_num = distance[prev_row_idx][j] + 1 distance[cur_row_idx][j] = min(s_num, i_num, d_num) - return distance[m % 2][n] + return distance[len_ref % 2][len_hyp] def word_errors( @@ -143,8 +145,8 @@ def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "): if ref_len == 0: raise ValueError("Reference's word number should be greater than 0.") - wer = float(edit_distance) / ref_len - return wer + word_error_rate = float(edit_distance) / ref_len + return word_error_rate def cer(reference, hypothesis, ignore_case=False, remove_space=False): @@ -181,5 +183,5 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False): if ref_len == 0: raise ValueError("Length of reference should be greater than 0.") - cer = float(edit_distance) / ref_len - return cer + char_error_rate = float(edit_distance) / ref_len + return char_error_rate diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 79d6727..a665159 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -63,15 +63,15 @@ class CharTokenizer: else: splits = [split] - chars = set() - for sp in splits: + chars: set = set() + for s_plit in splits: transcript_path = os.path.join( - dataset_path, language, sp, "transcripts.txt" + dataset_path, language, s_plit, "transcripts.txt" ) # check if dataset is downloaded, download if not if download and not os.path.exists(transcript_path): - MultilingualLibriSpeech(dataset_path, language, sp, download=True) + MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) with open( transcript_path, @@ -82,7 +82,7 @@ class CharTokenizer: lines = [line.split(" ", 1)[1] for line in lines] lines = [line.strip() for line in lines] - for line in tqdm(lines, desc=f"Training tokenizer on {sp} split"): + for line in tqdm(lines, desc=f"Training tokenizer on {s_plit} split"): chars.update(line) offset = len(self.char_map) for i, char in enumerate(chars): @@ -205,10 +205,12 @@ def train_bpe_tokenizer( lines = [] - for sp in splits: - transcripts_path = os.path.join(dataset_path, language, sp, "transcripts.txt") + for s_plit in splits: + transcripts_path = os.path.join( + dataset_path, language, s_plit, "transcripts.txt" + ) if download and not os.path.exists(transcripts_path): - MultilingualLibriSpeech(dataset_path, language, sp, download=True) + MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) with open( transcripts_path, diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 8943f71..6af1e80 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -83,7 +83,7 @@ class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" def __init__(self, n_feats: int): - super(CNNLayerNorm, self).__init__() + super().__init__() self.layer_norm = nn.LayerNorm(n_feats) def forward(self, data): @@ -105,7 +105,7 @@ class ResidualCNN(nn.Module): dropout: float, n_feats: int, ): - super(ResidualCNN, self).__init__() + super().__init__() self.cnn1 = nn.Conv2d( in_channels, out_channels, kernel, stride, padding=kernel // 2 @@ -147,7 +147,7 @@ class BidirectionalGRU(nn.Module): dropout: float, batch_first: bool, ): - super(BidirectionalGRU, self).__init__() + super().__init__() self.bi_gru = nn.GRU( input_size=rnn_dim, @@ -181,7 +181,7 @@ class SpeechRecognitionModel(nn.Module): stride: int = 2, dropout: float = 0.1, ): - super(SpeechRecognitionModel, self).__init__() + super().__init__() n_feats //= 2 self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) # n residual cnn layers with filter size of 32 @@ -227,7 +227,7 @@ class SpeechRecognitionModel(nn.Module): return data -class IterMeter(object): +class IterMeter: """keeps track of total iterations""" def __init__(self): @@ -381,7 +381,7 @@ def run( ).to(device) print( - "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) + "Num Model Parameters", sum((param.nelement() for param in model.parameters())) ) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) |