diff options
author | Pherkel | 2023-09-11 15:45:35 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 15:45:35 +0200 |
commit | c078ce6789c134aa05607903d3bf9e4be64df45d (patch) | |
tree | afff5a3dd3e19a1cf906c096c3938f8b70fa683d /swr2_asr/model_deep_speech.py | |
parent | effde1d9e71864a2c5bd8464db0958f5bf2d1733 (diff) |
big change!
Diffstat (limited to 'swr2_asr/model_deep_speech.py')
-rw-r--r-- | swr2_asr/model_deep_speech.py | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index dd07ff9..8ddbd99 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,8 +1,29 @@ -"""Main definition of model""" +"""Main definition of the Deep speech 2 model by Baidu Research. + +Following definition by Assembly AI +(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) +""" +from typing import TypedDict + import torch.nn.functional as F from torch import nn +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" + + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int + + class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" @@ -60,7 +81,7 @@ class ResidualCNN(nn.Module): class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" + """Bidirectional GRU with Layer Normalization and Dropout""" def __init__( self, |