diff options
Diffstat (limited to 'swr2_asr/model_deep_speech.py')
-rw-r--r-- | swr2_asr/model_deep_speech.py | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index dd07ff9..8ddbd99 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,8 +1,29 @@ -"""Main definition of model""" +"""Main definition of the Deep speech 2 model by Baidu Research. + +Following definition by Assembly AI +(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) +""" +from typing import TypedDict + import torch.nn.functional as F from torch import nn +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" + + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int + + class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" @@ -60,7 +81,7 @@ class ResidualCNN(nn.Module): class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" + """Bidirectional GRU with Layer Normalization and Dropout""" def __init__( self, |