Skip to content

Commit

Permalink
Merge pull request #225 from matthewwardrop/pandas_dict_recarray
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop authored Dec 3, 2024
2 parents 05bfa25 + c92cc83 commit 2c217dd
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
18 changes: 17 additions & 1 deletion formulaic/materializers/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import itertools
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Set, Tuple, cast

import numpy
Expand All @@ -22,9 +23,24 @@

class PandasMaterializer(FormulaMaterializer):
REGISTER_NAME = "pandas"
REGISTER_INPUTS: Sequence[str] = ("pandas.core.frame.DataFrame", "pandas.DataFrame")
REGISTER_INPUTS: Sequence[str] = (
"pandas.core.frame.DataFrame",
"pandas.DataFrame",
"dict",
"numpy.rec.recarray",
)
REGISTER_OUTPUTS: Sequence[str] = ("pandas", "numpy", "sparse")

@override
def _init(self) -> None:
if isinstance(self.data, (dict, Mapping)):
if all(numpy.isscalar(v) for v in self.data.values()):
self.data = pandas.DataFrame(self.data, index=[0])
else:
self.data = pandas.DataFrame(self.data)
elif isinstance(self.data, numpy.rec.recarray):
self.data = pandas.DataFrame.from_records(self.data)

@override
def _is_categorical(self, values: Any) -> bool:
if isinstance(values, (pandas.Series, pandas.Categorical)):
Expand Down
2 changes: 2 additions & 0 deletions tests/materializers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class TestFormulaMaterializer:
def test_registrations(self):
assert sorted(FormulaMaterializer.REGISTERED_NAMES) == ["arrow", "pandas"]
assert sorted(FormulaMaterializer.REGISTERED_INPUTS) == [
"dict",
"numpy.rec.recarray",
"pandas.DataFrame",
"pandas.core.frame.DataFrame",
"pyarrow.lib.Table",
Expand Down
17 changes: 17 additions & 0 deletions tests/materializers/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,23 @@ def data_with_nulls(self):
def materializer(self, data):
return PandasMaterializer(data)

def test_data_conversion(self):
df = PandasMaterializer({"a": [1, 2, 3]}).data
assert isinstance(df, pandas.DataFrame)
assert df.columns == ["a"]

df2 = PandasMaterializer({"a": 1}).data
assert isinstance(df2, pandas.DataFrame)
assert df2.columns == ["a"]
assert list(df2["a"]) == [1]

df3 = PandasMaterializer(
numpy.recarray((2,), dtype=[("x", int), ("y", float), ("z", int)])
).data
assert isinstance(df3, pandas.DataFrame)
assert list(df3.columns) == ["x", "y", "z"]
assert len(df3["x"]) == 2

@pytest.mark.parametrize("formula,tests", PANDAS_TESTS.items())
def test_get_model_matrix(self, materializer, formula, tests):
mm = materializer.get_model_matrix(formula, ensure_full_rank=True)
Expand Down

0 comments on commit 2c217dd

Please sign in to comment.