diff options
author | Pherkel | 2023-09-18 21:52:33 +0200 |
---|---|---|
committer | Pherkel | 2023-09-18 21:52:33 +0200 |
commit | d9421cafa088babdf6b84fa18af28ba6671addce (patch) | |
tree | dc92662ff0bc039b898ee4275dfc456b702d4c59 | |
parent | eaec9a8a32fc791c9ea676843dd7cac6a2b48d9a (diff) |
linter changes
-rw-r--r-- | swr2_asr/inference.py | 2 | ||||
-rw-r--r-- | swr2_asr/model_deep_speech.py | 8 | ||||
-rw-r--r-- | swr2_asr/utils/visualization.py | 4 |
3 files changed, 8 insertions, 6 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index 511aef1..d3543a9 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -109,7 +109,7 @@ def main(config_path: str, file_path: str, target_path: Union[str, None] = None) target = target.replace("!", "") print("---------") - print(f"Prediction:\n\{preds}") + print(f"Prediction:\n{preds}") print("---------") print(f"Target:\n{target}") print("---------") diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index 73f5a81..bd557d8 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -11,7 +11,7 @@ class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" def __init__(self, n_feats): - super(CNNLayerNorm, self).__init__() + super().__init__() self.layer_norm = nn.LayerNorm(n_feats) def forward(self, data): @@ -27,7 +27,7 @@ class ResidualCNN(nn.Module): """ def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats): - super(ResidualCNN, self).__init__() + super().__init__() self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2) self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2) @@ -55,7 +55,7 @@ class BidirectionalGRU(nn.Module): """Bidirectional GRU layer""" def __init__(self, rnn_dim, hidden_size, dropout, batch_first): - super(BidirectionalGRU, self).__init__() + super().__init__() self.BiGRU = nn.GRU( # pylint: disable=invalid-name input_size=rnn_dim, @@ -82,7 +82,7 @@ class SpeechRecognitionModel(nn.Module): def __init__( self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1 ): - super(SpeechRecognitionModel, self).__init__() + super().__init__() n_feats = n_feats // 2 self.cnn = nn.Conv2d( 1, 32, 3, stride=stride, padding=3 // 2 diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index b288c5a..23956fd 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -14,7 +14,9 @@ def plot(path): epoch = 5 while True: try: - current_state = torch.load(path + str(epoch), map_location=torch.device("cpu")) + current_state = torch.load( + path + str(epoch), map_location=torch.device("cpu") + ) # pylint: disable=no-member except FileNotFoundError: break train_losses.append((epoch, current_state["train_loss"].item())) |