diff options
-rw-r--r-- | .github/workflows/format.yml | 6 | ||||
-rw-r--r-- | Makefile | 6 | ||||
-rw-r--r-- | mypy.ini | 2 | ||||
-rw-r--r-- | swr2_asr/train.py | 127 |
4 files changed, 73 insertions, 68 deletions
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 3480462..11b8164 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -11,13 +11,9 @@ jobs: uses: actions/setup-python@v3 with: python-version: "3.10" - - name: Install poetry - run: | - python -m pip install --upgrade pip - pip install poetry - name: Install dependencies run: | - poetry lock && poetry install --with cpu + pip install -r requirements.txt - name: Run CI run: | make lint @@ -1,6 +1,6 @@ format: - @black . + @poetry run black . lint: - @mypy --strict swr2_asr - @pylint swr2_asr
\ No newline at end of file + @poetry run mypy --strict swr2_asr + @poetry run pylint swr2_asr
\ No newline at end of file @@ -8,4 +8,4 @@ ignore_missing_imports = true ignore_missing_imports = true [mypy-click.*] -ignore_missing_imports = true
\ No newline at end of file +ignore_missing_imports = true diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 8ee96b9..29f9372 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -65,20 +65,20 @@ class TextTransform: self.char_map = {} self.index_map = {} for line in char_map_str.strip().split("\n"): - ch, index = line.split() - self.char_map[ch] = int(index) - self.index_map[int(index)] = ch + char, index = line.split() + self.char_map[char] = int(index) + self.index_map[int(index)] = char self.index_map[1] = " " def text_to_int(self, text): """Use a character map and convert text to an integer sequence""" int_sequence = [] - for c in text: - if c == " ": - ch = self.char_map["<SPACE>"] + for char in text: + if char == " ": + mapped_char = self.char_map["<SPACE>"] else: - ch = self.char_map[c] - int_sequence.append(ch) + mapped_char = self.char_map[char] + int_sequence.append(mapped_char) return int_sequence def int_to_text(self, labels): @@ -106,15 +106,15 @@ def data_processing(data, data_type="train"): labels = [] input_lengths = [] label_lengths = [] - for x in data: + for sample in data: if data_type == "train": - spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) + spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) elif data_type == "valid": - spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) + spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) else: raise ValueError("data_type should be train or valid") spectrograms.append(spec) - label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower())) + label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower())) labels.append(label) input_lengths.append(spec.shape[0] // 2) label_lengths.append(len(label)) @@ -129,11 +129,11 @@ def data_processing(data, data_type="train"): return spectrograms, labels, input_lengths, label_lengths -def GreedyDecoder( +def greedy_decoder( output, labels, label_lengths, blank_label=28, collapse_repeated=True ): """Greedily decode a sequence.""" - arg_maxes = torch.argmax(output, dim=2) + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] targets = [] for i, args in enumerate(arg_maxes): @@ -157,11 +157,11 @@ class CNNLayerNorm(nn.Module): super(CNNLayerNorm, self).__init__() self.layer_norm = nn.LayerNorm(n_feats) - def forward(self, x): + def forward(self, data): """x (batch, channel, feature, time)""" - x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature) - x = self.layer_norm(x) - return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) + data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) + data = self.layer_norm(data) + return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) class ResidualCNN(nn.Module): @@ -193,18 +193,19 @@ class ResidualCNN(nn.Module): self.layer_norm1 = CNNLayerNorm(n_feats) self.layer_norm2 = CNNLayerNorm(n_feats) - def forward(self, x): - residual = x # (batch, channel, feature, time) - x = self.layer_norm1(x) - x = F.gelu(x) - x = self.dropout1(x) - x = self.cnn1(x) - x = self.layer_norm2(x) - x = F.gelu(x) - x = self.dropout2(x) - x = self.cnn2(x) - x += residual - return x # (batch, channel, feature, time) + def forward(self, data): + """x (batch, channel, feature, time)""" + residual = data # (batch, channel, feature, time) + data = self.layer_norm1(data) + data = F.gelu(data) + data = self.dropout1(data) + data = self.cnn1(data) + data = self.layer_norm2(data) + data = F.gelu(data) + data = self.dropout2(data) + data = self.cnn2(data) + data += residual + return data # (batch, channel, feature, time) class BidirectionalGRU(nn.Module): @@ -219,7 +220,7 @@ class BidirectionalGRU(nn.Module): ): super(BidirectionalGRU, self).__init__() - self.BiGRU = nn.GRU( + self.bi_gru = nn.GRU( input_size=rnn_dim, hidden_size=hidden_size, num_layers=1, @@ -229,15 +230,18 @@ class BidirectionalGRU(nn.Module): self.layer_norm = nn.LayerNorm(rnn_dim) self.dropout = nn.Dropout(dropout) - def forward(self, x): - x = self.layer_norm(x) - x = F.gelu(x) - x = self.dropout(x) - x, _ = self.BiGRU(x) - return x + def forward(self, data): + """data (batch, time, feature)""" + data = self.layer_norm(data) + data = F.gelu(data) + data = self.dropout(data) + data, _ = self.bi_gru(data) + return data class SpeechRecognitionModel(nn.Module): + """Speech Recognition Model Inspired by DeepSpeech 2""" + def __init__( self, n_cnn_layers: int, @@ -279,16 +283,19 @@ class SpeechRecognitionModel(nn.Module): nn.Linear(rnn_dim, n_class), ) - def forward(self, x): - x = self.cnn(x) - x = self.rescnn_layers(x) - sizes = x.size() - x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) - x = x.transpose(1, 2) # (batch, time, feature) - x = self.fully_connected(x) - x = self.birnn_layers(x) - x = self.classifier(x) - return x + def forward(self, data): + """data (batch, channel, feature, time)""" + data = self.cnn(data) + data = self.rescnn_layers(data) + sizes = data.size() + data = data.view( + sizes[0], sizes[1] * sizes[2], sizes[3] + ) # (batch, feature, time) + data = data.transpose(1, 2) # (batch, time, feature) + data = self.fully_connected(data) + data = self.birnn_layers(data) + data = self.classifier(data) + return data class IterMeter(object): @@ -298,9 +305,11 @@ class IterMeter(object): self.val = 0 def step(self): + """step""" self.val += 1 def get(self): + """get""" return self.val @@ -314,6 +323,7 @@ def train( epoch, iter_meter, ): + """Train""" model.train() data_len = len(train_loader.dataset) for batch_idx, _data in enumerate(train_loader): @@ -334,17 +344,16 @@ def train( iter_meter.step() if batch_idx % 100 == 0 or batch_idx == data_len: print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(spectrograms), - data_len, - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) + f"Train Epoch: \ + {epoch} \ + [{batch_idx * len(spectrograms)}/{data_len} \ + ({100.0 * batch_idx / len(train_loader)}%)]\t \ + Loss: {loss.item()}" ) def test(model, device, test_loader, criterion): + """Test""" print("\nevaluating...") model.eval() test_loss = 0 @@ -361,12 +370,12 @@ def test(model, device, test_loader, criterion): loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = GreedyDecoder( + decoded_preds, decoded_targets = greedy_decoder( output.transpose(0, 1), labels, label_lengths ) - for j in range(len(decoded_preds)): - test_cer.append(cer(decoded_targets[j], decoded_preds[j])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + for j, pred in enumerate(decoded_preds): + test_cer.append(cer(decoded_targets[j], pred)) + test_wer.append(wer(decoded_targets[j], pred)) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) @@ -394,7 +403,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No use_cuda = torch.cuda.is_available() torch.manual_seed(42) - device = torch.device("cuda" if use_cuda else "cpu") + device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") train_dataset = MultilingualLibriSpeech( |