diff options
author | Pherkel | 2023-09-11 23:08:45 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 23:08:45 +0200 |
commit | 96fee5f59f67187292ddf37db4660c5085fb66b5 (patch) | |
tree | 0c5d9b8520c1e3655a337fb9a3877adeedca6766 /swr2_asr/model_deep_speech.py | |
parent | 6f5513140f153206cfa91df3077e67ce58043d35 (diff) |
changed name to match pre-trained weights
Diffstat (limited to 'swr2_asr/model_deep_speech.py')
-rw-r--r-- | swr2_asr/model_deep_speech.py | 68 |
1 files changed, 23 insertions, 45 deletions
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index 77f4c8a..73f5a81 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -10,8 +10,8 @@ from torch import nn class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" - def __init__(self, n_feats: int): - super().__init__() + def __init__(self, n_feats): + super(CNNLayerNorm, self).__init__() self.layer_norm = nn.LayerNorm(n_feats) def forward(self, data): @@ -22,34 +22,22 @@ class CNNLayerNorm(nn.Module): class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf + except with layer norm instead of batch norm + """ - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super().__init__() + def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats): + super(ResidualCNN, self).__init__() self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) + self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.layer_norm1 = CNNLayerNorm(n_feats) self.layer_norm2 = CNNLayerNorm(n_feats) def forward(self, data): - """x (batch, channel, feature, time)""" + """data (batch, channel, feature, time)""" residual = data # (batch, channel, feature, time) data = self.layer_norm1(data) data = F.gelu(data) @@ -64,18 +52,12 @@ class ResidualCNN(nn.Module): class BidirectionalGRU(nn.Module): - """Bidirectional GRU with Layer Normalization and Dropout""" + """Bidirectional GRU layer""" - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super().__init__() + def __init__(self, rnn_dim, hidden_size, dropout, batch_first): + super(BidirectionalGRU, self).__init__() - self.bi_gru = nn.GRU( + self.BiGRU = nn.GRU( # pylint: disable=invalid-name input_size=rnn_dim, hidden_size=hidden_size, num_layers=1, @@ -86,11 +68,11 @@ class BidirectionalGRU(nn.Module): self.dropout = nn.Dropout(dropout) def forward(self, data): - """data (batch, time, feature)""" + """x (batch, time, feature)""" data = self.layer_norm(data) data = F.gelu(data) + data, _ = self.BiGRU(data) data = self.dropout(data) - data, _ = self.bi_gru(data) return data @@ -98,18 +80,14 @@ class SpeechRecognitionModel(nn.Module): """Speech Recognition Model Inspired by DeepSpeech 2""" def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, + self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1 ): - super().__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + super(SpeechRecognitionModel, self).__init__() + n_feats = n_feats // 2 + self.cnn = nn.Conv2d( + 1, 32, 3, stride=stride, padding=3 // 2 + ) # cnn for extracting heirachal features + # n residual cnn layers with filter size of 32 self.rescnn_layers = nn.Sequential( *[ @@ -137,7 +115,7 @@ class SpeechRecognitionModel(nn.Module): ) def forward(self, data): - """data (batch, channel, feature, time)""" + """x (batch, channel, feature, time)""" data = self.cnn(data) data = self.rescnn_layers(data) sizes = data.size() |