diff --git a/probreg/cpd.py b/probreg/cpd.py index 60da613..306b779 100644 --- a/probreg/cpd.py +++ b/probreg/cpd.py @@ -284,7 +284,7 @@ def _maximization_step( sigma2 = max(sigma2, np.finfo(np.float32).eps) q = (tr_xp1x - 2 * tr_ab + tr_xpyb) / (2.0 * sigma2) q += dim * n_p * 0.5 * np.log(sigma2) - return MstepResult(tf.AffineTransformation(b, t), sigma2, q) + return MstepResult(tf.AffineTransformation(b, t, xp=xp), sigma2, q) class NonRigidCPD(CoherentPointDrift): @@ -316,7 +316,7 @@ def __init__( def set_source(self, source: np.ndarray) -> None: self._source = source - self._tf_obj = self._tf_type(None, self._source, self._beta) + self._tf_obj = self._tf_type(None, self._source, self._beta, xp=self.xp) def maximization_step( self, target: np.ndarray, estep_res: EstepResult, sigma2_p: Optional[float] = None diff --git a/probreg/transformation.py b/probreg/transformation.py index 088ec7b..1453777 100644 --- a/probreg/transformation.py +++ b/probreg/transformation.py @@ -22,7 +22,13 @@ def __init__(self, xp=np): def transform(self, points, array_type=o3.utility.Vector3dVector): if isinstance(points, array_type): - return array_type(self._transform(self.xp.asarray(points))) + _array_after_transform = self._transform(self.xp.asarray(points)) + + if self.xp.__name__ == "cupy": + _array_after_transform = _array_after_transform.get() + + return array_type(_array_after_transform) + return self.xp.c_[self._transform(points[:, :3]), points[:, 3:]] @abc.abstractmethod diff --git a/pyproject.toml b/pyproject.toml index 32ef351..efaf8db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,12 +10,12 @@ python = ">=3.9,<4.0" pybind11 = "^2.6.2" six = "^1.15.0" scipy = "^1.6.0" -transforms3d = "^0.3.1" +transforms3d = "^0.4.0" scikit-learn = "^1.0" matplotlib = "^3.3.3" open3d = "0.18.0" dq3d = {version = "^0.3.6", optional = true} -cupy = {version = "^11.0.0", optional = true} +cupy = {version = "^12.0.0", optional = true} pyyaml = "^6.0" addict = "^2.4.0" pandas = "^2.0.0"