Skip to content

Commit

Permalink
add splitting_ratios to weather-forecast/gravity-wave-drag/cnns
Browse files Browse the repository at this point in the history
  • Loading branch information
A669015 committed Feb 3, 2025
1 parent d24e3ed commit e10df5b
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 35 deletions.
2 changes: 1 addition & 1 deletion weather-forecast/ecrad-3d-correction/unets/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def prepare_data(self):
def setup(self, stage: Optional[str] = None):

# Define subsets.
tr, va, te = self.splitting_ratios
tr, va, _ = self.splitting_ratios
self.dataset = ThreeDCorrectionDataset(self.data_path)
length = len(self.dataset)
idx = list(range(length))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ data:
init_args:
batch_size: 10
num_workers: 1
splitting_ratios: [0.6, 0.2]
splitting_ratios: [0.6, 0.2, 0.2]
shard_len: 100
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ data:
init_args:
batch_size: 10
num_workers: 1
splitting_ratios: [0.6, 0.2]
splitting_ratios: [0.6, 0.2, 0.2]
shard_len: 100
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def create_data():
"test_5.h5",
"test_6.h5",
]
file_split = {}

with open(filepath, "r") as file:
params = safe_load(file)
Expand All @@ -47,22 +46,6 @@ def create_data():
f["/x"] = np.random.normal(0, 1, (row_feats, 191)).astype("float32")
f["/y"] = np.random.normal(0, 1, (row_feats, 126)).astype("float32")

temp_file_path = join(data_path, "filenames.yaml")
with open(temp_file_path, "w") as tmpfile:
dump(filenames, tmpfile)

tr, va = params["data"]["init_args"]["splitting_ratios"]
length = len(filenames)

for _ in filenames:
file_split["train"] = filenames[: int(tr * length)]
file_split["val"] = filenames[int(tr * length) : int((tr + va) * length)]
file_split["test"] = filenames[int((tr + va) * length) :]

temp_file_path = join(data_path, "filenames-split.yaml")
with open(temp_file_path, "w") as tmpfile:
dump(file_split, tmpfile)

else:
raise Exception(f"Remove manually {data_path}")

Expand Down
1 change: 1 addition & 0 deletions weather-forecast/gravity-wave-drag/cnns/configs/cnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ data:
init_args:
batch_size: 256
num_workers: 0
splitting_ratios: [0.6, 0.2, 0.2]
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ data:
init_args:
batch_size: 256
num_workers: 0
splitting_ratios: [0.6, 0.2, 0.2]
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

trainer:
Expand Down
1 change: 1 addition & 0 deletions weather-forecast/gravity-wave-drag/cnns/configs/mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ data:
init_args:
batch_size: 256
num_workers: 0
splitting_ratios: [0.6, 0.2, 0.2]
data_path: ${oc.decode:${oc.env:DATADIR, ./data}}

trainer:
Expand Down
56 changes: 41 additions & 15 deletions weather-forecast/gravity-wave-drag/cnns/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os
import random
from typing import List, Optional, Tuple

import h5py
import lightning as pl
import numpy as np
import torch
import yaml
from torch.utils.data import DataLoader, Dataset
from yaml import dump

RAW = "raw"
FILENAMES_SPLIT = "filenames-split.yaml"

class NOGWDDataset(torch.utils.data.Dataset):

class NOGWDDataset(Dataset):
"""
Creates the PyTorch Dataset for the Non-Orographic variant of the Gravity Wave Drag (GWD) UC.
Each raw HDF5 input file contains two datasets—x and y—of 2355840 rows each.
Expand Down Expand Up @@ -97,7 +104,7 @@ def raw_dir(self) -> str:
Returns:
(str): Raw data folder path.
"""
return os.path.join(self.root, "raw")
return os.path.join(self.root, RAW)

@property
def raw_filenames(self) -> List[str]:
Expand All @@ -106,7 +113,7 @@ def raw_filenames(self) -> List[str]:
Returns:
(List[str]): Raw data file names list.
"""
with open(os.path.join(self.root, "filenames-split.yaml"), "r") as stream:
with open(os.path.join(self.root, FILENAMES_SPLIT), "r") as stream:
filenames = yaml.safe_load(stream)
filenames = filenames[self.mode]
return filenames
Expand Down Expand Up @@ -139,7 +146,7 @@ def __init__(
batch_size: int,
num_workers: int,
data_path: str,
splitting_ratios: Tuple[float, float] = (0.8, 0.1),
splitting_ratios: Tuple[float, float, float] = (0.8, 0.1, 0.1),
shard_len: int = 2355840,
) -> None:
"""Init the NOGWDDataModule class.
Expand All @@ -160,7 +167,24 @@ def __init__(

def prepare_data(self) -> None:
"""Not used. The download logic is the responsibility for the Dataset."""
pass
filenames_split_path = os.path.join(self.data_path, FILENAMES_SPLIT)
if os.path.exists(filenames_split_path):
print(
f"Warning: a {FILENAMES_SPLIT} file exists, the splitting ratios will be ignored."
)
else:
files = random.shuffle(
glob.glob("*.h5", root_dir=os.path.join(self.data_path, RAW))
)
length = len(files)
tr, va, _ = self.splitting_ratios

file_split = {}
file_split["train"] = files[: int(tr * length)]
file_split["val"] = files[int(tr * length) : int((tr + va) * length)]
file_split["test"] = files[int((tr + va) * length) :]
with open(filenames_split_path, "w") as filenames_split:
dump(file_split, filenames_split)

def setup(self, stage: Optional[str] = None) -> None:
"""
Expand All @@ -175,39 +199,41 @@ def setup(self, stage: Optional[str] = None) -> None:
if stage == "test":
self.test = NOGWDDataset(self.data_path, "test", self.shard_len)

def train_dataloader(self) -> torch.utils.data.DataLoader:
def train_dataloader(self) -> DataLoader:
"""Return the train DataLoader.
Returns:
(torch.utils.data.DataLoader): Train DataLoader.
(DataLoader): Train DataLoader.
"""
return torch.utils.data.DataLoader(
return DataLoader(
self.train,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
num_workers=self.num_workers,
)

def val_dataloader(self) -> torch.utils.data.DataLoader:
def val_dataloader(self) -> DataLoader:
"""Return the val DataLoader.
Returns:
(torch.utils.data.DataLoader): Validation DataLoader.
(DataLoader): Validation DataLoader.
"""
return torch.utils.data.DataLoader(
return DataLoader(
self.val,
batch_size=self.batch_size,
drop_last=True,
num_workers=self.num_workers,
)

def test_dataloader(self) -> torch.utils.data.DataLoader:
def test_dataloader(self) -> DataLoader:
"""Return the test DataLoader.
Returns:
(torch.utils.data.DataLoader): Test DataLoader.
(DataLoader): Test DataLoader.
"""
return torch.utils.data.DataLoader(
self.test, batch_size=self.batch_size, num_workers=self.num_workers
return DataLoader(
self.test,
batch_size=self.batch_size,
num_workers=self.num_workers,
)

0 comments on commit e10df5b

Please sign in to comment.