Skip to content

Commit

Permalink
major core refactor in gaussian wnb classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
msamsami committed Jan 19, 2025
1 parent 1bf3803 commit 00b36a1
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 225 deletions.
23 changes: 23 additions & 0 deletions tests/test_gwnb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re

import numpy as np
import pandas as pd
import pytest
from sklearn.base import is_classifier
from sklearn.utils._testing import assert_array_almost_equal, assert_array_equal
Expand Down Expand Up @@ -194,3 +195,25 @@ def test_gwnb_no_cost_hist():
clf = GaussianWNB(max_iter=10)
clf.fit(X, y)
assert clf.cost_hist_ is None


def test_gwnb_attrs():
"""
Test whether the attributes are properly set.
"""
clf = GaussianWNB().fit(X, y)
assert np.array_equal(clf.class_count_, np.array([3, 3]))
assert np.array_equal(clf.class_prior_, np.array([0.5, 0.5]))
assert np.array_equal(clf.classes_, np.array([1, 2]))
assert clf.n_classes_ == 2
assert clf.n_features_in_ == 2
assert not hasattr(clf, "feature_names_in_")
assert np.array_equal(clf.error_weights_, np.array([[0, 1], [-1, 0]]))
assert clf.theta_.shape == (2, 2)
assert clf.std_.shape == (2, 2)
assert clf.var_.shape == (2, 2)
assert clf.coef_.shape == (2,)

feature_names = [f"x{i}" for i in range(X.shape[1])]
clf = GaussianWNB().fit(pd.DataFrame(X, columns=feature_names), y)
assert np.array_equal(clf.feature_names_in_, np.array(feature_names))
10 changes: 3 additions & 7 deletions wnb/gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,10 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]:
return X, y

def _init_parameters(self) -> None:
self.class_prior_: np.ndarray

# Set priors if not specified
if self.priors is None:
self.class_prior_ = (
self.class_count_ / self.class_count_.sum()
) # Calculate empirical prior probabilities

# Calculate empirical prior probabilities
self.class_prior_ = self.class_count_ / self.class_count_.sum()
else:
# Check that the provided priors match the number of classes
if len(self.priors) != self.n_classes_:
Expand Down Expand Up @@ -213,7 +209,7 @@ def fit(self, X: MatrixLike, y: ArrayLike) -> Self:
X, y = self._check_X_y(X, y)

self.classes_, y_, self.class_count_ = np.unique(y, return_counts=True, return_inverse=True)
self.n_classes_: int = len(self.classes_)
self.n_classes_ = len(self.classes_)

self._init_parameters()

Expand Down
Loading

0 comments on commit 00b36a1

Please sign in to comment.