Skip to content

Commit

Permalink
SBRA-731 Clean trainer (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
A669015 authored Feb 3, 2025
1 parent e3ab804 commit 91db174
Show file tree
Hide file tree
Showing 17 changed files with 179 additions and 697 deletions.
16 changes: 12 additions & 4 deletions reactive-flows/cnf-combustion/gnns/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from typing import List, Tuple

import lightning as pl
Expand Down Expand Up @@ -124,14 +126,20 @@ def on_test_epoch_end(self) -> None:
self.y_hats = self.all_gather(y_hats)

# Reshape the outputs to the original grid shape plus the batch dimension
self.ys = self.ys.squeeze().view((-1,) + self.grid_shape).cpu().numpy()
self.y_hats = self.y_hats.squeeze().view((-1,) + self.grid_shape).cpu().numpy()
self.ys = self.ys.squeeze().view((-1,) + self.grid_shape).detach().numpy()
self.y_hats = (
self.y_hats.squeeze().view((-1,) + self.grid_shape).detach().numpy()
)

plots_path = os.path.join(self.trainer.log_dir, "plots")
if self.trainer.is_global_zero:
if not os.path.exists(plots_path):
os.makedirs(plots_path, exist_ok=True)

self.plotter = plotters.Plotter(
self.model.__class__.__name__, self.trainer.plots_path, self.grid_shape
self.model.__class__.__name__, plots_path, self.grid_shape
)
self.plotter.cross_section(self.plotter.zslice, self.ys, self.y_hats)
self.plotter.cross_section((self.ys.shape[1] // 2), self.ys, self.y_hats)
self.plotter.dispersion_plot(self.ys, self.y_hats)
self.plotter.histo(self.ys, self.y_hats)
self.plotter.histo2d(self.ys, self.y_hats)
Expand Down
2 changes: 0 additions & 2 deletions reactive-flows/cnf-combustion/gnns/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ def __init__(
model_type: torch.Tensor,
plots_path: str,
grid_shape: torch.Tensor,
zslice: int = 16,
) -> None:
"""Init the Plotter."""
self.model_type = model_type
self.grid_shape = grid_shape
self.zslice = zslice
self.plots_path = plots_path

self.label_target = r"$\overline{\Sigma}_{target}$"
Expand Down
57 changes: 47 additions & 10 deletions reactive-flows/cnf-combustion/gnns/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
import torch_geometric as pyg
import torch_optimizer as optim
import yaml
from lightning.pytorch.trainer import Trainer

import models
from models import LitGAT, LitGCN, LitGIN, LitGraphUNet


class TestModel(unittest.TestCase):
Expand All @@ -34,7 +35,7 @@ def setUp(self) -> None:
"""Define default parameters."""
self.filenames = ["DNS1_00116000.h5", "DNS1_00117000.h5", "DNS1_00118000.h5"]

self.initParam = {
self.init_param = {
"in_channels": 1,
"hidden_channels": 32,
"out_channels": 1,
Expand Down Expand Up @@ -90,61 +91,97 @@ def test_forward(self):
file_path = os.path.join(tempdir, "data", "raw", self.filenames[0])
data_test = self.create_graph(file_path)

test_gcn = models.LitGCN(**self.initParam)
test_gcn = LitGCN(**self.init_param)
test_forward = test_gcn.forward(data_test.x, data_test.edge_index)

self.assertTrue(isinstance(test_forward, torch.Tensor))

def test_common_step(self):
"""Test the "_common_step" method returns a 3 length tuple."""
gin_init_param = {
"in_channels": 1,
"hidden_channels": 32,
"out_channels": 1,
"num_layers": 4,
"dropout": 0.5,
"lr": 0.0001,
}

with tempfile.TemporaryDirectory() as tempdir:
self.create_env(tempdir)
file_path = os.path.join(tempdir, "data", "raw", self.filenames[0])
data_test = self.create_graph(file_path)

test_gcn = models.LitGCN(**self.initParam)
test_gin = LitGIN(**gin_init_param)
batch = pyg.data.Batch.from_data_list([data_test, data_test])

loss = test_gcn._common_step(batch=batch, batch_idx=1, stage="train")
loss = test_gin._common_step(batch=batch, batch_idx=1, stage="train")

self.assertEqual(len(loss), 3)

def test_training_step(self):
"""Test the "training_step" method returns a Tensor."""
gat_init_param = {
"in_channels": 1,
"hidden_channels": 32,
"out_channels": 1,
"num_layers": 4,
"dropout": 0.5,
"heads": 8,
"jk": "last",
"lr": 0.0001,
}

with tempfile.TemporaryDirectory() as tempdir:
self.create_env(tempdir)
file_path = os.path.join(tempdir, "data", "raw", self.filenames[0])
data_test = self.create_graph(file_path)

test_gcn = models.LitGCN(**self.initParam)
test_gat = LitGAT(**gat_init_param)
batch = pyg.data.Batch.from_data_list([data_test, data_test])

loss = test_gcn.training_step(batch=batch, batch_idx=1)
loss = test_gat.training_step(batch=batch, batch_idx=1)
self.assertTrue(isinstance(loss, torch.Tensor))

def test_test_step(self):
"""Test the "test_step" method returns a tuple of same size Tensors."""
gunet_init_param = {
"in_channels": 1,
"hidden_channels": 32,
"out_channels": 1,
"depth": 4,
"pool_ratios": 0.5,
"lr": 0.0001,
}

with tempfile.TemporaryDirectory() as tempdir:
self.create_env(tempdir)
file_path = os.path.join(tempdir, "data", "raw", self.filenames[0])
data_test = self.create_graph(file_path)

test_gcn = models.LitGCN(**self.initParam)
test_gunet = LitGraphUNet(**gunet_init_param)
batch = pyg.data.Batch.from_data_list([data_test, data_test])

out_tuple = test_gcn.test_step(batch=batch, batch_idx=1)
out_tuple = test_gunet.test_step(batch=batch, batch_idx=1)

self.assertEqual(len(out_tuple), 2)
self.assertEqual(out_tuple[0].size(), out_tuple[1].size())

test_gunet._trainer = Trainer()
test_gunet.on_test_epoch_end()

self.assertTrue(
os.path.exists(os.path.join(test_gunet.trainer.log_dir, "plots"))
)

def test_configure_optimizers(self):
"""Test the "configure_optimizers" method returns an optim.Optimizer."""
with tempfile.TemporaryDirectory() as tempdir:
self.create_env(tempdir)
file_path = os.path.join(tempdir, "data", "raw", self.filenames[0])
_ = self.create_graph(file_path)

test_gcn = models.LitGCN(**self.initParam)
test_gcn = LitGCN(**self.init_param)
op = test_gcn.configure_optimizers()

self.assertIsInstance(op, optim.Optimizer)
Expand Down
79 changes: 0 additions & 79 deletions reactive-flows/cnf-combustion/gnns/tests/test_trainer.py

This file was deleted.

82 changes: 1 addition & 81 deletions reactive-flows/cnf-combustion/gnns/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,89 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
from typing import Iterable, List, Optional, Union

import torch
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_only


class CLITrainer(Trainer):
"""
Modified PyTorch Lightning Trainer that automatically tests, logs, and writes artifacts by
the end of training.
"""

def __init__(
self,
accelerator: Union[str, Accelerator, None],
devices: Union[List[int], str, int, None],
max_epochs: int,
logger: Optional[Union[Logger, Iterable[Logger], bool]] = None,
# TODO: delete.
# For some reason, those two are mandatory in current version of Lightning.
fast_dev_run: Union[int, bool] = False,
callbacks: Union[List[Callback], Callback, None] = None,
) -> None:
"""Init the Trainer.
Args:
accelerator (Union[str, pl.accelerators.Accelerator, None]): Type of accelerator to use
for training.
devices: (Union[List[int], str, int, None]): Devices explicit names to use for training.
max_epochs (int): Maximum number of epochs if no early stopping logic is implemented.
"""
self._accelerator = accelerator
self._devices = devices
self._max_epochs = max_epochs

super().__init__(
logger=logger,
accelerator=self._accelerator,
devices=self._devices,
max_epochs=self._max_epochs,
num_sanity_val_steps=0,
)
self.artifacts_path = os.path.join(self.log_dir, "artifacts")
self.plots_path = os.path.join(self.log_dir, "plots")
if self.is_global_zero:
os.makedirs(self.artifacts_path, exist_ok=True)
os.makedirs(self.plots_path, exist_ok=True)

@rank_zero_only
def save(self, results):
"""Save the results of the training and the learned model."""
result_file = os.path.join(self.artifacts_path, "results.json")
with open(result_file, "w") as f:
json.dump(results, f)

torch.save(self.model, os.path.join(self.artifacts_path, "model.pth"))
logging.info(
f"Torch model saved in {os.path.join(self.artifacts_path, 'model.pth')}"
)

def test(self, **kwargs) -> None:
"""Use superclass test results, but additionally, saves raw results as a JSON file,
and stores the model weights for future use in inference mode.
Returns:
None
"""
results = super().test(**kwargs)[0]

self.save(results)


if __name__ == "__main__":
cli = LightningCLI(
trainer_class=CLITrainer, run=False, parser_kwargs={"parser_mode": "omegaconf"}
)
cli = LightningCLI(run=False, parser_kwargs={"parser_mode": "omegaconf"})
cli.trainer.fit(model=cli.model, datamodule=cli.datamodule)
cli.trainer.test(model=cli.model, datamodule=cli.datamodule)
Loading

0 comments on commit 91db174

Please sign in to comment.