Skip to content

Commit

Permalink
fix using cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
neka-nat committed May 11, 2024
1 parent c8f6ad5 commit 9c414d9
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
2 changes: 0 additions & 2 deletions examples/cpd_affine3d_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
else:
cp = np
to_cpu = lambda x: x
import open3d as o3
from probreg import cpd
from probreg import callbacks
import utils
import time

Expand Down
1 change: 0 additions & 1 deletion examples/cpd_nonrigid3d_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
to_cpu = lambda x: x
import open3d as o3
from probreg import cpd
from probreg import callbacks
import utils
import time

Expand Down
4 changes: 1 addition & 3 deletions examples/cpd_rigid_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
else:
cp = np
to_cpu = lambda x: x
import open3d as o3
import transforms3d as trans
import transforms3d as t3d
from probreg import cpd
from probreg import callbacks
import utils
import time

Expand Down
11 changes: 9 additions & 2 deletions probreg/cpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
"""


class DistModule:
def __init__(self, xp):
self.xp = xp

def cdist(self, x1, x2, metric):
return self.xp.stack([self.xp.sum(self.xp.square(x2 - ts), axis=1) for ts in x1])


@six.add_metaclass(abc.ABCMeta)
class CoherentPointDrift:
"""Coherent Point Drift algorithm.
Expand All @@ -50,12 +58,11 @@ def __init__(self, source: Optional[np.ndarray] = None, use_color: bool = False,
self._use_color = use_color
if use_cuda:
import cupy as cp
from cupyx.scipy.spatial import distance as cupy_distance

from . import cupy_utils

self.xp = cp
self.distance_module = cupy_distance
self.distance_module = DistModule(cp)
self.cupy_utils = cupy_utils
self._squared_kernel_sum = cupy_utils.squared_kernel_sum
else:
Expand Down

0 comments on commit 9c414d9

Please sign in to comment.