Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SBRA-739 redefine the graph creation method #79

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ model:
lr: .0001

data:
class_path: data.LitCombustionDataModule
class_path: data.R2DataModule
init_args:
batch_size: 1
num_workers: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ model:
lr: .0001

data:
class_path: data.LitCombustionDataModule
class_path: data.R2DataModule
init_args:
batch_size: 1
num_workers: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ model:
lr: .0001

data:
class_path: data.LitCombustionDataModule
class_path: data.R2DataModule
init_args:
batch_size: 1
num_workers: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ model:
lr: .0001

data:
class_path: data.LitCombustionDataModule
class_path: data.R2DataModule
init_args:
batch_size: 1
num_workers: 0
Expand Down
2 changes: 1 addition & 1 deletion reactive-flows/cnf-combustion/gnns/configs/gat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ model:
lr: .0001

data:
class_path: data.LitCombustionDataModule
class_path: data.R2DataModule
init_args:
batch_size: 1
num_workers: 0
Expand Down
2 changes: 1 addition & 1 deletion reactive-flows/cnf-combustion/gnns/configs/gcn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ model:
lr: .0001

data:
class_path: data.LitCombustionDataModule
class_path: data.R2DataModule
init_args:
batch_size: 1
num_workers: 0
Expand Down
2 changes: 1 addition & 1 deletion reactive-flows/cnf-combustion/gnns/configs/gin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ model:
lr: .0001

data:
class_path: data.LitCombustionDataModule
class_path: data.R2DataModule
init_args:
batch_size: 1
num_workers: 0
Expand Down
2 changes: 1 addition & 1 deletion reactive-flows/cnf-combustion/gnns/configs/gunet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ model:
lr: .0001

data:
class_path: data.LitCombustionDataModule
class_path: data.R2DataModule
init_args:
batch_size: 1
num_workers: 0
Expand Down
153 changes: 95 additions & 58 deletions reactive-flows/cnf-combustion/gnns/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
from typing import List
from typing import Dict, List, Tuple

import h5py
import lightning as pl
import networkx as nx
import numpy as np
import torch
import torch_geometric as pyg
import yaml
from torch import float as tfloat
from torch import tensor


class CombustionDataset(pyg.data.Dataset):
Expand All @@ -41,6 +43,7 @@ def __init__(self, root: str, y_normalizer: float = None) -> None:
y_normalizer (str): normalizing value
"""
self.y_normalizer = y_normalizer
self.graph_topology = None
super().__init__(root)

@property
Expand Down Expand Up @@ -70,14 +73,38 @@ def download(self) -> None:
f"and move all files in file.tgz/DATA in {self.raw_dir}"
)

def _get_data(self, idx: int) -> Dict[str, np.array]:
"""Return the dict of the feat and sigma of the corresponding data file.

Returns:
(Dict[str, np.array]): the feat and sigma.
"""
raise NotImplementedError

def get(self, idx: int) -> pyg.data.Data:
"""Return the graph at the given index.

Returns:
(pyg.data.Data): Graph at the given index.
"""
data = torch.load(os.path.join(self.processed_dir, f"data-{idx}.pt"))
return data
pyg_data = copy.copy(self.graph_topology)
data = self._get_data(idx)
pyg_data.x = tensor(data["feat"].reshape(-1, 1), dtype=tfloat)
pyg_data.y = tensor(data["sigma"].reshape(-1, 1), dtype=tfloat)
return pyg_data

def create_graph_topo(self, grid_shape: Tuple[int, int, int]) -> None:
"""Create the graph topology and store it in memory.

Args:
grid_shape (Tuple[int, int, int]): the shape of the grid for the
z, y and x sorted dimensions.
"""
g0 = nx.grid_graph(dim=grid_shape)
self.graph_topology = pyg.utils.convert.from_networkx(g0)
coordinates = list(g0.nodes())
coordinates.reverse()
self.graph_topology.pos = tensor(np.stack(coordinates))

def len(self) -> int:
"""Return the total length of the dataset
Expand Down Expand Up @@ -105,34 +132,27 @@ def process(self) -> None:
Create a graph for each volume of data, and saves each graph in a separate file index by
the order in the raw file names list.
"""
i = 0
for raw_path in self.raw_paths:
with h5py.File(raw_path, "r") as file:
feat = file["/c_filt"][:]

sigma = file["/c_grad_filt"][:]
if self.y_normalizer:
sigma /= self.y_normalizer

x_size, y_size, z_size = feat.shape

grid_shape = (z_size, y_size, x_size)

g0 = nx.grid_graph(dim=grid_shape)
graph = pyg.utils.convert.from_networkx(g0)
undirected_index = graph.edge_index
coordinates = list(g0.nodes())
coordinates.reverse()

data = pyg.data.Data(
x=torch.tensor(feat.reshape(-1, 1), dtype=torch.float),
edge_index=undirected_index.clone().detach().type(torch.LongTensor),
pos=torch.tensor(np.stack(coordinates)),
y=torch.tensor(sigma.reshape(-1, 1), dtype=torch.float),
)
# Create graph from first file
with h5py.File(self.raw_paths[0], "r") as file:
feat = file["/c_filt"][:]
x_size, y_size, z_size = feat.shape
grid_shape = (z_size, y_size, x_size)
self.create_graph_topo(grid_shape)

def _get_data(self, idx: int) -> Dict[str, np.array]:
"""Return the dict of the feat and sigma of the corresponding data file.

Returns:
(Dict[str, np.array]): the feat and sigma.
"""
data = {}
with h5py.File(self.raw_paths[idx], "r") as file:
data["feat"] = file["/c_filt"][:]

torch.save(data, os.path.join(self.processed_dir, f"data-{i}.pt"))
i += 1
data["sigma"] = file["/c_grad_filt"][:]
if self.y_normalizer:
data["sigma"] /= self.y_normalizer
return data


class CnfDataset(CombustionDataset):
Expand All @@ -152,33 +172,27 @@ def process(self) -> None:
Create a graph for each volume of data, and saves each graph in a separate file index by
the order in the raw file names list.
"""
i = 0
for raw_path in self.raw_paths:
with h5py.File(raw_path, "r") as file:
feat = file["/filt_8"][:]

sigma = file["/filt_grad_8"][:]
if self.y_normalizer is not None:
sigma /= self.y_normalizer

x_size, y_size, z_size = feat.shape
grid_shape = (z_size, y_size, x_size)

g0 = nx.grid_graph(dim=grid_shape)
graph = pyg.utils.convert.from_networkx(g0)
undirected_index = graph.edge_index
coordinates = list(g0.nodes())
coordinates.reverse()

data = pyg.data.Data(
x=torch.tensor(feat.reshape(-1, 1), dtype=torch.float),
edge_index=undirected_index.type(torch.LongTensor),
pos=torch.tensor(np.stack(coordinates)),
y=torch.tensor(sigma.reshape(-1, 1), dtype=torch.float),
)
# Create graph from first file
with h5py.File(self.raw_paths[0], "r") as file:
feat = file["/filt_8"][:]
x_size, y_size, z_size = feat.shape
grid_shape = (z_size, y_size, x_size)
self.create_graph_topo(grid_shape)

torch.save(data, os.path.join(self.processed_dir, f"data-{i}.pt"))
i += 1
def _get_data(self, idx: int) -> Dict[str, np.array]:
"""Return the dict of the feat and sigma of the corresponding data file.

Returns:
(Dict[str, np.array]): the feat and sigma.
"""
data = {}
with h5py.File(self.raw_paths[idx], "r") as file:
data["feat"] = file["/filt_8"][:]

data["sigma"] = file["/filt_grad_8"][:]
if self.y_normalizer:
data["sigma"] /= self.y_normalizer
return data


class LitCombustionDataModule(pl.LightningDataModule):
Expand Down Expand Up @@ -217,6 +231,11 @@ def __init__(
self.test_dataset = None
self.train_dataset = None

@property
def dataset_class(self) -> pyg.data.Dataset:
# Set here the Dataset class you want to use in the datamodule
return NotImplementedError

def prepare_data(self) -> None:
"""Not used."""
CombustionDataset(self.data_path, self.y_normalizer)
Expand All @@ -238,7 +257,9 @@ def setup(
if self.source_raw_data_path:
LinkRawData(self.source_raw_data_path, self.data_path)

dataset = R2Dataset(self.data_path, y_normalizer=self.y_normalizer).shuffle()
dataset = self.dataset_class(
self.data_path, y_normalizer=self.y_normalizer
).shuffle()
dataset_size = len(dataset)

self.val_dataset = dataset[int(dataset_size * 0.9) :]
Expand Down Expand Up @@ -336,3 +357,19 @@ def rm_old_dataset(self):
os.rmdir(file_location)
else:
pass


class R2DataModule(LitCombustionDataModule):
"""Data module to load use R2Dataset."""

@property
def dataset_class(self) -> pyg.data.Dataset:
return R2Dataset


class CnfDataModule(LitCombustionDataModule):
"""Data module to load use R2Dataset."""

@property
def dataset_class(self) -> pyg.data.Dataset:
return CnfDataset
4 changes: 2 additions & 2 deletions reactive-flows/cnf-combustion/gnns/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def on_test_epoch_end(self) -> None:
self.y_hats = self.all_gather(y_hats)

# Reshape the outputs to the original grid shape plus the batch dimension
self.ys = self.ys.squeeze().view((-1,) + self.grid_shape).detach().numpy()
self.ys = self.ys.squeeze().view((-1,) + self.grid_shape).detach().cpu().numpy()
self.y_hats = (
self.y_hats.squeeze().view((-1,) + self.grid_shape).detach().numpy()
self.y_hats.squeeze().view((-1,) + self.grid_shape).detach().cpu().numpy()
)

plots_path = os.path.join(self.trainer.log_dir, "plots")
Expand Down
21 changes: 12 additions & 9 deletions reactive-flows/cnf-combustion/gnns/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import numpy as np
import torch
import yaml
from torch import LongTensor, Tensor

from data import CnfDataset, LinkRawData, LitCombustionDataModule
from data import CnfDataModule, CnfDataset, LinkRawData


class TestData(unittest.TestCase):
Expand Down Expand Up @@ -98,10 +99,12 @@ def test_process(self):

self.assertTrue(os.path.exists(os.path.join(tempdir, "data", "processed")))

# insert +2 to have transform and filter files
self.assertEqual(
len(os.listdir(os.path.join(tempdir, "data", "processed"))),
len(self.filenames) + 2,
# Check the pyg.data.Data object has edge_index and pos
self.assertTrue(
isinstance(data_test.graph_topology.edge_index, LongTensor),
)
self.assertTrue(
isinstance(data_test.graph_topology.pos, Tensor),
)

def test_get(self):
Expand All @@ -121,7 +124,7 @@ def test_setup(self):

init_param = copy(self.init_param)
init_param.update({"data_path": os.path.join(tempdir, "data")})
dataset_test = LitCombustionDataModule(**init_param)
dataset_test = CnfDataModule(**init_param)

with self.assertRaises(ValueError) as context:
dataset_test.setup(stage=None)
Expand Down Expand Up @@ -150,7 +153,7 @@ def test_train_dataloader(self):

init_param = copy(self.init_param)
init_param.update({"data_path": os.path.join(tempdir, "data")})
dataset_test = LitCombustionDataModule(**init_param)
dataset_test = CnfDataModule(**init_param)

with self.assertRaises(ValueError):
_ = dataset_test.setup(stage=None)
Expand All @@ -166,7 +169,7 @@ def test_val_dataloader(self):

init_param = copy(self.init_param)
init_param.update({"data_path": os.path.join(tempdir, "data")})
dataset_test = LitCombustionDataModule(**init_param)
dataset_test = CnfDataModule(**init_param)

with self.assertRaises(ValueError):
_ = dataset_test.setup(stage=None)
Expand All @@ -182,7 +185,7 @@ def test_test_dataloader(self):

init_param = copy(self.init_param)
init_param.update({"data_path": os.path.join(tempdir, "data")})
dataset_test = LitCombustionDataModule(**init_param)
dataset_test = CnfDataModule(**init_param)

with self.assertRaises(ValueError):
_ = dataset_test.setup(stage=None)
Expand Down
Loading