"""Main definition of the Deep speech 2 model by Baidu Research.

Following definition by Assembly AI 
(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/)
"""
import torch.nn.functional as F
from torch import nn


class CNNLayerNorm(nn.Module):
    """Layer normalization built for cnns input"""

    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, data):
        """x (batch, channel, feature, time)"""
        data = data.transpose(2, 3).contiguous()  # (batch, channel, time, feature)
        data = self.layer_norm(data)
        return data.transpose(2, 3).contiguous()  # (batch, channel, feature, time)


class ResidualCNN(nn.Module):
    """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
    except with layer norm instead of batch norm
    """

    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, data):
        """data (batch, channel, feature, time)"""
        residual = data  # (batch, channel, feature, time)
        data = self.layer_norm1(data)
        data = F.gelu(data)
        data = self.dropout1(data)
        data = self.cnn1(data)
        data = self.layer_norm2(data)
        data = F.gelu(data)
        data = self.dropout2(data)
        data = self.cnn2(data)
        data += residual
        return data  # (batch, channel, feature, time)


class BidirectionalGRU(nn.Module):
    """Bidirectional GRU layer"""

    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = nn.GRU(  # pylint: disable=invalid-name
            input_size=rnn_dim,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=batch_first,
            bidirectional=True,
        )
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        """x (batch, time, feature)"""
        data = self.layer_norm(data)
        data = F.gelu(data)
        data, _ = self.BiGRU(data)
        data = self.dropout(data)
        return data


class SpeechRecognitionModel(nn.Module):
    """Speech Recognition Model Inspired by DeepSpeech 2"""

    def __init__(
        self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1
    ):
        super(SpeechRecognitionModel, self).__init__()
        n_feats = n_feats // 2
        self.cnn = nn.Conv2d(
            1, 32, 3, stride=stride, padding=3 // 2
        )  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(
            *[
                ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
                for _ in range(n_cnn_layers)
            ]
        )
        self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
        self.birnn_layers = nn.Sequential(
            *[
                BidirectionalGRU(
                    rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
                    hidden_size=rnn_dim,
                    dropout=dropout,
                    batch_first=i == 0,
                )
                for i in range(n_rnn_layers)
            ]
        )
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim * 2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class),
        )

    def forward(self, data):
        """x (batch, channel, feature, time)"""
        data = self.cnn(data)
        data = self.rescnn_layers(data)
        sizes = data.size()
        data = data.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        data = data.transpose(1, 2)  # (batch, time, feature)
        data = self.fully_connected(data)
        data = self.birnn_layers(data)
        data = self.classifier(data)
        return data