Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated to accept NAN values in string columns: Fix for issue #138 #185

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ line-length = 88
output-format = "full"
src = ["src", "tests", "examples"]

[tool.ruff.lint.mccabe]
# Flag errors (`C901`) whenever the complexity level exceeds 5.
max-complexity = 15

[tool.ruff.lint]
# Extend what ruff is allowed to fix, even it it may break
# This is okay given we use it all the time and it ensures
Expand Down Expand Up @@ -256,6 +260,7 @@ convention = "google"

[tool.ruff.lint.pylint]
max-args = 10 # Changed from default of 5
max-branches = 16

[tool.mypy]
python_version = "3.9"
Expand Down
31 changes: 26 additions & 5 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def __init__( # noqa: PLR0913
self.random_state = random_state
self.n_jobs = n_jobs
self.inference_config = inference_config
self.placeholder = "__MISSING__"

# TODO: We can remove this from scikit-learn lower bound of 1.6
def _more_tags(self) -> dict[str, Any]:
Expand Down Expand Up @@ -448,11 +449,20 @@ def fit(self, X: XType, y: YType) -> Self:

# Will convert specified categorical indices to category dtype, as well
# as handle `np.object` arrays or otherwise `object` dtype pandas columns.
X = _fix_dtypes(X, cat_indices=self.categorical_features_indices)

X_fixed = _fix_dtypes(
X,
cat_indices=self.categorical_features_indices,
placeholder=self.placeholder,
)
string_cols = X_fixed.select_dtypes(include=["string", "object"]).columns
# Ensure categories are ordinally encoded
ord_encoder = _get_ordinal_encoder()
X = ord_encoder.fit_transform(X) # type: ignore
X = ord_encoder.fit_transform(X_fixed)
string_indices = [X_fixed.columns.get_loc(col) for col in string_cols]
mask = (X_fixed[string_cols] == self.placeholder).to_numpy()
for i, col_idx in enumerate(string_indices):
X[:, col_idx] = np.where(mask[:, i], np.nan, X[:, col_idx])

assert isinstance(X, np.ndarray)
self.preprocessor_ = ord_encoder

Expand Down Expand Up @@ -529,8 +539,19 @@ def predict_proba(self, X: XType) -> np.ndarray:
check_is_fitted(self)

X = validate_X_predict(X, self)
X = _fix_dtypes(X, cat_indices=self.categorical_features_indices)
X = self.preprocessor_.transform(X)

X_fixed = _fix_dtypes(
X,
cat_indices=self.categorical_features_indices,
placeholder=self.placeholder,
)
string_cols = X_fixed.select_dtypes(include=["string", "object"]).columns

X = self.preprocessor_.transform(X_fixed)
string_indices = [X_fixed.columns.get_loc(col) for col in string_cols]
mask = (X_fixed[string_cols] == self.placeholder).to_numpy()
for i, col_idx in enumerate(string_indices):
X[:, col_idx] = np.where(mask[:, i], np.nan, X[:, col_idx])

outputs: list[torch.Tensor] = []

Expand Down
2 changes: 1 addition & 1 deletion src/tabpfn/misc/debug_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def _get_deps_info():
try:
deps_info[modname] = get_version(modname) # Use renamed function
except PackageNotFoundError:
deps_info[modname] = None
deps_info[modname] = "None"
return deps_info


Expand Down
2 changes: 1 addition & 1 deletion src/tabpfn/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def broadcast_kv_across_heads(
return kv.reshape(*kv.shape[:-3], nhead * share_kv_across_n_heads, d)

@staticmethod
def compute_attention_heads( # noqa: C901, PLR0912
def compute_attention_heads( # noqa: PLR0912
q: torch.Tensor | None,
k: torch.Tensor | None,
v: torch.Tensor | None,
Expand Down
2 changes: 1 addition & 1 deletion src/tabpfn/model/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def __init__(
self.global_transformer_name = global_transformer_name
self.transformer_: Pipeline | ColumnTransformer | None = None

def _set_transformer_and_cat_ix( # noqa: PLR0912
def _set_transformer_and_cat_ix(
self,
n_samples: int,
n_features: int,
Expand Down
4 changes: 2 additions & 2 deletions src/tabpfn/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class PerFeatureTransformer(nn.Module):
"""

# TODO: Feel like this could be simplified a lot from this part downwards
def __init__( # noqa: C901, D417, PLR0913
def __init__( # noqa: D417, PLR0913
self,
*,
encoder: nn.Module | None = None,
Expand Down Expand Up @@ -679,7 +679,7 @@ def _forward( # noqa: PLR0912, C901

return output_decoded

def add_embeddings( # noqa: C901, PLR0912
def add_embeddings(
self,
x: torch.Tensor,
y: torch.Tensor,
Expand Down
32 changes: 26 additions & 6 deletions src/tabpfn/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def __init__( # noqa: PLR0913
self.random_state = random_state
self.n_jobs = n_jobs
self.inference_config = inference_config
self.placeholder = "__MISSING__"

# TODO: We can remove this from scikit-learn lower bound of 1.6
def _more_tags(self) -> dict[str, Any]:
Expand Down Expand Up @@ -439,11 +440,19 @@ def fit(self, X: XType, y: YType) -> Self:

# Will convert specified categorical indices to category dtype, as well
# as handle `np.object` arrays or otherwise `object` dtype pandas columns.
X = _fix_dtypes(X, cat_indices=self.categorical_features_indices)

X_fixed = _fix_dtypes(
X,
cat_indices=self.categorical_features_indices,
placeholder=self.placeholder,
)
string_cols = X_fixed.select_dtypes(include=["string", "object"]).columns
# Ensure categories are ordinally encoded
ord_encoder = _get_ordinal_encoder()
X = ord_encoder.fit_transform(X) # type: ignore
X = ord_encoder.fit_transform(X_fixed)
string_indices = [X_fixed.columns.get_loc(col) for col in string_cols]
mask = (X_fixed[string_cols] == self.placeholder).to_numpy()
for i, col_idx in enumerate(string_indices):
X[:, col_idx] = np.where(mask[:, i], np.nan, X[:, col_idx])
self.preprocessor_ = ord_encoder

self.inferred_categorical_indices_ = infer_categorical_features(
Expand Down Expand Up @@ -555,7 +564,7 @@ def predict(
) -> dict[str, np.ndarray | FullSupportBarDistribution]: ...

# FIXME: improve to not have noqa C901, PLR0912
def predict( # noqa: C901, PLR0912
def predict(
self,
X: XType,
*,
Expand Down Expand Up @@ -605,8 +614,19 @@ def predict( # noqa: C901, PLR0912
check_is_fitted(self)

X = validate_X_predict(X, self)
X = _fix_dtypes(X, cat_indices=self.categorical_features_indices)
X = self.preprocessor_.transform(X)

X_fixed = _fix_dtypes(
X,
cat_indices=self.categorical_features_indices,
placeholder=self.placeholder,
)
string_cols = X_fixed.select_dtypes(include=["string", "object"]).columns

X = self.preprocessor_.transform(X_fixed)
string_indices = [X_fixed.columns.get_loc(col) for col in string_cols]
mask = (X_fixed[string_cols] == self.placeholder).to_numpy()
for i, col_idx in enumerate(string_indices):
X[:, col_idx] = np.where(mask[:, i], np.nan, X[:, col_idx])

if quantiles is None:
quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
Expand Down
7 changes: 6 additions & 1 deletion src/tabpfn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def _fix_dtypes(
X: pd.DataFrame | np.ndarray,
cat_indices: Sequence[int | str] | None,
numeric_dtype: Literal["float32", "float64"] = "float64",
placeholder: str = "__MISSING__",
) -> pd.DataFrame:
if isinstance(X, pd.DataFrame):
# This will help us get better dtype inference later
Expand Down Expand Up @@ -429,10 +430,14 @@ def _fix_dtypes(
#
if convert_dtype:
X = X.convert_dtypes()

integer_columns = X.select_dtypes(include=["number"]).columns

if len(integer_columns) > 0:
X[integer_columns] = X[integer_columns].astype(numeric_dtype)

string_cols = X.select_dtypes(include=["string", "object"]).columns
if len(string_cols) > 0:
X[string_cols] = X[string_cols].fillna(placeholder)
return X


Expand Down
Loading