From c078ce6789c134aa05607903d3bf9e4be64df45d Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 15:45:35 +0200 Subject: big change! --- swr2_asr/model_deep_speech.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) (limited to 'swr2_asr/model_deep_speech.py') diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index dd07ff9..8ddbd99 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,8 +1,29 @@ -"""Main definition of model""" +"""Main definition of the Deep speech 2 model by Baidu Research. + +Following definition by Assembly AI +(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) +""" +from typing import TypedDict + import torch.nn.functional as F from torch import nn +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" + + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int + + class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" @@ -60,7 +81,7 @@ class ResidualCNN(nn.Module): class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" + """Bidirectional GRU with Layer Normalization and Dropout""" def __init__( self, -- cgit v1.2.3