Skip to content

Commit

Permalink
fix array type
Browse files Browse the repository at this point in the history
  • Loading branch information
neka-nat committed May 11, 2024
1 parent 66a9b87 commit c8f6ad5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions probreg/cpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _maximization_step(
sigma2 = max(sigma2, np.finfo(np.float32).eps)
q = (tr_xp1x - 2 * tr_ab + tr_xpyb) / (2.0 * sigma2)
q += dim * n_p * 0.5 * np.log(sigma2)
return MstepResult(tf.AffineTransformation(b, t), sigma2, q)
return MstepResult(tf.AffineTransformation(b, t, xp=xp), sigma2, q)


class NonRigidCPD(CoherentPointDrift):
Expand Down Expand Up @@ -316,7 +316,7 @@ def __init__(

def set_source(self, source: np.ndarray) -> None:
self._source = source
self._tf_obj = self._tf_type(None, self._source, self._beta)
self._tf_obj = self._tf_type(None, self._source, self._beta, xp=self.xp)

def maximization_step(
self, target: np.ndarray, estep_res: EstepResult, sigma2_p: Optional[float] = None
Expand Down
8 changes: 7 additions & 1 deletion probreg/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ def __init__(self, xp=np):

def transform(self, points, array_type=o3.utility.Vector3dVector):
if isinstance(points, array_type):
return array_type(self._transform(self.xp.asarray(points)))
_array_after_transform = self._transform(self.xp.asarray(points))

if self.xp.__name__ == "cupy":
_array_after_transform = _array_after_transform.get()

return array_type(_array_after_transform)

return self.xp.c_[self._transform(points[:, :3]), points[:, 3:]]

@abc.abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ python = ">=3.9,<4.0"
pybind11 = "^2.6.2"
six = "^1.15.0"
scipy = "^1.6.0"
transforms3d = "^0.3.1"
transforms3d = "^0.4.0"
scikit-learn = "^1.0"
matplotlib = "^3.3.3"
open3d = "0.18.0"
dq3d = {version = "^0.3.6", optional = true}
cupy = {version = "^11.0.0", optional = true}
cupy = {version = "^12.0.0", optional = true}
pyyaml = "^6.0"
addict = "^2.4.0"
pandas = "^2.0.0"
Expand Down

0 comments on commit c8f6ad5

Please sign in to comment.