aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/model_deep_speech.py
diff options
context:
space:
mode:
authorPherkel2023-09-11 15:45:35 +0200
committerPherkel2023-09-11 15:45:35 +0200
commitc078ce6789c134aa05607903d3bf9e4be64df45d (patch)
treeafff5a3dd3e19a1cf906c096c3938f8b70fa683d /swr2_asr/model_deep_speech.py
parenteffde1d9e71864a2c5bd8464db0958f5bf2d1733 (diff)
big change!
Diffstat (limited to 'swr2_asr/model_deep_speech.py')
-rw-r--r--swr2_asr/model_deep_speech.py25
1 files changed, 23 insertions, 2 deletions
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
index dd07ff9..8ddbd99 100644
--- a/swr2_asr/model_deep_speech.py
+++ b/swr2_asr/model_deep_speech.py
@@ -1,8 +1,29 @@
-"""Main definition of model"""
+"""Main definition of the Deep speech 2 model by Baidu Research.
+
+Following definition by Assembly AI
+(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/)
+"""
+from typing import TypedDict
+
import torch.nn.functional as F
from torch import nn
+class HParams(TypedDict):
+ """Type for the hyperparameters of the model."""
+
+ n_cnn_layers: int
+ n_rnn_layers: int
+ rnn_dim: int
+ n_class: int
+ n_feats: int
+ stride: int
+ dropout: float
+ learning_rate: float
+ batch_size: int
+ epochs: int
+
+
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
@@ -60,7 +81,7 @@ class ResidualCNN(nn.Module):
class BidirectionalGRU(nn.Module):
- """BIdirectional GRU with Layer Normalization and Dropout"""
+ """Bidirectional GRU with Layer Normalization and Dropout"""
def __init__(
self,