aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/model_deep_speech.py
diff options
context:
space:
mode:
authorPherkel2023-09-11 23:08:45 +0200
committerPherkel2023-09-11 23:08:45 +0200
commit96fee5f59f67187292ddf37db4660c5085fb66b5 (patch)
tree0c5d9b8520c1e3655a337fb9a3877adeedca6766 /swr2_asr/model_deep_speech.py
parent6f5513140f153206cfa91df3077e67ce58043d35 (diff)
changed name to match pre-trained weights
Diffstat (limited to 'swr2_asr/model_deep_speech.py')
-rw-r--r--swr2_asr/model_deep_speech.py68
1 files changed, 23 insertions, 45 deletions
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
index 77f4c8a..73f5a81 100644
--- a/swr2_asr/model_deep_speech.py
+++ b/swr2_asr/model_deep_speech.py
@@ -10,8 +10,8 @@ from torch import nn
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
- def __init__(self, n_feats: int):
- super().__init__()
+ def __init__(self, n_feats):
+ super(CNNLayerNorm, self).__init__()
self.layer_norm = nn.LayerNorm(n_feats)
def forward(self, data):
@@ -22,34 +22,22 @@ class CNNLayerNorm(nn.Module):
class ResidualCNN(nn.Module):
- """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
+ """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
+ except with layer norm instead of batch norm
+ """
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel: int,
- stride: int,
- dropout: float,
- n_feats: int,
- ):
- super().__init__()
+ def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
+ super(ResidualCNN, self).__init__()
self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
- self.cnn2 = nn.Conv2d(
- out_channels,
- out_channels,
- kernel,
- stride,
- padding=kernel // 2,
- )
+ self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.layer_norm1 = CNNLayerNorm(n_feats)
self.layer_norm2 = CNNLayerNorm(n_feats)
def forward(self, data):
- """x (batch, channel, feature, time)"""
+ """data (batch, channel, feature, time)"""
residual = data # (batch, channel, feature, time)
data = self.layer_norm1(data)
data = F.gelu(data)
@@ -64,18 +52,12 @@ class ResidualCNN(nn.Module):
class BidirectionalGRU(nn.Module):
- """Bidirectional GRU with Layer Normalization and Dropout"""
+ """Bidirectional GRU layer"""
- def __init__(
- self,
- rnn_dim: int,
- hidden_size: int,
- dropout: float,
- batch_first: bool,
- ):
- super().__init__()
+ def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
+ super(BidirectionalGRU, self).__init__()
- self.bi_gru = nn.GRU(
+ self.BiGRU = nn.GRU( # pylint: disable=invalid-name
input_size=rnn_dim,
hidden_size=hidden_size,
num_layers=1,
@@ -86,11 +68,11 @@ class BidirectionalGRU(nn.Module):
self.dropout = nn.Dropout(dropout)
def forward(self, data):
- """data (batch, time, feature)"""
+ """x (batch, time, feature)"""
data = self.layer_norm(data)
data = F.gelu(data)
+ data, _ = self.BiGRU(data)
data = self.dropout(data)
- data, _ = self.bi_gru(data)
return data
@@ -98,18 +80,14 @@ class SpeechRecognitionModel(nn.Module):
"""Speech Recognition Model Inspired by DeepSpeech 2"""
def __init__(
- self,
- n_cnn_layers: int,
- n_rnn_layers: int,
- rnn_dim: int,
- n_class: int,
- n_feats: int,
- stride: int = 2,
- dropout: float = 0.1,
+ self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1
):
- super().__init__()
- n_feats //= 2
- self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
+ super(SpeechRecognitionModel, self).__init__()
+ n_feats = n_feats // 2
+ self.cnn = nn.Conv2d(
+ 1, 32, 3, stride=stride, padding=3 // 2
+ ) # cnn for extracting heirachal features
+
# n residual cnn layers with filter size of 32
self.rescnn_layers = nn.Sequential(
*[
@@ -137,7 +115,7 @@ class SpeechRecognitionModel(nn.Module):
)
def forward(self, data):
- """data (batch, channel, feature, time)"""
+ """x (batch, channel, feature, time)"""
data = self.cnn(data)
data = self.rescnn_layers(data)
sizes = data.size()