Skip to content

Commit

Permalink
Merge pull request #43 from msamsami/major-internal-refactor
Browse files Browse the repository at this point in the history
feat: refactor `GeneralNB` and `GaussianWNB`, improve validation, and enhance scikit-learn compatibility
  • Loading branch information
msamsami authored Jan 19, 2025
2 parents 2770e16 + 5710ea9 commit 45e174b
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 391 deletions.
4 changes: 0 additions & 4 deletions MANIFEST.in

This file was deleted.

15 changes: 2 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

<div align="center">

![Lastest Release](https://img.shields.io/badge/release-v0.5.1-green)
![Lastest Release](https://img.shields.io/badge/release-v0.6.0-green)
[![PyPI Version](https://img.shields.io/pypi/v/wnb)](https://pypi.org/project/wnb/)
![Python Versions](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue)<br>
![GitHub Workflow Status (build)](https://github.com/msamsami/wnb/actions/workflows/build.yml/badge.svg)
Expand Down Expand Up @@ -102,6 +102,7 @@ Both Scikit-learn classifiers and WNB classifiers share these well-known methods
- `predict(X)`
- `predict_proba(X)`
- `predict_log_proba(X)`
- `predict_joint_log_proba(X)`
- `score(X, y)`
- `get_params()`
- `set_params(**params)`
Expand All @@ -123,18 +124,6 @@ These benchmarks highlight the potential of WNB classifiers to provide better pe

The scripts used to generate these benchmark results are available in the _tests/benchmarks/_ directory.

## Tests
To run the tests, make sure to clone the repository and install the development requirements in addition to base requirements:
```bash
pip install -r requirements.txt
pip install -r requirements-dev.txt
```

Then, run pytest:
```bash
pytest
```

## Support us 💡
You can support the project in the following ways:

Expand Down
21 changes: 21 additions & 0 deletions tests/test_gnb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.base import is_classifier
from sklearn.naive_bayes import BernoulliNB, CategoricalNB, GaussianNB
Expand Down Expand Up @@ -281,3 +282,23 @@ def test_gnb_var_smoothing_non_numeric():
clf = GeneralNB(distributions=[D.CATEGORICAL, D.CATEGORICAL], var_smoothing=1e-6)
clf.fit(X, y)
assert clf.epsilon_ == 0


def test_gnb_attrs():
"""
Test whether the attributes are properly set.
"""
clf = GeneralNB().fit(X, y)
assert np.array_equal(clf.class_count_, np.array([3, 3]))
assert np.array_equal(clf.class_prior_, np.array([0.5, 0.5]))
assert np.array_equal(clf.classes_, np.array([1, 2]))
assert clf.n_classes_ == 2
assert clf.epsilon_ > 0
assert clf.n_features_in_ == 2
assert not hasattr(clf, "feature_names_in_")
assert clf.distributions_ == [D.NORMAL, D.NORMAL]
assert len(clf.likelihood_params_) == 2

feature_names = [f"x{i}" for i in range(X.shape[1])]
clf = GeneralNB().fit(pd.DataFrame(X, columns=feature_names), y)
assert np.array_equal(clf.feature_names_in_, np.array(feature_names))
23 changes: 23 additions & 0 deletions tests/test_gwnb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re

import numpy as np
import pandas as pd
import pytest
from sklearn.base import is_classifier
from sklearn.utils._testing import assert_array_almost_equal, assert_array_equal
Expand Down Expand Up @@ -194,3 +195,25 @@ def test_gwnb_no_cost_hist():
clf = GaussianWNB(max_iter=10)
clf.fit(X, y)
assert clf.cost_hist_ is None


def test_gwnb_attrs():
"""
Test whether the attributes are properly set.
"""
clf = GaussianWNB().fit(X, y)
assert np.array_equal(clf.class_count_, np.array([3, 3]))
assert np.array_equal(clf.class_prior_, np.array([0.5, 0.5]))
assert np.array_equal(clf.classes_, np.array([1, 2]))
assert clf.n_classes_ == 2
assert clf.n_features_in_ == 2
assert not hasattr(clf, "feature_names_in_")
assert np.array_equal(clf.error_weights_, np.array([[0, 1], [-1, 0]]))
assert clf.theta_.shape == (2, 2)
assert clf.std_.shape == (2, 2)
assert clf.var_.shape == (2, 2)
assert clf.coef_.shape == (2,)

feature_names = [f"x{i}" for i in range(X.shape[1])]
clf = GaussianWNB().fit(pd.DataFrame(X, columns=feature_names), y)
assert np.array_equal(clf.feature_names_in_, np.array(feature_names))
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion wnb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Python library for the implementations of general and weighted naive Bayes (WNB) classifiers.
"""

__version__ = "0.5.1"
__version__ = "0.6.0"
__author__ = "Mehdi Samsami"


Expand Down
12 changes: 11 additions & 1 deletion wnb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import sklearn
from packaging import version
from sklearn.utils import check_array
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_X_y as _check_X_y

__all__ = [
"SKLEARN_V1_6_OR_LATER",
"validate_data",
"check_X_y",
"_check_n_features",
"_check_feature_names",
]
Expand All @@ -24,12 +26,20 @@ def validate_data(*args, **kwargs):
kwargs["ensure_all_finite"] = kwargs.pop("force_all_finite")
return _validate_data(*args, **kwargs)

def check_X_y(*args, **kwargs):
if kwargs.get("force_all_finite"):
kwargs["ensure_all_finite"] = kwargs.pop("force_all_finite")
return _check_X_y(*args, **kwargs)

else:

def validate_data(estimator, X, **kwargs: Any):
kwargs.pop("reset", None)
return check_array(X, estimator=estimator, **kwargs)

def check_X_y(*args, **kwargs):
return _check_X_y(*args, **kwargs)

def _check_n_features(estimator, X, reset):
return estimator._check_n_features(X, reset=reset)

Expand Down
Loading

0 comments on commit 45e174b

Please sign in to comment.