aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/format.yml6
-rw-r--r--Makefile6
-rw-r--r--mypy.ini2
-rw-r--r--swr2_asr/train.py127
4 files changed, 73 insertions, 68 deletions
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index 3480462..11b8164 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -11,13 +11,9 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: "3.10"
- - name: Install poetry
- run: |
- python -m pip install --upgrade pip
- pip install poetry
- name: Install dependencies
run: |
- poetry lock && poetry install --with cpu
+ pip install -r requirements.txt
- name: Run CI
run: |
make lint
diff --git a/Makefile b/Makefile
index 703405f..4f0ea9c 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
format:
- @black .
+ @poetry run black .
lint:
- @mypy --strict swr2_asr
- @pylint swr2_asr \ No newline at end of file
+ @poetry run mypy --strict swr2_asr
+ @poetry run pylint swr2_asr \ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index 9f3a098..f7cfc59 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -8,4 +8,4 @@ ignore_missing_imports = true
ignore_missing_imports = true
[mypy-click.*]
-ignore_missing_imports = true \ No newline at end of file
+ignore_missing_imports = true
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 8ee96b9..29f9372 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -65,20 +65,20 @@ class TextTransform:
self.char_map = {}
self.index_map = {}
for line in char_map_str.strip().split("\n"):
- ch, index = line.split()
- self.char_map[ch] = int(index)
- self.index_map[int(index)] = ch
+ char, index = line.split()
+ self.char_map[char] = int(index)
+ self.index_map[int(index)] = char
self.index_map[1] = " "
def text_to_int(self, text):
"""Use a character map and convert text to an integer sequence"""
int_sequence = []
- for c in text:
- if c == " ":
- ch = self.char_map["<SPACE>"]
+ for char in text:
+ if char == " ":
+ mapped_char = self.char_map["<SPACE>"]
else:
- ch = self.char_map[c]
- int_sequence.append(ch)
+ mapped_char = self.char_map[char]
+ int_sequence.append(mapped_char)
return int_sequence
def int_to_text(self, labels):
@@ -106,15 +106,15 @@ def data_processing(data, data_type="train"):
labels = []
input_lengths = []
label_lengths = []
- for x in data:
+ for sample in data:
if data_type == "train":
- spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1)
+ spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
elif data_type == "valid":
- spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1)
+ spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
else:
raise ValueError("data_type should be train or valid")
spectrograms.append(spec)
- label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower()))
+ label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower()))
labels.append(label)
input_lengths.append(spec.shape[0] // 2)
label_lengths.append(len(label))
@@ -129,11 +129,11 @@ def data_processing(data, data_type="train"):
return spectrograms, labels, input_lengths, label_lengths
-def GreedyDecoder(
+def greedy_decoder(
output, labels, label_lengths, blank_label=28, collapse_repeated=True
):
"""Greedily decode a sequence."""
- arg_maxes = torch.argmax(output, dim=2)
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
@@ -157,11 +157,11 @@ class CNNLayerNorm(nn.Module):
super(CNNLayerNorm, self).__init__()
self.layer_norm = nn.LayerNorm(n_feats)
- def forward(self, x):
+ def forward(self, data):
"""x (batch, channel, feature, time)"""
- x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
- x = self.layer_norm(x)
- return x.transpose(2, 3).contiguous() # (batch, channel, feature, time)
+ data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature)
+ data = self.layer_norm(data)
+ return data.transpose(2, 3).contiguous() # (batch, channel, feature, time)
class ResidualCNN(nn.Module):
@@ -193,18 +193,19 @@ class ResidualCNN(nn.Module):
self.layer_norm1 = CNNLayerNorm(n_feats)
self.layer_norm2 = CNNLayerNorm(n_feats)
- def forward(self, x):
- residual = x # (batch, channel, feature, time)
- x = self.layer_norm1(x)
- x = F.gelu(x)
- x = self.dropout1(x)
- x = self.cnn1(x)
- x = self.layer_norm2(x)
- x = F.gelu(x)
- x = self.dropout2(x)
- x = self.cnn2(x)
- x += residual
- return x # (batch, channel, feature, time)
+ def forward(self, data):
+ """x (batch, channel, feature, time)"""
+ residual = data # (batch, channel, feature, time)
+ data = self.layer_norm1(data)
+ data = F.gelu(data)
+ data = self.dropout1(data)
+ data = self.cnn1(data)
+ data = self.layer_norm2(data)
+ data = F.gelu(data)
+ data = self.dropout2(data)
+ data = self.cnn2(data)
+ data += residual
+ return data # (batch, channel, feature, time)
class BidirectionalGRU(nn.Module):
@@ -219,7 +220,7 @@ class BidirectionalGRU(nn.Module):
):
super(BidirectionalGRU, self).__init__()
- self.BiGRU = nn.GRU(
+ self.bi_gru = nn.GRU(
input_size=rnn_dim,
hidden_size=hidden_size,
num_layers=1,
@@ -229,15 +230,18 @@ class BidirectionalGRU(nn.Module):
self.layer_norm = nn.LayerNorm(rnn_dim)
self.dropout = nn.Dropout(dropout)
- def forward(self, x):
- x = self.layer_norm(x)
- x = F.gelu(x)
- x = self.dropout(x)
- x, _ = self.BiGRU(x)
- return x
+ def forward(self, data):
+ """data (batch, time, feature)"""
+ data = self.layer_norm(data)
+ data = F.gelu(data)
+ data = self.dropout(data)
+ data, _ = self.bi_gru(data)
+ return data
class SpeechRecognitionModel(nn.Module):
+ """Speech Recognition Model Inspired by DeepSpeech 2"""
+
def __init__(
self,
n_cnn_layers: int,
@@ -279,16 +283,19 @@ class SpeechRecognitionModel(nn.Module):
nn.Linear(rnn_dim, n_class),
)
- def forward(self, x):
- x = self.cnn(x)
- x = self.rescnn_layers(x)
- sizes = x.size()
- x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time)
- x = x.transpose(1, 2) # (batch, time, feature)
- x = self.fully_connected(x)
- x = self.birnn_layers(x)
- x = self.classifier(x)
- return x
+ def forward(self, data):
+ """data (batch, channel, feature, time)"""
+ data = self.cnn(data)
+ data = self.rescnn_layers(data)
+ sizes = data.size()
+ data = data.view(
+ sizes[0], sizes[1] * sizes[2], sizes[3]
+ ) # (batch, feature, time)
+ data = data.transpose(1, 2) # (batch, time, feature)
+ data = self.fully_connected(data)
+ data = self.birnn_layers(data)
+ data = self.classifier(data)
+ return data
class IterMeter(object):
@@ -298,9 +305,11 @@ class IterMeter(object):
self.val = 0
def step(self):
+ """step"""
self.val += 1
def get(self):
+ """get"""
return self.val
@@ -314,6 +323,7 @@ def train(
epoch,
iter_meter,
):
+ """Train"""
model.train()
data_len = len(train_loader.dataset)
for batch_idx, _data in enumerate(train_loader):
@@ -334,17 +344,16 @@ def train(
iter_meter.step()
if batch_idx % 100 == 0 or batch_idx == data_len:
print(
- "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
- epoch,
- batch_idx * len(spectrograms),
- data_len,
- 100.0 * batch_idx / len(train_loader),
- loss.item(),
- )
+ f"Train Epoch: \
+ {epoch} \
+ [{batch_idx * len(spectrograms)}/{data_len} \
+ ({100.0 * batch_idx / len(train_loader)}%)]\t \
+ Loss: {loss.item()}"
)
def test(model, device, test_loader, criterion):
+ """Test"""
print("\nevaluating...")
model.eval()
test_loss = 0
@@ -361,12 +370,12 @@ def test(model, device, test_loader, criterion):
loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
- decoded_preds, decoded_targets = GreedyDecoder(
+ decoded_preds, decoded_targets = greedy_decoder(
output.transpose(0, 1), labels, label_lengths
)
- for j in range(len(decoded_preds)):
- test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
- test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
+ for j, pred in enumerate(decoded_preds):
+ test_cer.append(cer(decoded_targets[j], pred))
+ test_wer.append(wer(decoded_targets[j], pred))
avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
@@ -394,7 +403,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
- device = torch.device("cuda" if use_cuda else "cpu")
+ device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
# device = torch.device("mps")
train_dataset = MultilingualLibriSpeech(