aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference.py
diff options
context:
space:
mode:
authorPherkel2023-09-11 23:08:45 +0200
committerPherkel2023-09-11 23:08:45 +0200
commit96fee5f59f67187292ddf37db4660c5085fb66b5 (patch)
tree0c5d9b8520c1e3655a337fb9a3877adeedca6766 /swr2_asr/inference.py
parent6f5513140f153206cfa91df3077e67ce58043d35 (diff)
changed name to match pre-trained weights
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r--swr2_asr/inference.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index 6495a9a..3f6a44e 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -66,9 +66,9 @@ def main(config_path: str, file_path: str) -> None:
).to(device)
checkpoint = torch.load(inference_config["model_load_path"], map_location=device)
- print(checkpoint["model_state_dict"].keys())
- model.load_state_dict(checkpoint["model_state_dict"], strict=False)
+ model.load_state_dict(checkpoint["model_state_dict"], strict=True)
model.eval()
+
waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member
if waveform.shape[0] != 1:
waveform = waveform[1]