aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPherkel2023-09-11 23:14:24 +0200
committerPherkel2023-09-11 23:14:24 +0200
commit3811dc68de2e2572b3656b8f4460553136eb11b4 (patch)
treec03fb6e7e90c3fe075af5c9ad7b6e728f8161f15
parent96fee5f59f67187292ddf37db4660c5085fb66b5 (diff)
added cluster config
-rw-r--r--config.cluster.yaml34
-rwxr-xr-xhpc_train.sh2
2 files changed, 35 insertions, 1 deletions
diff --git a/config.cluster.yaml b/config.cluster.yaml
new file mode 100644
index 0000000..a3def0e
--- /dev/null
+++ b/config.cluster.yaml
@@ -0,0 +1,34 @@
+model:
+ n_cnn_layers: 3
+ n_rnn_layers: 5
+ rnn_dim: 512
+ n_feats: 128 # number of mel features
+ stride: 2
+ dropout: 0.25 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets
+
+training:
+ learning_rate: 0.0005
+ batch_size: 64 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
+ epochs: 150
+ eval_every_n: 5 # evaluate every n epochs
+ num_workers: 8 # number of workers for dataloader
+ device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
+
+dataset:
+ download: True
+ dataset_root_path: "/mnt/lustre/mladm/mfa252/data" # files will be downloaded into this dir
+ language_name: "mls_german_opus"
+ limited_supervision: False # set to True if you want to use limited supervision
+ dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
+ shuffle: True
+
+tokenizer:
+ tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
+
+checkpoints:
+ model_load_path: "data/runs/epoch31" # path to load model from
+ model_save_path: "data/runs/epoch" # path to save model to
+
+inference:
+ model_load_path: ~ # path to load model from
+ device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file
diff --git a/hpc_train.sh b/hpc_train.sh
index c7d1636..2280087 100755
--- a/hpc_train.sh
+++ b/hpc_train.sh
@@ -1,3 +1,3 @@
#!/bin/sh
-yes no | python -m swr2_asr.train --epochs=100 --batch_size=30 --dataset_path=/mnt/lustre/mladm/mfa252/data
+yes no | python -m swr2_asr.train --config_path config.cluster.yaml