Skip to content

Commit

Permalink
Merge pull request #45 from msamsami/new-distributions-param-type
Browse files Browse the repository at this point in the history
feat: add support for specifying distributions as a list of tuples
  • Loading branch information
msamsami authored Jan 31, 2025
2 parents 4db0269 + affc0ed commit 18709d4
Show file tree
Hide file tree
Showing 9 changed files with 416 additions and 203 deletions.
31 changes: 19 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

<div align="center">

![Lastest Release](https://img.shields.io/badge/release-v0.6.0-green)
![Lastest Release](https://img.shields.io/badge/release-v0.7.0-green)
[![PyPI Version](https://img.shields.io/pypi/v/wnb)](https://pypi.org/project/wnb/)
![Python Versions](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue)<br>
![GitHub Workflow Status (build)](https://github.com/msamsami/wnb/actions/workflows/build.yml/badge.svg)
Expand All @@ -17,13 +17,15 @@
</div>

## Introduction
Naive Bayes is often recognized as one of the most popular classification algorithms in the machine learning community. This package takes naive Bayes to a higher level by providing its implementations in more general and weighted settings.
Naive Bayes is a widely used classification algorithm known for its simplicity and efficiency. This package takes naive Bayes to a higher level by providing more flexible and weighted variants, making it suitable for a broader range of applications.

### General naive Bayes
The issue with the well-known implementations of the naive Bayes algorithm (such as the ones in `sklearn.naive_bayes` module) is that they assume a single distribution for the likelihoods of all features. Such an implementation can limit those who need to develop naive Bayes models with different distributions for feature likelihood. And enters **WNB** library! It allows you to customize your naive Bayes model by specifying the likelihood distribution of each feature separately. You can choose from a range of continuous and discrete probability distributions to design your classifier.
Most standard implementations, such as those in `sklearn.naive_bayes`, assume a single distribution type for all feature likelihoods. This can be restrictive when dealing with mixed data types. **WNB** overcomes this limitation by allowing users to specify different probability distributions for each feature individually. You can select from a variety of continuous and discrete distributions, enabling greater customization and improved model performance.

### Weighted naive Bayes
Although naive Bayes has many advantages such as simplicity and interpretability, its conditional independence assumption rarely holds true in real-world applications. In order to alleviate its conditional independence assumption, many attribute weighting naive Bayes (WNB) approaches have been proposed. Most of the proposed methods involve computationally demanding optimization problems that do not allow for controlling the model's bias due to class imbalance. Minimum Log-likelihood Difference WNB (MLD-WNB) is a novel weighting approach that optimizes the weights according to the Bayes optimal decision rule and includes hyperparameters for controlling the model's bias. **WNB** library provides an efficient implementation of gaussian MLD-WNB.
While naive Bayes is simple and interpretable, its conditional independence assumption often fails in real-world scenarios. To address this, various attribute-weighted naive Bayes methods exist, but most are computationally expensive and lack mechanisms for handling class imbalance.

**WNB** package provides an optimized implementation of *Minimum Log-likelihood Difference Wighted Naive Bayes* (MLD-WNB), a novel approach that optimizes feature weights using the Bayes optimal decision rule. It also introduces hyperparameters for controlling model bias, making it more robust for imbalanced classification.

## Installation
This library is shipped as an all-in-one module implementation with minimalistic dependencies and requirements. Furthermore, it fully **adheres to Scikit-learn API** ❤️.
Expand All @@ -42,7 +44,7 @@ uv add wnb
```

## Getting started ⚡️
Here, we show how you can use the library to train general and weighted naive Bayes classifiers.
Here, we show how you can use the library to train general (mixed) and weighted naive Bayes classifiers.

### General naive Bayes

Expand All @@ -53,14 +55,19 @@ A general naive Bayes model can be set up and used in four simple steps:
from wnb import GeneralNB, Distribution as D
```

2. Initialize a classifier and specify the likelihood distributions
2. Initialize a classifier with likelihood distributions specified
```python
gnb = GeneralNB(distributions=[D.NORMAL, D.CATEGORICAL, D.EXPONENTIAL, D.EXPONENTIAL])
```
or
```python
gnb = GeneralNB(distributions=[D.NORMAL, D.CATEGORICAL, D.EXPONENTIAL])
# Columns not explicitly specified will default to Gaussian (normal) distribution
gnb = GeneralNB(distributions=[(D.CATEGORICAL, "col2"), (D.EXPONENTIAL, ["col3", "col4"])])
```

3. Fit the classifier to a training set (with three features)
3. Fit the classifier to a training set (with four features)
```python
gnb.fit(X, y)
gnb.fit(X_train, y_train)
```

4. Predict on test data
Expand All @@ -79,17 +86,17 @@ from wnb import GaussianWNB

2. Initialize a classifier
```python
wnb = GaussianWNB(max_iter=25, step_size=1e-2, penalty="l2")
gwnb = GaussianWNB(max_iter=25, step_size=1e-2, penalty="l2")
```

3. Fit the classifier to a training set
```python
wnb.fit(X, y)
gwnb.fit(X_train, y_train)
```

4. Predict on test data
```python
wnb.predict(x_test)
gwnb.predict(X_test)
```

## Compatibility with Scikit-learn 🤝
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np
import pytest
from numpy.typing import NDArray


@pytest.fixture
def global_random_seed():
def global_random_seed() -> int:
return np.random.randint(0, 1000)


@pytest.fixture
def random_uniform():
def random_uniform() -> NDArray[np.float64]:
return np.random.uniform(0, 100, size=10000)
35 changes: 18 additions & 17 deletions tests/test_dist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from numpy.typing import NDArray
from scipy import stats
from sklearn.utils._testing import assert_array_almost_equal

Expand All @@ -20,29 +21,29 @@
TDist,
UniformDist,
)
from wnb.stats.typing import DistributionLike

out_of_support_warn_msg = "Value doesn't lie within the support of the distribution"


def test_distributions_correct_name_attr():
@pytest.mark.parametrize("dist_name", AllDistributions.keys())
def test_distributions_correct_name_attr(dist_name):
"""
Test if all defined distributions have correct `name` attributes.
"""
for dist_name in AllDistributions.keys():
assert isinstance(dist_name, (str, D))
assert isinstance(dist_name, (str, D))


def test_distributions_correct_support_attr():
@pytest.mark.parametrize("dist", AllDistributions.values())
def test_distributions_correct_support_attr(dist: DistributionLike):
"""
Test if all defined distributions have correct `_support` attributes.
"""
for dist in AllDistributions.values():
if dist.name in [D.UNIFORM, D.PARETO, D.CATEGORICAL]:
assert dist._support is None
continue
if dist.name in [D.UNIFORM, D.PARETO, D.CATEGORICAL]:
assert dist._support is None

else:
assert isinstance(dist._support, (list, tuple))

if isinstance(dist._support, list):
for x in dist._support:
assert isinstance(x, (float, int))
Expand Down Expand Up @@ -77,7 +78,7 @@ def test_normal_with_epsilon(epsilon: float):
assert_array_almost_equal(norm_2(X), norm_3(X), decimal=10)


def test_lognormal_pdf(random_uniform):
def test_lognormal_pdf(random_uniform: NDArray[np.float64]):
"""
Test whether pdf method of `LognormalDist` returns the same result as pdf method of `scipy.stats.lognorm`.
"""
Expand All @@ -98,7 +99,7 @@ def test_lognormal_out_of_support_data():
lognorm_wnb(-1)


def test_exponential_pdf(random_uniform):
def test_exponential_pdf(random_uniform: NDArray[np.float64]):
"""
Test whether pdf method of `ExponentialDist` returns the same result as pdf method of `scipy.stats.expon`.
"""
Expand All @@ -119,7 +120,7 @@ def test_exponential_out_of_support_data():
expon_wnb(-1)


def test_uniform_pdf(random_uniform):
def test_uniform_pdf(random_uniform: NDArray[np.float64]):
"""
Test whether pdf method of `UniformDist` returns the same result as pdf method of `scipy.stats.uniform`.
"""
Expand All @@ -140,7 +141,7 @@ def test_uniform_out_of_support_data():
uniform_wnb(3)


def test_pareto_pdf(random_uniform):
def test_pareto_pdf(random_uniform: NDArray[np.float64]):
"""
Test whether pdf method of `ParetoDist` returns the same result as pdf method of `scipy.stats.pareto`.
"""
Expand All @@ -161,7 +162,7 @@ def test_pareto_out_of_support_data():
pareto_wnb(-5)


def test_gamma_pdf(random_uniform):
def test_gamma_pdf(random_uniform: NDArray[np.float64]):
"""
Test whether pdf method of `GammaDist` returns the same result as pdf method of `scipy.stats.gamma`.
"""
Expand Down Expand Up @@ -203,7 +204,7 @@ def test_beta_out_of_support_data():
beta_wnb(1.01)


def test_chi2_pdf(random_uniform):
def test_chi2_pdf(random_uniform: NDArray[np.float64]):
"""
Test whether pdf method of `ChiSquaredDist` returns the same result as pdf method of `scipy.stats.chi2`.
"""
Expand All @@ -224,7 +225,7 @@ def test_chi2_out_of_support_data():
chi2_wnb(-5)


def test_t_pdf(random_uniform):
def test_t_pdf(random_uniform: NDArray[np.float64]):
"""
Test whether pdf method of `TDist` returns the same result as pdf method of `scipy.stats.t`.
"""
Expand All @@ -234,7 +235,7 @@ def test_t_pdf(random_uniform):
assert_array_almost_equal(t_wnb(X), t_scipy.pdf(X), decimal=10)


def test_rayleigh_pdf(random_uniform):
def test_rayleigh_pdf(random_uniform: NDArray[np.float64]):
"""
Test whether pdf method of `RayleighDist` returns the same result as pdf method of `scipy.stats.rayleigh`.
"""
Expand Down
134 changes: 120 additions & 14 deletions tests/test_gnb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import numpy as np
import pandas as pd
import pytest
from numpy.typing import NDArray
from sklearn.base import is_classifier
from sklearn.naive_bayes import BernoulliNB, CategoricalNB, GaussianNB
from sklearn.utils._testing import assert_array_almost_equal, assert_array_equal
Expand All @@ -14,7 +17,7 @@
y = np.array([1, 1, 1, 2, 2, 2])


def get_random_normal_x_binary_y(global_random_seed):
def get_random_normal_x_binary_y(global_random_seed: int) -> tuple[NDArray[np.float64], NDArray[np.int64]]:
# A bit more random tests
rng = np.random.RandomState(global_random_seed)
X1 = rng.normal(size=(10, 3))
Expand Down Expand Up @@ -69,7 +72,7 @@ def test_gnb_vs_sklearn_bernoulli():
X_ = rng.randint(2, size=(150, 100))
y_ = rng.randint(1, 5, size=(150,))

clf1 = GeneralNB(distributions=[D.BERNOULLI for _ in range(100)])
clf1 = GeneralNB(distributions=[D.BERNOULLI] * 100)
clf1.fit(X_, y_)

clf2 = BernoulliNB(alpha=1e-10, force_alpha=True)
Expand Down Expand Up @@ -135,7 +138,121 @@ def test_gnb_estimator():
assert is_classifier(GeneralNB)


def test_gnb_prior(global_random_seed):
def test_gnb_dist_none():
"""
Test whether the default distribution is set to Gaussian (normal) when no distributions are specified.
"""
clf = GeneralNB().fit(X, y)
assert clf.distributions_ == [D.NORMAL] * X.shape[1]


def test_gnb_dist_default_normal():
"""
Test whether the default distribution is set to Gaussian (normal) when no distributions are specified.
"""
clf = GeneralNB(distributions=[(D.RAYLEIGH, [0])]).fit(X, y)
assert clf.distributions_ == [D.RAYLEIGH, D.NORMAL]


def test_gnb_wrong_dist_length():
"""
Test whether an error is raised if the number of distributions is different from the number of features.
"""
clf = GeneralNB(distributions=[D.NORMAL] * (X.shape[1] + 1))
msg = "Number of specified distributions must match the number of features"
with pytest.raises(ValueError, match=msg):
clf.fit(X, y)


@pytest.mark.parametrize(
"clf",
[
GeneralNB(distributions="Normal"),
GeneralNB(distributions={"Normal": 0}),
GeneralNB(distributions=[(D.NORMAL, [0]), D.BERNOULLI]),
GeneralNB(distributions=["Normal", (D.BERNOULLI, [1])]),
],
)
def test_gnb_wrong_dist_value(clf: GeneralNB):
"""
Test whether an error is raised if invalid value is provided for the distributions parameter.
"""
msg_1 = "distributions parameter must be a sequence of distributions or a sequence of tuples"
msg_2 = "The 'distributions' parameter of GeneralNB must be an array-like or None"
with pytest.raises(ValueError, match=rf"{msg_1}|{msg_2}"):
clf.fit(X, y)


class InvalidDistA:
def __call__(self, *args, **kwargs):
return 0.0


class InvalidDistB(InvalidDistA):
@classmethod
def from_data(cls, *args, **kwargs):
return cls()


@pytest.mark.parametrize(
"clf",
[
GeneralNB(distributions=["Normal", "Borel"]),
GeneralNB(distributions=[(D.NORMAL, 0), ("Weibull", [1])]),
GeneralNB(distributions=[D.NORMAL, InvalidDistA]),
GeneralNB(distributions=[(InvalidDistA, 0), (InvalidDistB, 1)]),
],
)
def test_gnb_unsupported_dist(clf: GeneralNB):
"""
Test whether an error is raised if an unsupported distribution is provided.
"""

msg = r"Distribution .* is not supported"
with pytest.raises(ValueError, match=msg):
clf.fit(X, y)


@pytest.mark.parametrize("dist", [D.BERNOULLI, D.POISSON, D.CATEGORICAL])
def test_gnb_dist_identical(dist: D):
"""
Test whether GeneralNB returns the same outputs when identical distributions are specified in different formats.
"""

rng = np.random.RandomState(10)
X_ = rng.randint(2, size=(150, 100))
y_ = rng.randint(1, 5, size=(150,))

clf1 = GeneralNB(distributions=[dist] * 100)
clf1.fit(X_, y_)

clf2 = GeneralNB(distributions=[(dist, [i for i in range(100)])])
clf2.fit(X_, y_)

df_X = pd.DataFrame(X_, columns=[f"x{i}" for i in range(100)])
clf3 = GeneralNB(distributions=[(dist, [f"x{i}" for i in range(100)])])
clf3.fit(df_X, y_)

y_pred1 = clf1.predict(X_[10:15])
y_pred2 = clf2.predict(X_[10:15])
y_pred3 = clf3.predict(df_X.iloc[10:15])
assert np.array_equal(y_pred1, y_pred2)
assert np.array_equal(y_pred2, y_pred3)

y_pred_proba1 = clf1.predict_proba(X_[10:15])
y_pred_proba2 = clf2.predict_proba(X_[10:15])
y_pred_proba3 = clf3.predict_proba(df_X.iloc[10:15])
assert np.array_equal(y_pred_proba1, y_pred_proba2)
assert np.array_equal(y_pred_proba2, y_pred_proba3)

y_pred_log_proba1 = clf1.predict_log_proba(X_[10:15])
y_pred_log_proba2 = clf2.predict_log_proba(X_[10:15])
y_pred_log_proba3 = clf3.predict_log_proba(df_X.iloc[10:15])
assert np.array_equal(y_pred_log_proba1, y_pred_log_proba2)
assert np.array_equal(y_pred_log_proba2, y_pred_log_proba3)


def test_gnb_prior(global_random_seed: int):
"""
Test whether class priors are properly set.
"""
Expand Down Expand Up @@ -239,17 +356,6 @@ def test_gnb_wrong_nb_dist():
clf.fit(X, y)


def test_gnb_invalid_dist():
"""
Test whether an error is raised if an invalid distribution is provided.
"""
clf = GeneralNB(distributions=["Normal", "Borel"])

msg = r"Distribution .* is not supported"
with pytest.raises(ValueError, match=msg):
clf.fit(X, y)


def test_gnb_var_smoothing():
"""
Test whether var_smoothing parameter properly affects the variances of normal distributions.
Expand Down
Loading

0 comments on commit 18709d4

Please sign in to comment.