Skip to content

Commit

Permalink
add color cpd
Browse files Browse the repository at this point in the history
  • Loading branch information
neka-nat committed May 10, 2024
1 parent 9c838b7 commit 3f49685
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 34 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ This package implements several algorithms using stochastic models and provides
* Maximum likelihood when the target or source point cloud is observation data
* [Coherent Point Drift (2010)](https://arxiv.org/pdf/0905.2635.pdf)
* [Extended Coherent Point Drift (2016)](https://ieeexplore.ieee.org/abstract/document/7477719) (add correspondence priors to CPD)
* [Color Coherent Point Drift (2018)](https://arxiv.org/pdf/1802.01516)
* [FilterReg (CVPR2019)](https://arxiv.org/pdf/1811.10136.pdf)
* Variational Bayesian inference
* [Bayesian Coherent Point Drift (2020)](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8985307)
Expand Down
3 changes: 1 addition & 2 deletions probreg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from . import (bcpd, callbacks, cpd, filterreg, gmmtree, l2dist_regs, log,
math_utils, transformation)
from . import bcpd, callbacks, cpd, filterreg, gmmtree, l2dist_regs, log, math_utils, transformation
from .version import __version__
6 changes: 3 additions & 3 deletions probreg/bcpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def expectation_step(self, t_source, target, scale, alpha, sigma_mat, sigma2, w=
pmat = np.exp(-pmat / (2.0 * sigma2))
pmat /= (2.0 * np.pi * sigma2) ** (dim * 0.5)
pmat = pmat.T
pmat *= np.exp(-(scale ** 2) / (2 * sigma2) * np.diag(sigma_mat) * dim)
pmat *= np.exp(-(scale**2) / (2 * sigma2) * np.diag(sigma_mat) * dim)
pmat *= (1.0 - w) * alpha
den = w / target.shape[0] + np.sum(pmat, axis=1)
den[den == 0] = np.finfo(np.float32).eps
Expand Down Expand Up @@ -126,7 +126,7 @@ def _maximization_step(source, target, rigid_trans, estep_res, gmat_inv, lmd, k,
nu_d, nu, n_p, px, x_hat = estep_res
dim = source.shape[1]
m = source.shape[0]
s2s2 = rigid_trans.scale ** 2 / (sigma2_p ** 2)
s2s2 = rigid_trans.scale**2 / (sigma2_p**2)
sigma_mat_inv = lmd * gmat_inv + s2s2 * np.diag(nu)
sigma_mat = np.linalg.inv(sigma_mat_inv)
residual = rigid_trans.inverse().transform(x_hat) - source
Expand All @@ -152,7 +152,7 @@ def _maximization_step(source, target, rigid_trans, estep_res, gmat_inv, lmd, k,
s1 = np.dot(target.ravel(), np.kron(nu_d, np.ones(dim)) * target.ravel())
s2 = np.dot(px.ravel(), y_hat.ravel())
s3 = np.dot(y_hat.ravel(), np.kron(nu, np.ones(dim)) * y_hat.ravel())
sigma2 = (s1 - 2.0 * s2 + s3) / (n_p * dim) + scale ** 2 * sigma2_m
sigma2 = (s1 - 2.0 * s2 + s3) / (n_p * dim) + scale**2 * sigma2_m
return MstepResult(tf.CombinedTransformation(rot, t, scale, v_hat), u_hat, sigma_mat, alpha, sigma2)


Expand Down
4 changes: 2 additions & 2 deletions probreg/cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def __call__(self, theta: np.ndarray, *args):
def compute_l2_dist(
mu_source: np.ndarray, phi_source: np.ndarray, mu_target: np.ndarray, phi_target: np.ndarray, sigma: float
):
z = np.power(2.0 * np.pi * sigma ** 2, mu_source.shape[1] * 0.5)
z = np.power(2.0 * np.pi * sigma**2, mu_source.shape[1] * 0.5)
gtrans = gt.GaussTransform(mu_target, np.sqrt(2.0) * sigma)
phi_j_e = gtrans.compute(mu_source, phi_target / z)
phi_mu_j_e = gtrans.compute(mu_source, phi_target * mu_target.T / z).T
g = (phi_source * phi_j_e * mu_source.T - phi_source * phi_mu_j_e.T).T / (2.0 * sigma ** 2)
g = (phi_source * phi_j_e * mu_source.T - phi_source * phi_mu_j_e.T).T / (2.0 * sigma**2)
return -np.dot(phi_source, phi_j_e), g


Expand Down
94 changes: 71 additions & 23 deletions probreg/cpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ class CoherentPointDrift:
Args:
source (numpy.ndarray, optional): Source point cloud data.
use_color (bool, optional): Use color information (if available).
use_cuda (bool, optional): Use CUDA.
"""

def __init__(self, source: Optional[np.ndarray] = None, use_cuda: bool = False) -> None:
def __init__(self, source: Optional[np.ndarray] = None, use_color: bool = False, use_cuda: bool = False) -> None:
self._source = source
self._tf_type = None
self._callbacks = []
self._use_color = use_color
if use_cuda:
import cupy as cp
from cupyx.scipy.spatial import distance as cupy_distance
Expand All @@ -68,11 +70,8 @@ def set_callbacks(self, callbacks: List[Callable]) -> None:
def _initialize(self, target: np.ndarray) -> MstepResult:
return MstepResult(None, None, None)

def expectation_step(self, t_source: np.ndarray, target: np.ndarray, sigma2: float, w: float = 0.0) -> EstepResult:
"""Expectation step for CPD"""
assert t_source.ndim == 2 and target.ndim == 2, "source and target must have 2 dimensions."
def _compute_pmat(self, t_source: np.ndarray, target: np.ndarray, sigma2: float, w: float) -> np.ndarray:
pmat = self.distance_module.cdist(t_source, target, "sqeuclidean")
# pmat = self.xp.stack([self.xp.sum(self.xp.square(target - ts), axis=1) for ts in t_source])
pmat = self.xp.exp(-pmat / (2.0 * sigma2))

c = (2.0 * np.pi * sigma2) ** (t_source.shape[1] * 0.5)
Expand All @@ -81,7 +80,23 @@ def expectation_step(self, t_source: np.ndarray, target: np.ndarray, sigma2: flo
den[den == 0] = self.xp.finfo(np.float32).eps
den += c

pmat = self.xp.divide(pmat, den)
return self.xp.divide(pmat, den)

def expectation_step(
self,
t_source: np.ndarray,
target: np.ndarray,
sigma2: float,
sigma2_c: float,
w: float = 0.0,
w_c: float = 0.0,
) -> EstepResult:
"""Expectation step for CPD"""
assert t_source.ndim == 2 and target.ndim == 2, "source and target must have 2 dimensions."
pmat = self._compute_pmat(t_source[:, :3], target[:, :3], sigma2, w)
if self._use_color:
pmat_c = self._compute_pmat(t_source[:, 3:], target[:, 3:], sigma2_c, w_c)
pmat = self.xp.multiply(pmat, pmat_c)
pt1 = self.xp.sum(pmat, axis=0)
p1 = self.xp.sum(pmat, axis=1)
px = self.xp.dot(pmat, target)
Expand All @@ -90,7 +105,7 @@ def expectation_step(self, t_source: np.ndarray, target: np.ndarray, sigma2: flo
def maximization_step(
self, target: np.ndarray, estep_res: EstepResult, sigma2_p: Optional[float] = None
) -> Optional[MstepResult]:
return self._maximization_step(self._source, target, estep_res, sigma2_p, xp=self.xp)
return self._maximization_step(self._source[:, :3], target[:, :3], estep_res, sigma2_p, xp=self.xp)

@staticmethod
@abc.abstractmethod
Expand All @@ -103,13 +118,18 @@ def _maximization_step(
) -> Optional[MstepResult]:
return None

def registration(self, target: np.ndarray, w: float = 0.0, maxiter: int = 50, tol: float = 0.001) -> MstepResult:
def registration(
self, target: np.ndarray, w: float = 0.0, w_c: float = 0.0, maxiter: int = 50, tol: float = 0.001
) -> MstepResult:
assert not self._tf_type is None, "transformation type is None."
res = self._initialize(target)
res = self._initialize(target[:, :3])
sigma2_c = 0.0
if self._use_color:
sigma2_c = self._squared_kernel_sum(self._source[:, 3:], target[:, 3:])
q = res.q
for i in range(maxiter):
t_source = res.transformation.transform(self._source)
estep_res = self.expectation_step(t_source, target, res.sigma2, w)
estep_res = self.expectation_step(t_source, target, res.sigma2, sigma2_c, w, w_c)
res = self.maximization_step(target, estep_res, res.sigma2)
for c in self._callbacks:
c(res.transformation)
Expand All @@ -127,6 +147,7 @@ class RigidCPD(CoherentPointDrift):
source (numpy.ndarray, optional): Source point cloud data.
update_scale (bool, optional): If this flag is True, compute the scale parameter.
tf_init_params (dict, optional): Parameters to initialize transformation.
use_color (bool, optional): Use color information (if available).
use_cuda (bool, optional): Use CUDA.
"""

Expand All @@ -135,9 +156,10 @@ def __init__(
source: Optional[np.ndarray] = None,
update_scale: bool = True,
tf_init_params: Dict = {},
use_color: bool = False,
use_cuda: bool = False,
) -> None:
super(RigidCPD, self).__init__(source, use_cuda)
super(RigidCPD, self).__init__(source, use_color, use_cuda)
self._tf_type = tf.RigidTransformation
self._update_scale = update_scale
self._tf_init_params = tf_init_params
Expand Down Expand Up @@ -187,7 +209,7 @@ def _maximization_step(
else:
sigma2 = (tr_xp1x + tr_yp1y - scale * tr_atr) / (n_p * dim)
sigma2 = max(sigma2, np.finfo(np.float32).eps)
q = (tr_xp1x - 2.0 * scale * tr_atr + (scale ** 2) * tr_yp1y) / (2.0 * sigma2)
q = (tr_xp1x - 2.0 * scale * tr_atr + (scale**2) * tr_yp1y) / (2.0 * sigma2)
q += dim * n_p * 0.5 * np.log(sigma2)
return MstepResult(tf.RigidTransformation(rot, t, scale, xp=xp), sigma2, q)

Expand All @@ -198,11 +220,18 @@ class AffineCPD(CoherentPointDrift):
Args:
source (numpy.ndarray, optional): Source point cloud data.
tf_init_params (dict, optional): Parameters to initialize transformation.
use_color (bool, optional): Use color information (if available).
use_cuda (bool, optional): Use CUDA.
"""

def __init__(self, source: Optional[np.ndarray] = None, tf_init_params: Dict = {}, use_cuda: bool = False) -> None:
super(AffineCPD, self).__init__(source, use_cuda)
def __init__(
self,
source: Optional[np.ndarray] = None,
tf_init_params: Dict = {},
use_color: bool = False,
use_cuda: bool = False,
) -> None:
super(AffineCPD, self).__init__(source, use_color, use_cuda)
self._tf_type = tf.AffineTransformation
self._tf_init_params = tf_init_params

Expand Down Expand Up @@ -251,13 +280,19 @@ class NonRigidCPD(CoherentPointDrift):
source (numpy.ndarray, optional): Source point cloud data.
beta (float, optional): Parameter of RBF kernel.
lmd (float, optional): Parameter for regularization term.
use_color (bool, optional): Use color information (if available).
use_cuda (bool, optional): Use CUDA.
"""

def __init__(
self, source: Optional[np.ndarray] = None, beta: float = 2.0, lmd: float = 2.0, use_cuda: bool = False
self,
source: Optional[np.ndarray] = None,
beta: float = 2.0,
lmd: float = 2.0,
use_color: bool = False,
use_cuda: bool = False,
) -> None:
super(NonRigidCPD, self).__init__(source, use_cuda)
super(NonRigidCPD, self).__init__(source, use_color, use_cuda)
self._tf_type = tf.NonRigidTransformation
self._beta = beta
self._lmd = lmd
Expand Down Expand Up @@ -316,6 +351,7 @@ class ConstrainedNonRigidCPD(CoherentPointDrift):
alpha (float): Degree of reliability of priors.
Approximately between 1e-8 (highly reliable) and 1 (highly unreliable)
use_cuda (bool, optional): Use CUDA.
use_color (bool, optional): Use color information (if available).
idx_source (numpy.ndarray of ints, optional): Indices in source matrix
for which a correspondance is known
idx_target (numpy.ndarray of ints, optional): Indices in target matrix
Expand All @@ -328,11 +364,12 @@ def __init__(
beta: float = 2.0,
lmd: float = 2.0,
alpha: float = 1e-8,
use_color: bool = False,
use_cuda: bool = False,
idx_source: Optional[np.ndarray] = None,
idx_target: Optional[np.ndarray] = None,
):
super(ConstrainedNonRigidCPD, self).__init__(source, use_cuda)
super(ConstrainedNonRigidCPD, self).__init__(source, use_color, use_cuda)
self._tf_type = tf.NonRigidTransformation
self._beta = beta
self._lmd = lmd
Expand Down Expand Up @@ -409,9 +446,11 @@ def registration_cpd(
target: Union[np.ndarray, o3.geometry.PointCloud],
tf_type_name: str = "rigid",
w: float = 0.0,
w_c: float = 0.0,
maxiter: int = 50,
tol: float = 0.001,
callbacks: List[Callable] = [],
use_color: bool = False,
use_cuda: bool = False,
**kwargs: Any,
) -> MstepResult:
Expand All @@ -422,10 +461,12 @@ def registration_cpd(
target (numpy.ndarray): Target point cloud data.
tf_type_name (str, optional): Transformation type('rigid', 'affine', 'nonrigid', 'nonrigid_constrained')
w (float, optional): Weight of the uniform distribution, 0 < `w` < 1.
w_c (float, optional): Weight of the color uniform distribution, 0 < `w_c` < 1.
maxitr (int, optional): Maximum number of iterations to EM algorithm.
tol (float, optional): Tolerance for termination.
callback (:obj:`list` of :obj:`function`, optional): Called after each iteration.
`callback(probreg.Transformation)`
use_color (bool, optional): Use color information (if available).
use_cuda (bool, optional): Use CUDA.
Keyword Args:
Expand All @@ -441,16 +482,23 @@ def registration_cpd(
import cupy as cp

xp = cp
cv = lambda x: xp.asarray(x.points if isinstance(x, o3.geometry.PointCloud) else x)
if use_color:
cv = (
lambda x: xp.c_[xp.asarray(x.points), xp.asarray(x.colors)]
if isinstance(x, o3.geometry.PointCloud)
else xp.asanyarray(x)[:, :6]
)
else:
cv = lambda x: xp.asarray(x.points if isinstance(x, o3.geometry.PointCloud) else x)[:, :3]
if tf_type_name == "rigid":
cpd = RigidCPD(cv(source), use_cuda=use_cuda, **kwargs)
cpd = RigidCPD(cv(source), use_color=use_color, use_cuda=use_cuda, **kwargs)
elif tf_type_name == "affine":
cpd = AffineCPD(cv(source), use_cuda=use_cuda, **kwargs)
cpd = AffineCPD(cv(source), use_color=use_color, use_cuda=use_cuda, **kwargs)
elif tf_type_name == "nonrigid":
cpd = NonRigidCPD(cv(source), use_cuda=use_cuda, **kwargs)
cpd = NonRigidCPD(cv(source), use_color=use_color, use_cuda=use_cuda, **kwargs)
elif tf_type_name == "nonrigid_constrained":
cpd = ConstrainedNonRigidCPD(cv(source), use_cuda=use_cuda, **kwargs)
cpd = ConstrainedNonRigidCPD(cv(source), use_color=use_color, use_cuda=use_cuda, **kwargs)
else:
raise ValueError("Unknown transformation type %s" % tf_type_name)
cpd.set_callbacks(callbacks)
return cpd.registration(cv(target), w, maxiter, tol)
return cpd.registration(cv(target), w, w_c, maxiter, tol)
2 changes: 1 addition & 1 deletion probreg/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def init(self):

def compute(self, data: np.ndarray):
self._clf.fit(data)
z = np.power(2.0 * np.pi * self._sigma ** 2, self._dim * 0.5)
z = np.power(2.0 * np.pi * self._sigma**2, self._dim * 0.5)
return self._clf.support_vectors_, self._clf.dual_coef_[0] * z

def annealing(self):
Expand Down
2 changes: 1 addition & 1 deletion probreg/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, xp=np):
def transform(self, points, array_type=o3.utility.Vector3dVector):
if isinstance(points, array_type):
return array_type(self._transform(self.xp.asarray(points)))
return self._transform(points)
return self.xp.c_[self._transform(points[:, :3]), points[:, 3:]]

@abc.abstractmethod
def _transform(self, points):
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ pyyaml = "^6.0"
addict = "^2.4.0"
pandas = "^2.0.0"

[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
Sphinx = "^3.4.3"
flake8 = "^3.8.4"
sphinx-rtd-theme = "^0.5.1"
twine = "^3.3.0"
setuptools = "^52.0.0"
isort = "^5.9.3"
black = "^21.9b0"
black = "22.3.0"

[build-system]
requires = ["poetry-core>=1.0.0", "setuptools", "pybind11"]
Expand Down

0 comments on commit 3f49685

Please sign in to comment.