aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/model_deep_speech.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/model_deep_speech.py')
-rw-r--r--swr2_asr/model_deep_speech.py13
1 files changed, 4 insertions, 9 deletions
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
index f00ebd4..dd07ff9 100644
--- a/swr2_asr/model_deep_speech.py
+++ b/swr2_asr/model_deep_speech.py
@@ -1,3 +1,4 @@
+"""Main definition of model"""
import torch.nn.functional as F
from torch import nn
@@ -30,9 +31,7 @@ class ResidualCNN(nn.Module):
):
super().__init__()
- self.cnn1 = nn.Conv2d(
- in_channels, out_channels, kernel, stride, padding=kernel // 2
- )
+ self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
self.cnn2 = nn.Conv2d(
out_channels,
out_channels,
@@ -110,9 +109,7 @@ class SpeechRecognitionModel(nn.Module):
# n residual cnn layers with filter size of 32
self.rescnn_layers = nn.Sequential(
*[
- ResidualCNN(
- 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
- )
+ ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
for _ in range(n_cnn_layers)
]
)
@@ -140,9 +137,7 @@ class SpeechRecognitionModel(nn.Module):
data = self.cnn(data)
data = self.rescnn_layers(data)
sizes = data.size()
- data = data.view(
- sizes[0], sizes[1] * sizes[2], sizes[3]
- ) # (batch, feature, time)
+ data = data.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time)
data = data.transpose(1, 2) # (batch, time, feature)
data = self.fully_connected(data)
data = self.birnn_layers(data)