Skip to content

Commit

Permalink
remove custom trainer for weather-forecast/gravity-wave-drag/cnns
Browse files Browse the repository at this point in the history
  • Loading branch information
A669015 committed Jan 29, 2025
1 parent 4c276b9 commit 26d8915
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 273 deletions.
51 changes: 25 additions & 26 deletions weather-forecast/gravity-wave-drag/cnns/ci/configs/cnn_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

fit:
model:
class_path: models.LitCNN
init_args:
in_channels: 5
init_feat: 16
out_channels: 126
conv_size: 1
pool_size: 2
lr: .0001
model:
class_path: models.LitCNN
init_args:
in_channels: 5
init_feat: 16
out_channels: 126
conv_size: 1
pool_size: 2
lr: .0001

data:
class_path: data.NOGWDDataModule
init_args:
batch_size: 10
num_workers: 1
splitting_ratios: [0.6, 0.2]
shard_len: 100
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}
data:
class_path: data.NOGWDDataModule
init_args:
batch_size: 10
num_workers: 1
splitting_ratios: [0.6, 0.2]
shard_len: 100
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

trainer:
max_epochs: 1
accelerator: "cpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, 3}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
trainer:
max_epochs: 1
accelerator: "cpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, 3}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
47 changes: 23 additions & 24 deletions weather-forecast/gravity-wave-drag/cnns/ci/configs/mlp_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

fit:
model:
class_path: models.LitMLP
init_args:
in_channels: 191
hidden_channels: 256
out_channels: 126
lr: .0001
model:
class_path: models.LitMLP
init_args:
in_channels: 191
hidden_channels: 256
out_channels: 126
lr: .0001

data:
class_path: data.NOGWDDataModule
init_args:
batch_size: 10
num_workers: 1
splitting_ratios: [0.6, 0.2]
shard_len: 100
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

data:
class_path: data.NOGWDDataModule
trainer:
max_epochs: 1
accelerator: "cpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, 1}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
batch_size: 10
num_workers: 1
splitting_ratios: [0.6, 0.2]
shard_len: 100
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

trainer:
max_epochs: 1
accelerator: "cpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, 1}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def create_data():
with open(filepath, "r") as file:
params = safe_load(file)

row_feats = params["fit"]["data"]["init_args"]["shard_len"]
row_feats = params["data"]["init_args"]["shard_len"]
if not exists(data_path):
makedirs(join(data_path, "raw"))
for file_h5 in filenames:
Expand All @@ -51,10 +51,10 @@ def create_data():
with open(temp_file_path, "w") as tmpfile:
dump(filenames, tmpfile)

tr, va = params["fit"]["data"]["init_args"]["splitting_ratios"]
tr, va = params["data"]["init_args"]["splitting_ratios"]
length = len(filenames)

for element in filenames:
for _ in filenames:
file_split["train"] = filenames[: int(tr * length)]
file_split["val"] = filenames[int(tr * length) : int((tr + va) * length)]
file_split["test"] = filenames[int((tr + va) * length) :]
Expand Down
47 changes: 23 additions & 24 deletions weather-forecast/gravity-wave-drag/cnns/configs/cnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

fit:
model:
class_path: models.LitCNN
init_args:
in_channels: 5
init_feat: 16
out_channels: 126
conv_size: 1
pool_size: 2
lr: .0001
model:
class_path: models.LitCNN
init_args:
in_channels: 5
init_feat: 16
out_channels: 126
conv_size: 1
pool_size: 2
lr: .0001

data:
class_path: data.NOGWDDataModule
init_args:
batch_size: 256
num_workers: 0
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}
data:
class_path: data.NOGWDDataModule
init_args:
batch_size: 256
num_workers: 0
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

trainer:
max_epochs: 10
accelerator: "gpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, [0]}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
trainer:
max_epochs: 10
accelerator: "gpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, [0]}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
43 changes: 21 additions & 22 deletions weather-forecast/gravity-wave-drag/cnns/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

fit:
model:
class_path: models.LitMLP
init_args:
in_channels: 191
hidden_channels: 256
out_channels: 126
lr: .0001
model:
class_path: models.LitMLP
init_args:
in_channels: 191
hidden_channels: 256
out_channels: 126
lr: .0001

data:
class_path: data.NOGWDDataModule
init_args:
batch_size: 256
num_workers: 0
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}
data:
class_path: data.NOGWDDataModule
init_args:
batch_size: 256
num_workers: 0
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

trainer:
max_epochs: 1
accelerator: "gpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, [0]}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
trainer:
max_epochs: 1
accelerator: "gpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, [0]}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
43 changes: 21 additions & 22 deletions weather-forecast/gravity-wave-drag/cnns/configs/mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

fit:
model:
class_path: models.LitMLP
init_args:
in_channels: 191
hidden_channels: 256
out_channels: 126
lr: .0001
model:
class_path: models.LitMLP
init_args:
in_channels: 191
hidden_channels: 256
out_channels: 126
lr: .0001

data:
class_path: data.NOGWDDataModule
init_args:
batch_size: 256
num_workers: 0
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}
data:
class_path: data.NOGWDDataModule
init_args:
batch_size: 256
num_workers: 0
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

trainer:
max_epochs: 5
accelerator: "gpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, [0]}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
trainer:
max_epochs: 5
accelerator: "gpu"
devices: ${oc.decode:${oc.env:SLURM_GPUS_ON_NODE, [0]}}
logger:
class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${oc.decode:${oc.env:LOGDIR, ./logs}}
89 changes: 0 additions & 89 deletions weather-forecast/gravity-wave-drag/cnns/tests/test_trainer.py

This file was deleted.

Loading

0 comments on commit 26d8915

Please sign in to comment.