Skip to content

Commit

Permalink
Merge pull request #8 from msamsami/dist-support
Browse files Browse the repository at this point in the history
Add attributes to mixins for distribution support and type
  • Loading branch information
msamsami authored May 13, 2023
2 parents ae476fb + e793d39 commit bcdb827
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# WNB: General and weighted naive Bayes classifiers

![](https://img.shields.io/badge/version-v0.1.7-green)
![](https://img.shields.io/badge/version-v0.1.8-green)
![](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue)

<p>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='wnb',
version='0.1.7',
version='0.1.8',
description='Python library for the implementations of general and weighted naive Bayes (WNB) classifiers.',
keywords=['python', 'bayes', 'naivebayes', 'classifier', 'probabilistic'],
author='Mehdi Samsami',
Expand Down
2 changes: 1 addition & 1 deletion wnb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.7"
__version__ = "0.1.8"
__author__ = "Mehdi Samsami"

__all__ = [
Expand Down
30 changes: 29 additions & 1 deletion wnb/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import inspect
from functools import wraps
from numbers import Number
from typing import List, Tuple, Union
import warnings

import numpy as np

from ._enums import Distribution

__all__ = [
'ContinuousDistMixin',
'DiscreteDistMixin'
Expand Down Expand Up @@ -33,7 +37,8 @@ class DistMixin(metaclass=ABCMeta):
Mixin class for probability distributions in wnb.
"""

name = None
name: Union[str, Distribution] = None
_support: Union[List[float], Tuple[float, float]] = None

@classmethod
def from_data(cls, data):
Expand Down Expand Up @@ -84,6 +89,23 @@ def get_params(self) -> dict:
out[key] = value
return out

@property
def support(self) -> Union[List[float], Tuple[float, float]]:
"""Returns the support of the probability distribution.
If support is a list, the support is a limited number of discrete values. If it is a tuple, it indicates a
limited set/range of continuous values.
"""
return self._support

def _check_support(self, x):
if (isinstance(self.support, list) and x not in self.support) or \
(isinstance(self.support, tuple) and (x < self.support[0] or x > self.support[1])):
warnings.warn("Value doesn't lie within the support of the distribution", RuntimeWarning)
else:
pass

def __repr__(self) -> str:
return "".join([
"<",
Expand All @@ -101,6 +123,8 @@ class ContinuousDistMixin(DistMixin, metaclass=ABCMeta):
Mixin class for all continuous probability distributions in wnb.
"""

_type = "continuous"

def __init__(self, **kwargs):
"""Initializes an instance of the continuous probability distribution with given parameters.
Expand All @@ -120,6 +144,7 @@ def pdf(self, x: float) -> float:

@vectorize(signature="(),()->()")
def __call__(self, x: float) -> float:
self._check_support(x)
return self.pdf(x)


Expand All @@ -128,6 +153,8 @@ class DiscreteDistMixin(DistMixin, metaclass=ABCMeta):
Mixin class for all discrete probability distributions in wnb.
"""

_type = "discrete"

def __init__(self, **kwargs):
"""Initializes an instance of the discrete probability distribution with given parameters.
Expand All @@ -147,4 +174,5 @@ def pmf(self, x: float) -> float:

@vectorize(signature="(),()->()")
def __call__(self, x: float) -> float:
self._check_support(x)
return self.pmf(x)
21 changes: 11 additions & 10 deletions wnb/dist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Any, Mapping, Sequence
import warnings

import numpy as np
from scipy.special import gamma
Expand All @@ -25,6 +24,7 @@

class NormalDist(ContinuousDistMixin):
name = D.NORMAL
_support = (-np.inf, np.inf)

def __init__(self, mu: float, sigma: float):
self.mu = mu
Expand All @@ -41,6 +41,7 @@ def pdf(self, x: float) -> float:

class LognormalDist(ContinuousDistMixin):
name = D.LOGNORMAL
_support = (0, np.inf)

def __init__(self, mu: float, sigma: float):
self.mu = mu
Expand All @@ -59,6 +60,7 @@ def pdf(self, x: float) -> float:

class ExponentialDist(ContinuousDistMixin):
name = D.EXPONENTIAL
_support = (0, np.inf)

def __init__(self, rate: float):
self.rate = rate
Expand All @@ -78,6 +80,7 @@ class UniformDist(ContinuousDistMixin):
def __init__(self, a: float, b: float):
self.a = a
self.b = b
self._support = (a, b)
super().__init__()

@classmethod
Expand All @@ -94,6 +97,7 @@ class ParetoDist(ContinuousDistMixin):
def __init__(self, x_m: float, alpha: float):
self.x_m = x_m
self.alpha = alpha
self._support = (self.x_m, np.inf)
super().__init__()

@classmethod
Expand All @@ -107,6 +111,7 @@ def pdf(self, x: float) -> float:

class GammaDist(ContinuousDistMixin):
name = D.GAMMA
_support = (0, np.inf)

def __init__(self, k: float, theta: float):
self.k = k
Expand All @@ -127,16 +132,14 @@ def pdf(self, x: float) -> float:

class BernoulliDist(DiscreteDistMixin):
name = D.BERNOULLI
_support = [0, 1]

def __init__(self, p: float):
self.p = p
super().__init__()

@classmethod
def from_data(cls, data):
if any(x not in [0, 1] for x in data):
warnings.warn("Bernoulli data points should be either 0 or 1")

return cls(p=(np.array(data) == 1).sum() / len(data))

def pmf(self, x: int) -> float:
Expand All @@ -148,6 +151,7 @@ class CategoricalDist(DiscreteDistMixin):

def __init__(self, prob: Mapping[Any, float]):
self.prob = prob
self._support = list(self.prob.keys())
super().__init__()

@classmethod
Expand All @@ -165,13 +169,11 @@ class MultinomialDist(DiscreteDistMixin):
def __init__(self, n: int, prob: Mapping[Any, float]):
self.n = n
self.prob = prob
self._support = [i for i in range(self.n+1)]
super().__init__()

@classmethod
def from_data(cls, data: Sequence[int]):
if any(not isinstance(x, int) or x < 0 for x in data):
warnings.warn("Multinomial data points should be integers greater than or equal to 0")

values, counts = np.unique(data, return_counts=True)
return cls(n=int(np.sum(values)), prob={v: c / len(data) for v, c in zip(values, counts)})

Expand All @@ -185,16 +187,14 @@ def pmf(self, x: Sequence[int]) -> float:

class GeometricDist(DiscreteDistMixin):
name = D.GEOMETRIC
_support = (1, np.inf)

def __init__(self, p: float):
self.p = p
super().__init__()

@classmethod
def from_data(cls, data):
if any(x < 1 for x in data):
warnings.warn("Geometric data points should be greater than or equal to 1")

return cls(p=len(data) / np.sum(data))

def pmf(self, x: int) -> float:
Expand All @@ -203,6 +203,7 @@ def pmf(self, x: int) -> float:

class PoissonDist(DiscreteDistMixin):
name = D.POISSON
_support = (0, np.inf)

def __init__(self, rate: float):
self.rate = rate
Expand Down

0 comments on commit bcdb827

Please sign in to comment.