"""Training script for the ASR model."""
import os
from typing import TypedDict

import click
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm

from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel
from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
from swr2_asr.utils.decoder import greedy_decoder
from swr2_asr.utils.tokenizer import CharTokenizer

from .utils.loss_scores import cer, wer


class IterMeter(object):
    """keeps track of total iterations"""

    def __init__(self):
        self.val = 0

    def step(self):
        """step"""
        self.val += 1

    def get(self):
        """get steps"""
        return self.val


class TrainArgs(TypedDict):
    """Type for the arguments of the training function."""

    model: SpeechRecognitionModel
    device: torch.device  # pylint: disable=no-member
    train_loader: DataLoader
    criterion: nn.CTCLoss
    optimizer: optim.AdamW
    scheduler: optim.lr_scheduler.OneCycleLR
    epoch: int
    iter_meter: IterMeter


def train(train_args) -> float:
    """Train
    Args:
        model: model
        device: device type
        train_loader: train dataloader
        criterion: loss function
        optimizer: optimizer
        scheduler: learning rate scheduler
        epoch: epoch number
        iter_meter: iteration meter

    Returns:
        avg_train_loss: avg_train_loss for the epoch

    Information:
        spectrograms: (batch, time, feature)
        labels: (batch, label_length)

        model output: (batch,time, n_class)

    """
    # get values from train_args:
    (
        model,
        device,
        train_loader,
        criterion,
        optimizer,
        scheduler,
        epoch,
        iter_meter,
    ) = train_args.values()

    model.train()
    print(f"training batch {epoch}")
    train_losses = []
    for _data in tqdm(train_loader, desc="Training batches"):
        spectrograms, labels, input_lengths, label_lengths = _data
        spectrograms, labels = spectrograms.to(device), labels.to(device)

        optimizer.zero_grad()

        output = model(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1)  # (time, batch, n_class)

        loss = criterion(output, labels, input_lengths, label_lengths)
        train_losses.append(loss)
        loss.backward()

        optimizer.step()
        scheduler.step()
        iter_meter.step()
    avg_train_loss = sum(train_losses) / len(train_losses)
    print(f"Train set: Average loss: {avg_train_loss:.2f}")
    return avg_train_loss


class TestArgs(TypedDict):
    """Type for the arguments of the test function."""

    model: SpeechRecognitionModel
    device: torch.device  # pylint: disable=no-member
    test_loader: DataLoader
    criterion: nn.CTCLoss
    tokenizer: CharTokenizer
    decoder: str


def test(test_args: TestArgs) -> tuple[float, float, float]:
    print("\nevaluating...")

    # get values from test_args:
    model, device, test_loader, criterion, tokenizer, decoder = test_args.values()

    if decoder == "greedy":
        decoder = greedy_decoder

    model.eval()
    test_loss = 0
    test_cer, test_wer = [], []
    with torch.no_grad():
        for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")):
            spectrograms, labels, input_lengths, label_lengths = _data
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1)  # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            test_loss += loss.item() / len(test_loader)

            decoded_preds, decoded_targets = greedy_decoder(
                output.transpose(0, 1), labels, label_lengths, tokenizer
            )
            if i == 1:
                print(f"decoding first sample: {decoded_preds}")
            for j, _ in enumerate(decoded_preds):
                test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
                test_wer.append(wer(decoded_targets[j], decoded_preds[j]))

    avg_cer = sum(test_cer) / len(test_cer)
    avg_wer = sum(test_wer) / len(test_wer)

    print(
        f"Test set: \
            Average loss: {test_loss:.4f}, \
            Average CER: {avg_cer:4f} \
            Average WER: {avg_wer:.4f}\n"
    )

    return test_loss, avg_cer, avg_wer


def main(
    learning_rate: float,
    batch_size: int,
    epochs: int,
    dataset_path: str,
    language: str,
    limited_supervision: bool,
    model_load_path: str,
    model_save_path: str,
    dataset_percentage: float,
    eval_every: int,
    num_workers: int,
):
    """Main function for training the model.

    Args:
        learning_rate: learning rate for the optimizer
        batch_size: batch size
        epochs: number of epochs to train
        dataset_path: path for the dataset
        language: language of the dataset
        limited_supervision: whether to use only limited supervision
        model_load_path: path to load a model from
        model_save_path: path to save the model to
        dataset_percentage: percentage of the dataset to use
        eval_every: evaluate every n epochs
        num_workers: number of workers for the dataloader
    """
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")  # pylint: disable=no-member

    torch.manual_seed(7)

    if not os.path.isdir(dataset_path):
        os.makedirs(dataset_path)

    train_dataset = MLSDataset(
        dataset_path,
        language,
        Split.TEST,
        download=True,
        limited=limited_supervision,
        size=dataset_percentage,
    )
    valid_dataset = MLSDataset(
        dataset_path,
        language,
        Split.TRAIN,
        download=False,
        limited=Falimited_supervisionlse,
        size=dataset_percentage,
    )

    # TODO: initialize and possibly train tokenizer if none found

    kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}

    hparams = HParams(
        n_cnn_layers=3,
        n_rnn_layers=5,
        rnn_dim=512,
        n_class=tokenizer.get_vocab_size(),
        n_feats=128,
        stride=2,
        dropout=0.1,
        learning_rate=learning_rate,
        batch_size=batch_size,
        epochs=epochs,
    )

    train_data_processing = DataProcessing("train", tokenizer)
    valid_data_processing = DataProcessing("valid", tokenizer)

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=hparams["batch_size"],
        shuffle=True,
        collate_fn=train_data_processing,
        **kwargs,
    )
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=hparams["batch_size"],
        shuffle=False,
        collate_fn=valid_data_processing,
        **kwargs,
    )

    model = SpeechRecognitionModel(
        hparams["n_cnn_layers"],
        hparams["n_rnn_layers"],
        hparams["rnn_dim"],
        hparams["n_class"],
        hparams["n_feats"],
        hparams["stride"],
        hparams["dropout"],
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
    criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=hparams["learning_rate"],
        steps_per_epoch=int(len(train_loader)),
        epochs=hparams["epochs"],
        anneal_strategy="linear",
    )
    prev_epoch = 0

    if model_load_path is not None:
        checkpoint = torch.load(model_load_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        prev_epoch = checkpoint["epoch"]

    iter_meter = IterMeter()
    if not os.path.isdir(os.path.dirname(model_save_path)):
        os.makedirs(os.path.dirname(model_save_path))
    for epoch in range(prev_epoch + 1, epochs + 1):
        train_args: TrainArgs = dict(
            model=model,
            device=device,
            train_loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            epoch=epoch,
            iter_meter=iter_meter,
        )

        train_loss = train(train_args)

        test_loss, test_cer, test_wer = 0, 0, 0

        test_args: TestArgs = dict(
            model=model,
            device=device,
            test_loader=valid_loader,
            criterion=criterion,
            tokenizer=tokenizer,
            decoder="greedy",
        )

        if epoch % eval_every == 0:
            test_loss, test_cer, test_wer = test(test_args)

        if model_save_path is None:
            continue

        if not os.path.isdir(os.path.dirname(model_save_path)):
            os.makedirs(os.path.dirname(model_save_path))
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_loss": train_loss,
                "test_loss": test_loss,
                "avg_cer": test_cer,
                "avg_wer": test_wer,
            },
            model_save_path + str(epoch),
        )


if __name__ == "__main__":
    main()  # pylint: disable=no-value-for-parameter