From 66a9b8712d7a0fb865397151b68595757db247d0 Mon Sep 17 00:00:00 2001 From: neka-nat Date: Sat, 11 May 2024 14:22:38 +0900 Subject: [PATCH] fix gmmtree --- probreg/cc/gmmtree.cc | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/probreg/cc/gmmtree.cc b/probreg/cc/gmmtree.cc index 077daaf..0cfd30b 100644 --- a/probreg/cc/gmmtree.cc +++ b/probreg/cc/gmmtree.cc @@ -44,7 +44,7 @@ Integer child(Integer j) { return (j + 1) * N_NODE; } Integer level(Integer l) { return N_NODE * (std::pow(N_NODE, l) - 1) / (N_NODE - 1); } void initializeNodes(NodeParamArray& nodes, const MatrixX3& points, Integer max_tree_level) { - const auto idxs = (points.rows() * Vector::Random(max_tree_level * N_NODE)).array().abs().cast(); + const auto idxs = (points.rows() * Vector::Random(std::pow(N_NODE, max_tree_level))).array().abs().cast(); const Integer lf_idx = level(max_tree_level - 1); for (Integer j = 0; j < std::pow(N_NODE, max_tree_level); ++j) { std::get<0>(nodes[lf_idx + j]) = 1.0 / N_NODE; diff --git a/pyproject.toml b/pyproject.toml index 1978a48..32ef351 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ scikit-learn = "^1.0" matplotlib = "^3.3.3" open3d = "0.18.0" dq3d = {version = "^0.3.6", optional = true} -cupy = {version = "^9.5.0", optional = true} +cupy = {version = "^11.0.0", optional = true} pyyaml = "^6.0" addict = "^2.4.0" pandas = "^2.0.0"