aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
authorPherkel2023-08-19 11:06:33 +0200
committerPherkel2023-08-19 11:06:33 +0200
commit3880339761062a588f467c4ea891338838c533e9 (patch)
tree5638b0a224d9bb1d41a8b392efec30bbbd1b175b /swr2_asr
parent7c375276b44dc933b3311c87601e0ac6945f5be8 (diff)
parent9b4592caac90a41eb6ce18558588ef504c49f58e (diff)
Merge branch 'wave2vec2'
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/inference_test.py74
-rw-r--r--swr2_asr/loss_scores.py185
-rw-r--r--swr2_asr/train.py482
3 files changed, 734 insertions, 7 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
new file mode 100644
index 0000000..a6b0010
--- /dev/null
+++ b/swr2_asr/inference_test.py
@@ -0,0 +1,74 @@
+"""Training script for the ASR model."""
+from AudioLoader.speech.mls import MultilingualLibriSpeech
+import torch
+import torchaudio
+import torchaudio.functional as F
+
+
+class GreedyCTCDecoder(torch.nn.Module):
+ def __init__(self, labels, blank=0) -> None:
+ super().__init__()
+ self.labels = labels
+ self.blank = blank
+
+ def forward(self, emission: torch.Tensor) -> str:
+ """Given a sequence emission over labels, get the best path string
+ Args:
+ emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
+
+ Returns:
+ str: The resulting transcript
+ """
+ indices = torch.argmax(emission, dim=-1) # [num_seq,]
+ indices = torch.unique_consecutive(indices, dim=-1)
+ indices = [i for i in indices if i != self.blank]
+ return "".join([self.labels[i] for i in indices])
+
+
+def main() -> None:
+ """Main function."""
+ # choose between cuda, cpu and mps devices
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ # device = "mps"
+ device = torch.device(device)
+
+ torch.random.manual_seed(42)
+
+ bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
+
+ print(f"Sample rate (model): {bundle.sample_rate}")
+ print(f"Labels (model): {bundle.get_labels()}")
+
+ model = bundle.get_model().to(device)
+
+ print(model.__class__)
+
+ # only do all things for one single sample
+ dataset = MultilingualLibriSpeech(
+ "data", "mls_german_opus", split="train", download=True
+ )
+
+ print(dataset[0])
+
+ # load waveforms and sample rate from dataset
+ waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"]
+
+ if sample_rate != bundle.sample_rate:
+ waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate))
+
+ waveform.to(device)
+
+ with torch.inference_mode():
+ features, _ = model.extract_features(waveform)
+
+ with torch.inference_mode():
+ emission, _ = model(waveform)
+
+ decoder = GreedyCTCDecoder(labels=bundle.get_labels())
+ transcript = decoder(emission[0])
+
+ print(transcript)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py
new file mode 100644
index 0000000..977462d
--- /dev/null
+++ b/swr2_asr/loss_scores.py
@@ -0,0 +1,185 @@
+import numpy as np
+
+
+def avg_wer(wer_scores, combined_ref_len):
+ return float(sum(wer_scores)) / float(combined_ref_len)
+
+
+def _levenshtein_distance(ref, hyp):
+ """Levenshtein distance is a string metric for measuring the difference
+ between two sequences. Informally, the levenshtein disctance is defined as
+ the minimum number of single-character edits (substitutions, insertions or
+ deletions) required to change one word into the other. We can naturally
+ extend the edits to word level when calculate levenshtein disctance for
+ two sentences.
+ """
+ m = len(ref)
+ n = len(hyp)
+
+ # special case
+ if ref == hyp:
+ return 0
+ if m == 0:
+ return n
+ if n == 0:
+ return m
+
+ if m < n:
+ ref, hyp = hyp, ref
+ m, n = n, m
+
+ # use O(min(m, n)) space
+ distance = np.zeros((2, n + 1), dtype=np.int32)
+
+ # initialize distance matrix
+ for j in range(0, n + 1):
+ distance[0][j] = j
+
+ # calculate levenshtein distance
+ for i in range(1, m + 1):
+ prev_row_idx = (i - 1) % 2
+ cur_row_idx = i % 2
+ distance[cur_row_idx][0] = i
+ for j in range(1, n + 1):
+ if ref[i - 1] == hyp[j - 1]:
+ distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
+ else:
+ s_num = distance[prev_row_idx][j - 1] + 1
+ i_num = distance[cur_row_idx][j - 1] + 1
+ d_num = distance[prev_row_idx][j] + 1
+ distance[cur_row_idx][j] = min(s_num, i_num, d_num)
+
+ return distance[m % 2][n]
+
+
+def word_errors(
+ reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "
+):
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in word-level.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param delimiter: Delimiter of input sentences.
+ :type delimiter: char
+ :return: Levenshtein distance and word number of reference sentence.
+ :rtype: list
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ ref_words = reference.split(delimiter)
+ hyp_words = hypothesis.split(delimiter)
+
+ edit_distance = _levenshtein_distance(ref_words, hyp_words)
+ return float(edit_distance), len(ref_words)
+
+
+def char_errors(
+ reference: str,
+ hypothesis: str,
+ ignore_case: bool = False,
+ remove_space: bool = False,
+):
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in char-level.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param remove_space: Whether remove internal space characters
+ :type remove_space: bool
+ :return: Levenshtein distance and length of reference sentence.
+ :rtype: list
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ join_char = " "
+ if remove_space:
+ join_char = ""
+
+ reference = join_char.join(filter(None, reference.split(" ")))
+ hypothesis = join_char.join(filter(None, hypothesis.split(" ")))
+
+ edit_distance = _levenshtein_distance(reference, hypothesis)
+ return float(edit_distance), len(reference)
+
+
+def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "):
+ """Calculate word error rate (WER). WER compares reference text and
+ hypothesis text in word-level. WER is defined as:
+ .. math::
+ WER = (Sw + Dw + Iw) / Nw
+ where
+ .. code-block:: text
+ Sw is the number of words subsituted,
+ Dw is the number of words deleted,
+ Iw is the number of words inserted,
+ Nw is the number of words in the reference
+ We can use levenshtein distance to calculate WER. Please draw an attention
+ that empty items will be removed when splitting sentences by delimiter.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param delimiter: Delimiter of input sentences.
+ :type delimiter: char
+ :return: Word error rate.
+ :rtype: float
+ :raises ValueError: If word number of reference is zero.
+ """
+ edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter)
+
+ if ref_len == 0:
+ raise ValueError("Reference's word number should be greater than 0.")
+
+ wer = float(edit_distance) / ref_len
+ return wer
+
+
+def cer(reference, hypothesis, ignore_case=False, remove_space=False):
+ """Calculate charactor error rate (CER). CER compares reference text and
+ hypothesis text in char-level. CER is defined as:
+ .. math::
+ CER = (Sc + Dc + Ic) / Nc
+ where
+ .. code-block:: text
+ Sc is the number of characters substituted,
+ Dc is the number of characters deleted,
+ Ic is the number of characters inserted
+ Nc is the number of characters in the reference
+ We can use levenshtein distance to calculate CER. Chinese input should be
+ encoded to unicode. Please draw an attention that the leading and tailing
+ space characters will be truncated and multiple consecutive space
+ characters in a sentence will be replaced by one space character.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param remove_space: Whether remove internal space characters
+ :type remove_space: bool
+ :return: Character error rate.
+ :rtype: float
+ :raises ValueError: If the reference length is zero.
+ """
+ edit_distance, ref_len = char_errors(
+ reference, hypothesis, ignore_case, remove_space
+ )
+
+ if ref_len == 0:
+ raise ValueError("Length of reference should be greater than 0.")
+
+ cer = float(edit_distance) / ref_len
+ return cer
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 7a9ffec..29f9372 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,16 +1,484 @@
"""Training script for the ASR model."""
-import os
from AudioLoader.speech.mls import MultilingualLibriSpeech
+import click
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import torchaudio
+from .loss_scores import cer, wer
-def main() -> None:
- """Main function."""
- dataset = MultilingualLibriSpeech(
- "data", "mls_polish_opus", split="train", download=(not os.path.isdir("data"))
+class TextTransform:
+ """Maps characters to integers and vice versa"""
+
+ def __init__(self):
+ char_map_str = """
+ ' 0
+ <SPACE> 1
+ a 2
+ b 3
+ c 4
+ d 5
+ e 6
+ f 7
+ g 8
+ h 9
+ i 10
+ j 11
+ k 12
+ l 13
+ m 14
+ n 15
+ o 16
+ p 17
+ q 18
+ r 19
+ s 20
+ t 21
+ u 22
+ v 23
+ w 24
+ x 25
+ y 26
+ z 27
+ ä 28
+ ö 29
+ ü 30
+ ß 31
+ - 32
+ é 33
+ è 34
+ à 35
+ ù 36
+ ç 37
+ â 38
+ ê 39
+ î 40
+ ô 41
+ û 42
+ ë 43
+ ï 44
+ ü 45
+ """
+ self.char_map = {}
+ self.index_map = {}
+ for line in char_map_str.strip().split("\n"):
+ char, index = line.split()
+ self.char_map[char] = int(index)
+ self.index_map[int(index)] = char
+ self.index_map[1] = " "
+
+ def text_to_int(self, text):
+ """Use a character map and convert text to an integer sequence"""
+ int_sequence = []
+ for char in text:
+ if char == " ":
+ mapped_char = self.char_map["<SPACE>"]
+ else:
+ mapped_char = self.char_map[char]
+ int_sequence.append(mapped_char)
+ return int_sequence
+
+ def int_to_text(self, labels):
+ """Use a character map and convert integer labels to an text sequence"""
+ string = []
+ for i in labels:
+ string.append(self.index_map[i])
+ return "".join(string).replace("<SPACE>", " ")
+
+
+train_audio_transforms = nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+ torchaudio.transforms.TimeMasking(time_mask_param=100),
+)
+
+valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
+
+text_transform = TextTransform()
+
+
+def data_processing(data, data_type="train"):
+ """Return the spectrograms, labels, and their lengths."""
+ spectrograms = []
+ labels = []
+ input_lengths = []
+ label_lengths = []
+ for sample in data:
+ if data_type == "train":
+ spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
+ elif data_type == "valid":
+ spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
+ else:
+ raise ValueError("data_type should be train or valid")
+ spectrograms.append(spec)
+ label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower()))
+ labels.append(label)
+ input_lengths.append(spec.shape[0] // 2)
+ label_lengths.append(len(label))
+
+ spectrograms = (
+ nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
+ .unsqueeze(1)
+ .transpose(2, 3)
)
+ labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+ return spectrograms, labels, input_lengths, label_lengths
+
+
+def greedy_decoder(
+ output, labels, label_lengths, blank_label=28, collapse_repeated=True
+):
+ """Greedily decode a sequence."""
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ targets.append(
+ text_transform.int_to_text(labels[i][: label_lengths[i]].tolist())
+ )
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(text_transform.int_to_text(decode))
+ return decodes, targets
+
+
+class CNNLayerNorm(nn.Module):
+ """Layer normalization built for cnns input"""
+
+ def __init__(self, n_feats: int):
+ super(CNNLayerNorm, self).__init__()
+ self.layer_norm = nn.LayerNorm(n_feats)
+
+ def forward(self, data):
+ """x (batch, channel, feature, time)"""
+ data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature)
+ data = self.layer_norm(data)
+ return data.transpose(2, 3).contiguous() # (batch, channel, feature, time)
+
+
+class ResidualCNN(nn.Module):
+ """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel: int,
+ stride: int,
+ dropout: float,
+ n_feats: int,
+ ):
+ super(ResidualCNN, self).__init__()
+
+ self.cnn1 = nn.Conv2d(
+ in_channels, out_channels, kernel, stride, padding=kernel // 2
+ )
+ self.cnn2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel,
+ stride,
+ padding=kernel // 2,
+ )
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.layer_norm1 = CNNLayerNorm(n_feats)
+ self.layer_norm2 = CNNLayerNorm(n_feats)
+
+ def forward(self, data):
+ """x (batch, channel, feature, time)"""
+ residual = data # (batch, channel, feature, time)
+ data = self.layer_norm1(data)
+ data = F.gelu(data)
+ data = self.dropout1(data)
+ data = self.cnn1(data)
+ data = self.layer_norm2(data)
+ data = F.gelu(data)
+ data = self.dropout2(data)
+ data = self.cnn2(data)
+ data += residual
+ return data # (batch, channel, feature, time)
+
+
+class BidirectionalGRU(nn.Module):
+ """BIdirectional GRU with Layer Normalization and Dropout"""
+
+ def __init__(
+ self,
+ rnn_dim: int,
+ hidden_size: int,
+ dropout: float,
+ batch_first: bool,
+ ):
+ super(BidirectionalGRU, self).__init__()
+
+ self.bi_gru = nn.GRU(
+ input_size=rnn_dim,
+ hidden_size=hidden_size,
+ num_layers=1,
+ batch_first=batch_first,
+ bidirectional=True,
+ )
+ self.layer_norm = nn.LayerNorm(rnn_dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, data):
+ """data (batch, time, feature)"""
+ data = self.layer_norm(data)
+ data = F.gelu(data)
+ data = self.dropout(data)
+ data, _ = self.bi_gru(data)
+ return data
+
+
+class SpeechRecognitionModel(nn.Module):
+ """Speech Recognition Model Inspired by DeepSpeech 2"""
+
+ def __init__(
+ self,
+ n_cnn_layers: int,
+ n_rnn_layers: int,
+ rnn_dim: int,
+ n_class: int,
+ n_feats: int,
+ stride: int = 2,
+ dropout: float = 0.1,
+ ):
+ super(SpeechRecognitionModel, self).__init__()
+ n_feats //= 2
+ self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
+ # n residual cnn layers with filter size of 32
+ self.rescnn_layers = nn.Sequential(
+ *[
+ ResidualCNN(
+ 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
+ )
+ for _ in range(n_cnn_layers)
+ ]
+ )
+ self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
+ self.birnn_layers = nn.Sequential(
+ *[
+ BidirectionalGRU(
+ rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
+ hidden_size=rnn_dim,
+ dropout=dropout,
+ batch_first=i == 0,
+ )
+ for i in range(n_rnn_layers)
+ ]
+ )
+ self.classifier = nn.Sequential(
+ nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(rnn_dim, n_class),
+ )
+
+ def forward(self, data):
+ """data (batch, channel, feature, time)"""
+ data = self.cnn(data)
+ data = self.rescnn_layers(data)
+ sizes = data.size()
+ data = data.view(
+ sizes[0], sizes[1] * sizes[2], sizes[3]
+ ) # (batch, feature, time)
+ data = data.transpose(1, 2) # (batch, time, feature)
+ data = self.fully_connected(data)
+ data = self.birnn_layers(data)
+ data = self.classifier(data)
+ return data
+
+
+class IterMeter(object):
+ """keeps track of total iterations"""
+
+ def __init__(self):
+ self.val = 0
+
+ def step(self):
+ """step"""
+ self.val += 1
+
+ def get(self):
+ """get"""
+ return self.val
+
+
+def train(
+ model,
+ device,
+ train_loader,
+ criterion,
+ optimizer,
+ scheduler,
+ epoch,
+ iter_meter,
+):
+ """Train"""
+ model.train()
+ data_len = len(train_loader.dataset)
+ for batch_idx, _data in enumerate(train_loader):
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
+
+ optimizer.zero_grad()
+
+ output = model(spectrograms) # (batch, time, n_class)
+ output = F.log_softmax(output, dim=2)
+ output = output.transpose(0, 1) # (time, batch, n_class)
+
+ loss = criterion(output, labels, input_lengths, label_lengths)
+ loss.backward()
+
+ optimizer.step()
+ scheduler.step()
+ iter_meter.step()
+ if batch_idx % 100 == 0 or batch_idx == data_len:
+ print(
+ f"Train Epoch: \
+ {epoch} \
+ [{batch_idx * len(spectrograms)}/{data_len} \
+ ({100.0 * batch_idx / len(train_loader)}%)]\t \
+ Loss: {loss.item()}"
+ )
+
+
+def test(model, device, test_loader, criterion):
+ """Test"""
+ print("\nevaluating...")
+ model.eval()
+ test_loss = 0
+ test_cer, test_wer = [], []
+ with torch.no_grad():
+ for _data in test_loader:
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
+
+ output = model(spectrograms) # (batch, time, n_class)
+ output = F.log_softmax(output, dim=2)
+ output = output.transpose(0, 1) # (time, batch, n_class)
+
+ loss = criterion(output, labels, input_lengths, label_lengths)
+ test_loss += loss.item() / len(test_loader)
+
+ decoded_preds, decoded_targets = greedy_decoder(
+ output.transpose(0, 1), labels, label_lengths
+ )
+ for j, pred in enumerate(decoded_preds):
+ test_cer.append(cer(decoded_targets[j], pred))
+ test_wer.append(wer(decoded_targets[j], pred))
+
+ avg_cer = sum(test_cer) / len(test_cer)
+ avg_wer = sum(test_wer) / len(test_wer)
+
+ print(
+ f"Test set: Average loss:\
+ {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n"
+ )
+
+
+def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
+ """Runs the training script."""
+ hparams = {
+ "n_cnn_layers": 3,
+ "n_rnn_layers": 5,
+ "rnn_dim": 512,
+ "n_class": 46,
+ "n_feats": 128,
+ "stride": 2,
+ "dropout": 0.1,
+ "learning_rate": learning_rate,
+ "batch_size": batch_size,
+ "epochs": epochs,
+ }
+
+ use_cuda = torch.cuda.is_available()
+ torch.manual_seed(42)
+ device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
+ # device = torch.device("mps")
+
+ train_dataset = MultilingualLibriSpeech(
+ "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False
+ )
+ test_dataset = MultilingualLibriSpeech(
+ "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False
+ )
+
+ kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
+
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=hparams["batch_size"],
+ shuffle=True,
+ collate_fn=lambda x: data_processing(x, "train"),
+ **kwargs,
+ )
+
+ test_loader = DataLoader(
+ test_dataset,
+ batch_size=hparams["batch_size"],
+ shuffle=True,
+ collate_fn=lambda x: data_processing(x, "train"),
+ **kwargs,
+ )
+
+ model = SpeechRecognitionModel(
+ hparams["n_cnn_layers"],
+ hparams["n_rnn_layers"],
+ hparams["rnn_dim"],
+ hparams["n_class"],
+ hparams["n_feats"],
+ hparams["stride"],
+ hparams["dropout"],
+ ).to(device)
+
+ print(
+ "Num Model Parameters", sum([param.nelement() for param in model.parameters()])
+ )
+
+ optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
+ criterion = nn.CTCLoss(blank=28).to(device)
+
+ scheduler = optim.lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=hparams["learning_rate"],
+ steps_per_epoch=int(len(train_loader)),
+ epochs=hparams["epochs"],
+ anneal_strategy="linear",
+ )
+
+ iter_meter = IterMeter()
+ for epoch in range(1, epochs + 1):
+ train(
+ model,
+ device,
+ train_loader,
+ criterion,
+ optimizer,
+ scheduler,
+ epoch,
+ iter_meter,
+ )
+ test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+
- print(dataset[1])
+@click.command()
+@click.option("--learning-rate", default=1e-3, help="Learning rate")
+@click.option("--batch_size", default=1, help="Batch size")
+@click.option("--epochs", default=1, help="Number of epochs")
+def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None:
+ """Runs the training script."""
+ run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs)
if __name__ == "__main__":
- main()
+ run(learning_rate=5e-4, batch_size=16, epochs=1)