diff options
author | Pherkel | 2023-09-03 19:30:33 +0200 |
---|---|---|
committer | Pherkel | 2023-09-03 19:30:33 +0200 |
commit | 33f09080aee10bddb4797a557d676ee1f7b8de31 (patch) | |
tree | 1a35c13c2c91f84f542fe9ed5552bcbe60b437c3 /swr2_asr/model_deep_speech.py | |
parent | f3d2ea9a16944434a08e662c5ecfd6ba50e5ea89 (diff) |
idk, hopefully this works
Diffstat (limited to 'swr2_asr/model_deep_speech.py')
-rw-r--r-- | swr2_asr/model_deep_speech.py | 13 |
1 files changed, 4 insertions, 9 deletions
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index f00ebd4..dd07ff9 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,3 +1,4 @@ +"""Main definition of model""" import torch.nn.functional as F from torch import nn @@ -30,9 +31,7 @@ class ResidualCNN(nn.Module): ): super().__init__() - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) + self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2) self.cnn2 = nn.Conv2d( out_channels, out_channels, @@ -110,9 +109,7 @@ class SpeechRecognitionModel(nn.Module): # n residual cnn layers with filter size of 32 self.rescnn_layers = nn.Sequential( *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) + ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) for _ in range(n_cnn_layers) ] ) @@ -140,9 +137,7 @@ class SpeechRecognitionModel(nn.Module): data = self.cnn(data) data = self.rescnn_layers(data) sizes = data.size() - data = data.view( - sizes[0], sizes[1] * sizes[2], sizes[3] - ) # (batch, feature, time) + data = data.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) data = data.transpose(1, 2) # (batch, time, feature) data = self.fully_connected(data) data = self.birnn_layers(data) |