Skip to content

Commit

Permalink
Merge pull request #407 from juntyr/float64
Browse files Browse the repository at this point in the history
Add values_dtype backend option to load values at full precision
  • Loading branch information
iainrussell authored Dec 18, 2024
2 parents 11acc76 + 21ec8e8 commit bda2944
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 4 deletions.
10 changes: 7 additions & 3 deletions cfgrib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,11 @@ class OnDiskArray:
)
missing_value: float
geo_ndim: int = attr.attrib(default=1, repr=False)
dtype = np.dtype("float32")
dtype: np.dtype = attr.attrib(default=messages.DEFAULT_VALUES_DTYPE, repr=False)

def build_array(self) -> np.ndarray:
"""Helper method used to test __getitem__"""
array = np.full(self.shape, fill_value=np.nan, dtype="float32")
array = np.full(self.shape, fill_value=np.nan, dtype=self.dtype)
for header_indexes, message_ids in self.field_id_index.items():
# NOTE: fill a single field as found in the message
message = self.index.get_field(message_ids[0]) # type: ignore
Expand All @@ -363,7 +363,7 @@ def __getitem__(self, item):
header_item_list = expand_item(item[: -self.geo_ndim], self.shape)
header_item = [{ix: i for i, ix in enumerate(it)} for it in header_item_list]
array_field_shape = tuple(len(i) for i in header_item_list) + self.shape[-self.geo_ndim :]
array_field = np.full(array_field_shape, fill_value=np.nan, dtype="float32")
array_field = np.full(array_field_shape, fill_value=np.nan, dtype=self.dtype)
for header_indexes, message_ids in self.field_id_index.items():
try:
array_field_indexes = [it[ix] for it, ix in zip(header_item, header_indexes)]
Expand Down Expand Up @@ -497,6 +497,7 @@ def build_variable_components(
extra_coords: T.Dict[str, str] = {},
coords_as_attributes: T.Dict[str, str] = {},
cache_geo_coords: bool = True,
values_dtype: np.dtype = messages.DEFAULT_VALUES_DTYPE,
) -> T.Tuple[T.Dict[str, int], Variable, T.Dict[str, Variable]]:
data_var_attrs = enforce_unique_attributes(index, DATA_ATTRIBUTES_KEYS, filter_by_keys)
grid_type_keys = GRID_TYPE_MAP.get(index.getone("gridType"), [])
Expand Down Expand Up @@ -601,6 +602,7 @@ def build_variable_components(
field_id_index=offsets,
missing_value=missing_value,
geo_ndim=len(geo_dims),
dtype=values_dtype,
)

if "time" in coord_vars and "step" in coord_vars:
Expand Down Expand Up @@ -673,6 +675,7 @@ def build_dataset_components(
extra_coords: T.Dict[str, str] = {},
coords_as_attributes: T.Dict[str, str] = {},
cache_geo_coords: bool = True,
values_dtype: np.dtype = messages.DEFAULT_VALUES_DTYPE,
) -> T.Tuple[T.Dict[str, int], T.Dict[str, Variable], T.Dict[str, T.Any], T.Dict[str, T.Any]]:
dimensions = {} # type: T.Dict[str, int]
variables = {} # type: T.Dict[str, Variable]
Expand Down Expand Up @@ -700,6 +703,7 @@ def build_dataset_components(
extra_coords=extra_coords,
coords_as_attributes=coords_as_attributes,
cache_geo_coords=cache_geo_coords,
values_dtype=values_dtype,
)
except DatasetBuildError as ex:
# NOTE: When a variable has more than one value for an attribute we need to raise all
Expand Down
1 change: 1 addition & 0 deletions cfgrib/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def multi_enabled(file: T.IO[bytes]) -> T.Iterator[None]:
}

DEFAULT_INDEXPATH = "{path}.{short_hash}.idx"
DEFAULT_VALUES_DTYPE = np.dtype("float32")

OffsetType = T.Union[int, T.Tuple[int, int]]

Expand Down
2 changes: 2 additions & 0 deletions cfgrib/xarray_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def open_dataset(
extra_coords: T.Dict[str, str] = {},
coords_as_attributes: T.Dict[str, str] = {},
cache_geo_coords: bool = True,
values_dtype: np.dtype = messages.DEFAULT_VALUES_DTYPE,
) -> xr.Dataset:
store = CfGribDataStore(
filename_or_obj,
Expand All @@ -122,6 +123,7 @@ def open_dataset(
extra_coords=extra_coords,
coords_as_attributes=coords_as_attributes,
cache_geo_coords=cache_geo_coords,
values_dtype=values_dtype,
)
with xr.core.utils.close_on_error(store):
vars, attrs = store.load() # type: ignore
Expand Down
14 changes: 14 additions & 0 deletions tests/test_30_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,17 @@ def test_missing_field_values() -> None:
t2 = res.variables["t2m"]
assert np.isclose(np.nanmean(t2.data[0, :, :]), 268.375)
assert np.isclose(np.nanmean(t2.data[1, :, :]), 270.716)


def test_default_values_dtype() -> None:
res = dataset.open_file(TEST_DATA_MISSING_VALS)
assert res.variables["t2m"].data.dtype == np.dtype("float32")
assert res.variables["latitude"].data.dtype == np.dtype("float64")
assert res.variables["longitude"].data.dtype == np.dtype("float64")


def test_float64_values_dtype() -> None:
res = dataset.open_file(TEST_DATA_MISSING_VALS, values_dtype=np.dtype("float64"))
assert res.variables["t2m"].data.dtype == np.dtype("float64")
assert res.variables["latitude"].data.dtype == np.dtype("float64")
assert res.variables["longitude"].data.dtype == np.dtype("float64")
16 changes: 15 additions & 1 deletion tests/test_50_xarray_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,18 @@ def test_xr_open_dataset_coords_to_attributes() -> None:
assert "depthBelowLandLayer" not in ds.coords

assert "GRIB_surface" in ds["t2m"].attrs
assert "GRIB_depthBelowLandLayer" in ds["stl1"].attrs
assert "GRIB_depthBelowLandLayer" in ds["stl1"].attrs


def test_xr_open_dataset_default_values_dtype() -> None:
ds = xr.open_dataset(TEST_DATA_MISSING_VALS, engine="cfgrib")
assert ds["t2m"].data.dtype == np.dtype("float32")
assert ds["latitude"].data.dtype == np.dtype("float64")
assert ds["longitude"].data.dtype == np.dtype("float64")


def test_xr_open_dataset_float64_values_dtype() -> None:
ds = xr.open_dataset(TEST_DATA_MISSING_VALS, engine="cfgrib", values_dtype=np.dtype("float64"))
assert ds["t2m"].data.dtype == np.dtype("float64")
assert ds["latitude"].data.dtype == np.dtype("float64")
assert ds["longitude"].data.dtype == np.dtype("float64")

0 comments on commit bda2944

Please sign in to comment.