From 33f09080aee10bddb4797a557d676ee1f7b8de31 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Sun, 3 Sep 2023 19:30:33 +0200 Subject: idk, hopefully this works --- swr2_asr/model_deep_speech.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) (limited to 'swr2_asr/model_deep_speech.py') diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index f00ebd4..dd07ff9 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,3 +1,4 @@ +"""Main definition of model""" import torch.nn.functional as F from torch import nn @@ -30,9 +31,7 @@ class ResidualCNN(nn.Module): ): super().__init__() - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) + self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2) self.cnn2 = nn.Conv2d( out_channels, out_channels, @@ -110,9 +109,7 @@ class SpeechRecognitionModel(nn.Module): # n residual cnn layers with filter size of 32 self.rescnn_layers = nn.Sequential( *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) + ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) for _ in range(n_cnn_layers) ] ) @@ -140,9 +137,7 @@ class SpeechRecognitionModel(nn.Module): data = self.cnn(data) data = self.rescnn_layers(data) sizes = data.size() - data = data.view( - sizes[0], sizes[1] * sizes[2], sizes[3] - ) # (batch, feature, time) + data = data.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) data = data.transpose(1, 2) # (batch, time, feature) data = self.fully_connected(data) data = self.birnn_layers(data) -- cgit v1.2.3