aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPherkel2023-09-18 21:52:33 +0200
committerPherkel2023-09-18 21:52:33 +0200
commitd9421cafa088babdf6b84fa18af28ba6671addce (patch)
treedc92662ff0bc039b898ee4275dfc456b702d4c59
parenteaec9a8a32fc791c9ea676843dd7cac6a2b48d9a (diff)
linter changes
-rw-r--r--swr2_asr/inference.py2
-rw-r--r--swr2_asr/model_deep_speech.py8
-rw-r--r--swr2_asr/utils/visualization.py4
3 files changed, 8 insertions, 6 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index 511aef1..d3543a9 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -109,7 +109,7 @@ def main(config_path: str, file_path: str, target_path: Union[str, None] = None)
target = target.replace("!", "")
print("---------")
- print(f"Prediction:\n\{preds}")
+ print(f"Prediction:\n{preds}")
print("---------")
print(f"Target:\n{target}")
print("---------")
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
index 73f5a81..bd557d8 100644
--- a/swr2_asr/model_deep_speech.py
+++ b/swr2_asr/model_deep_speech.py
@@ -11,7 +11,7 @@ class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
def __init__(self, n_feats):
- super(CNNLayerNorm, self).__init__()
+ super().__init__()
self.layer_norm = nn.LayerNorm(n_feats)
def forward(self, data):
@@ -27,7 +27,7 @@ class ResidualCNN(nn.Module):
"""
def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
- super(ResidualCNN, self).__init__()
+ super().__init__()
self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2)
@@ -55,7 +55,7 @@ class BidirectionalGRU(nn.Module):
"""Bidirectional GRU layer"""
def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
- super(BidirectionalGRU, self).__init__()
+ super().__init__()
self.BiGRU = nn.GRU( # pylint: disable=invalid-name
input_size=rnn_dim,
@@ -82,7 +82,7 @@ class SpeechRecognitionModel(nn.Module):
def __init__(
self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1
):
- super(SpeechRecognitionModel, self).__init__()
+ super().__init__()
n_feats = n_feats // 2
self.cnn = nn.Conv2d(
1, 32, 3, stride=stride, padding=3 // 2
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
index b288c5a..23956fd 100644
--- a/swr2_asr/utils/visualization.py
+++ b/swr2_asr/utils/visualization.py
@@ -14,7 +14,9 @@ def plot(path):
epoch = 5
while True:
try:
- current_state = torch.load(path + str(epoch), map_location=torch.device("cpu"))
+ current_state = torch.load(
+ path + str(epoch), map_location=torch.device("cpu")
+ ) # pylint: disable=no-member
except FileNotFoundError:
break
train_losses.append((epoch, current_state["train_loss"].item()))