Skip to content

Commit

Permalink
Merge pull request #48 from msamsami/add-examples-doc-improve-error-w…
Browse files Browse the repository at this point in the history
…eights-and-prior

feat: add docstring examples, improve error weights and class prior initialization
  • Loading branch information
msamsami authored Feb 12, 2025
2 parents 348d21d + ba54105 commit 47b5422
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ or
# Columns not explicitly specified will default to Gaussian (normal) distribution
gnb = GeneralNB(
distributions=[
(D.CATEGORICAL, "col2"),
(D.CATEGORICAL, [1]),
(D.EXPONENTIAL, ["col3", "col4"]),
],
)
Expand Down
2 changes: 1 addition & 1 deletion examples/gnb_wine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
print("sklearn | GaussianNB >> score >>", gnb.score(X_test, y_test))

# Train and score wnb GeneralNB classifier with Log-normal likelihoods
gnb = GeneralNB(distributions=[D.LOGNORMAL] * X.shape[1])
gnb = GeneralNB(distributions=[(D.LOGNORMAL, range(X.shape[1]))])
gnb.fit(X_train, y_train)
print("wnb | GeneralNB >> score >>", gnb.score(X_test, y_test))
36 changes: 27 additions & 9 deletions wnb/gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,27 @@ class GeneralNB(_BaseNB):
likelihood_params_ : dict
A mapping from class labels to their fitted likelihood distributions.
Examples
--------
>>> import numpy as np
>>> X = np.array([[-1, 1], [-2, 1], [-3, 2], [1, 1], [2, 1], [3, 2]])
>>> Y = np.array([1, 1, 1, 2, 2, 2])
>>> from wnb import GeneralNB, Distribution as D
>>> clf = GeneralNB(distributions=[D.NORMAL, D.POISSON])
>>> clf.fit(X, Y)
GeneralNB(distributions=[<Distribution.NORMAL: 'Normal'>,
<Distribution.POISSON: 'Poisson'>])
>>> print(clf.predict([[-0.8, 1]]))
[1]
>>> X = np.array([[-1, 1, 1], [-2, 1, 1], [-3, 2, 2], [1, 1, 1], [2, 1, 1], [3, 2, 2]])
>>> Y = np.array([-1, -1, -1, 1, 1, 1])
>>> clf_2 = GeneralNB(distributions=[(D.NORMAL, [0, 2]), (D.POISSON, [1])])
>>> clf_2.fit(X, Y)
GeneralNB(distributions=[(<Distribution.NORMAL: 'Normal'>, [0, 2]),
(<Distribution.POISSON: 'Poisson'>, [1])])
>>> print(clf_2.predict([[-0.8, 1, 1]]))
[-1]
"""

if parameter_constraints := _get_parameter_constraints():
Expand Down Expand Up @@ -194,26 +215,23 @@ def _find_dist(
return Distribution.NORMAL

def _init_parameters(self) -> None:
# Set priors if not specified
if self.priors is None:
# Calculate empirical prior probabilities
self.class_prior_ = self.class_count_ / self.class_count_.sum()
else:
priors = np.asarray(self.priors)

# Check that the provided priors match the number of classes
if len(self.priors) != self.n_classes_:
if len(priors) != self.n_classes_:
raise ValueError("Number of priors must match the number of classes.")
# Check that the sum of priors is 1
if not np.isclose(self.priors.sum(), 1.0):
if not np.isclose(priors.sum(), 1.0):
raise ValueError("The sum of the priors should be 1.")
# Check that the priors are non-negative
if (self.priors < 0).any():
if (priors < 0).any():
raise ValueError("Priors must be non-negative.")

self.class_prior_ = self.priors

# Convert to NumPy array if input priors is in a list/tuple/set
if isinstance(self.class_prior_, (list, tuple, set)):
self.class_prior_ = np.array(list(self.class_prior_))
self.class_prior_ = priors

distributions_error_msg = "distributions parameter must be a sequence of distributions or a sequence of tuples of (distribution, column_key)"
if self.distributions is None:
Expand Down
42 changes: 31 additions & 11 deletions wnb/gwnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,25 @@ class GaussianWNB(_BaseNB):
n_iter_ : int
Number of iterations run by the optimization routine to fit the model.
Examples
--------
>>> import numpy as np
>>> X = np.array([[-1, 1], [-2, 1], [-3, 2], [1, 1], [2, 1], [3, 2]])
>>> Y = np.array([1, 1, 1, 2, 2, 2])
>>> from wnb import GaussianWNB
>>> clf = GaussianWNB()
>>> clf.fit(X, Y)
GaussianWNB()
>>> print(clf.predict([[-0.8, 1]]))
[1]
>>> X = np.array([[1, 3], [-1, 2], [2, 1], [3, 0], [1, 0.5], [-2, 1], [2, -1], [0, 0]])
>>> Y = np.array([-1, -1, 1, 1, 1, 1, 1, 1])
>>> clf_2 = GaussianWNB(error_weights=[[0, 3], [-1, 0]], max_iter=20, step_size=0.1)
>>> clf_2.fit(X, Y)
GaussianWNB(error_weights=[[0, 3], [-1, 0]], max_iter=20, step_size=0.1)
>>> print(clf_2.predict([[-1, 1.75]]))
[-1]
"""

if parameter_constraints := _get_parameter_constraints():
Expand All @@ -127,7 +146,7 @@ def __init__(
self,
*,
priors: Optional[ArrayLike] = None,
error_weights: Optional[np.ndarray] = None,
error_weights: Optional[ArrayLike] = None,
max_iter: Int = 25,
step_size: Float = 1e-4,
penalty: str = "l2",
Expand Down Expand Up @@ -198,33 +217,34 @@ def _init_parameters(self) -> None:
# Calculate empirical prior probabilities
self.class_prior_ = self.class_count_ / self.class_count_.sum()
else:
priors = np.asarray(self.priors)

# Check that the provided priors match the number of classes
if len(self.priors) != self.n_classes_:
if len(priors) != self.n_classes_:
raise ValueError("Number of priors must match the number of classes.")
# Check that the sum of priors is 1
if not np.isclose(self.priors.sum(), 1.0):
if not np.isclose(priors.sum(), 1.0):
raise ValueError("The sum of the priors should be 1.")
# Check that the priors are non-negative
if (self.priors < 0).any():
if (priors < 0).any():
raise ValueError("Priors must be non-negative.")

self.class_prior_ = self.priors

# Convert to NumPy array if input priors is in a list/tuple/set
if isinstance(self.class_prior_, (list, tuple, set)):
self.class_prior_ = np.array(list(self.class_prior_))
self.class_prior_ = priors

if self.error_weights is None:
# Assign equal weight to the errors of both classes
self.error_weights_ = np.array([[0, 1], [-1, 0]])
else:
error_weights = np.asarray(self.error_weights)

# Ensure the size of error weights matrix matches number of classes
if self.error_weights.shape != (self.n_classes_, self.n_classes_):
if error_weights.shape != (self.n_classes_, self.n_classes_):
raise ValueError(
"The shape of error weights matrix does not match the number of classes, "
"must be (n_classes, n_classes)."
)
self.error_weights_ = self.error_weights

self.error_weights_ = error_weights

# Ensure regularization type is either 'l1' or 'l2'
if self.penalty not in ("l1", "l2"):
Expand Down

0 comments on commit 47b5422

Please sign in to comment.