Skip to content

Commit

Permalink
improve test of flexicube to take any triangulation combination (#776)
Browse files Browse the repository at this point in the history
Signed-off-by: Clement Fuji Tsang <[email protected]>
  • Loading branch information
Caenorst authored Dec 13, 2023
1 parent d283806 commit c67198d
Showing 1 changed file with 87 additions and 185 deletions.
272 changes: 87 additions & 185 deletions tests/python/kaolin/non_commercial/flexicubes/test_flexicubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,188 +177,69 @@ def expected_grid(self, device):
return expected_x_nx3, expected_cube_fx8

@pytest.fixture(autouse=True)
def expected_qef_output(self, device):
expected_vertices = torch.tensor([[-0.5, -0.5, -0.5],
[-0.5, -0.5, 0.0],
[-0.5, -0.5, 0.5],
[-0.5, 0.0, -0.5],
[-0.5, 0.0, 0.0],
[-0.5, 0.0, 0.5],
[-0.5, 0.5, -0.5],
[-0.5, 0.5, 0.0],
[-0.5, 0.5, 0.5],
[0.0, -0.5, -0.5],
[0.0, -0.5, 0.0],
[0.0, -0.5, 0.5],
[0.0, 0.0, -0.5],
[0.0, 0.0, 0.5],
[0.0, 0.5, -0.5],
[0.0, 0.5, 0.0],
[0.0, 0.5, 0.5],
[0.5, -0.5, -0.5],
[0.5, -0.5, 0.0],
[0.5, -0.5, 0.5],
[0.5, 0.0, -0.5],
[0.5, 0.0, 0.0],
[0.5, 0.0, 0.5],
[0.5, 0.5, -0.5],
[0.5, 0.5, 0.0],
[0.5, 0.5, 0.5]],
dtype=torch.float,
device=device)
def expected_qef_vertices(self, device):
return torch.tensor([[-0.5, -0.5, -0.5],
[-0.5, -0.5, 0.0],
[-0.5, -0.5, 0.5],
[-0.5, 0.0, -0.5],
[-0.5, 0.0, 0.0],
[-0.5, 0.0, 0.5],
[-0.5, 0.5, -0.5],
[-0.5, 0.5, 0.0],
[-0.5, 0.5, 0.5],
[0.0, -0.5, -0.5],
[0.0, -0.5, 0.0],
[0.0, -0.5, 0.5],
[0.0, 0.0, -0.5],
[0.0, 0.0, 0.5],
[0.0, 0.5, -0.5],
[0.0, 0.5, 0.0],
[0.0, 0.5, 0.5],
[0.5, -0.5, -0.5],
[0.5, -0.5, 0.0],
[0.5, -0.5, 0.5],
[0.5, 0.0, -0.5],
[0.5, 0.0, 0.0],
[0.5, 0.0, 0.5],
[0.5, 0.5, -0.5],
[0.5, 0.5, 0.0],
[0.5, 0.5, 0.5]],
dtype=torch.float,
device=device)

expected_faces_1 = torch.tensor([[3, 4, 0],
[0, 4, 1],
[4, 5, 1],
[1, 5, 2],
[6, 7, 3],
[3, 7, 4],
[7, 8, 5],
[7, 5, 4],
[9, 12, 0],
[0, 12, 3],
[9, 10, 0],
[0, 10, 1],
[10, 11, 1],
[1, 11, 2],
[11, 13, 2],
[2, 13, 5],
[12, 14, 3],
[3, 14, 6],
[13, 16, 8],
[13, 8, 5],
[14, 15, 6],
[6, 15, 7],
[15, 16, 7],
[7, 16, 8],
[17, 20, 12],
[17, 12, 9],
[17, 18, 10],
[17, 10, 9],
[20, 21, 17],
[17, 21, 18],
[18, 19, 10],
[10, 19, 11],
[19, 22, 13],
[19, 13, 11],
[21, 22, 18],
[18, 22, 19],
[20, 23, 12],
[12, 23, 14],
[23, 24, 20],
[20, 24, 21],
[22, 25, 13],
[13, 25, 16],
[24, 25, 22],
[24, 22, 21],
[23, 24, 15],
[23, 15, 14],
[24, 25, 16],
[24, 16, 15]],
dtype=torch.long,
device=device)
expected_faces_2 = torch.tensor([[3, 4, 0],
[0, 4, 1],
[4, 5, 1],
[1, 5, 2],
[6, 7, 3],
[3, 7, 4],
[7, 8, 5],
[7, 5, 4],
[9, 12, 0],
[0, 12, 3],
[9, 10, 0],
[0, 10, 1],
[10, 11, 1],
[1, 11, 2],
[11, 13, 2],
[2, 13, 5],
[12, 14, 3],
[3, 14, 6],
[13, 16, 5],
[5, 16, 8],
[14, 15, 6],
[6, 15, 7],
[15, 16, 8],
[15, 8, 7],
[17, 20, 9],
[9, 20, 12],
[17, 18, 9],
[9, 18, 10],
[20, 21, 17],
[17, 21, 18],
[18, 19, 10],
[10, 19, 11],
[19, 22, 13],
[19, 13, 11],
[21, 22, 18],
[18, 22, 19],
[20, 23, 12],
[12, 23, 14],
[23, 24, 20],
[20, 24, 21],
[22, 25, 13],
[13, 25, 16],
[24, 25, 22],
[24, 22, 21],
[23, 24, 15],
[23, 15, 14],
[24, 25, 15],
[15, 25, 16]],
dtype=torch.long,
device=device)
expected_faces_3 = torch.tensor([[ 3, 4, 0],
[ 0, 4, 1],
[ 4, 5, 1],
[ 1, 5, 2],
[ 6, 7, 3],
[ 3, 7, 4],
[ 7, 8, 5],
[ 7, 5, 4],
[ 9, 12, 0],
[ 0, 12, 3],
[ 9, 10, 0],
[ 0, 10, 1],
[10, 11, 1],
[ 1, 11, 2],
[11, 13, 2],
[ 2, 13, 5],
[12, 14, 6],
[12, 6, 3],
[13, 16, 8],
[13, 8, 5],
[14, 15, 6],
[ 6, 15, 7],
[15, 16, 7],
[ 7, 16, 8],
[17, 20, 12],
[17, 12, 9],
[17, 18, 10],
[17, 10, 9],
[20, 21, 18],
[20, 18, 17],
[18, 19, 10],
[10, 19, 11],
[19, 22, 13],
[19, 13, 11],
[21, 22, 18],
[18, 22, 19],
[20, 23, 12],
[12, 23, 14],
[23, 24, 20],
[20, 24, 21],
[22, 25, 13],
[13, 25, 16],
[24, 25, 22],
[24, 22, 21],
[23, 24, 15],
[23, 15, 14],
[24, 25, 15],
[15, 25, 16]],
dtype=torch.long,
device=device)

return expected_vertices, expected_faces_1, expected_faces_2, expected_faces_3
@pytest.fixture(autouse=True)
def expected_qef_possible_tri(self, device):
quad = torch.tensor([
[3, 4, 1, 0],
[4, 5, 2, 1],
[6, 7, 4, 3],
[7, 8, 5, 4],
[9, 12, 3, 0],
[9, 10, 1, 0],
[10, 11, 2, 1],
[11, 13, 5, 2],
[12, 14, 6, 3],
[13, 16, 8, 5],
[14, 15, 7, 6],
[15, 16, 8, 7],
[17, 20, 12, 9],
[17, 18, 10, 9],
[20, 21, 18, 17],
[18, 19, 11, 10],
[19, 22, 13, 11],
[21, 22, 19, 18],
[20, 23, 14, 12],
[23, 24, 21, 20],
[22, 25, 16, 13],
[24, 25, 22, 21],
[23, 24, 15, 14],
[24, 25, 16, 15]
], dtype=torch.long, device=device)
tri_00 = torch.sort(quad[:, [0, 1, 2]], dim=1)[0]
tri_01 = torch.sort(quad[:, [0, 2, 3]], dim=1)[0]
tri_10 = torch.sort(quad[:, [0, 1, 3]], dim=1)[0]
tri_11 = torch.sort(quad[:, [1, 2, 3]], dim=1)[0]
return tri_00, tri_01, tri_10, tri_11

def test_grid_construction(self, expected_grid, device):
fc = FlexiCubes(device)
Expand All @@ -382,12 +263,33 @@ def test_tetmesh_extraction(self, input_data, expected_tetmesh_output, device):
assert torch.allclose(output[0], expected_tetmesh_output[0], atol=1e-4)
assert torch.equal(output[1], expected_tetmesh_output[1])

def test_qef_extraction_grad_func(self, expected_qef_output, device):
def test_qef_extraction_grad_func(self, expected_qef_vertices,
expected_qef_possible_tri, device):
fc = FlexiCubes(device)
x_nx3, cube_fx8 = fc.construct_voxel_grid(3)
sdf_n = cube_sdf(x_nx3)
output = fc(x_nx3, sdf_n, cube_fx8, 3, grad_func=cube_sdf_gradient)

assert torch.allclose(output[0], expected_qef_output[0], atol=1e-4)
# in this example, both triangulations are correct
assert torch.equal(output[1], expected_qef_output[1]) or torch.equal(output[1], expected_qef_output[2]) or torch.equal(output[1], expected_qef_output[3])
assert torch.allclose(output[0], expected_qef_vertices, atol=1e-4)
# There are many triangulation possible
tri_00, tri_01, tri_10, tri_11 = expected_qef_possible_tri
sorted_tri_mesh = torch.sort(output[1], dim=1)[0]
has_tri_00 = torch.any(torch.all(
tri_00.reshape(1, -1, 3) == sorted_tri_mesh.reshape(-1, 1, 3), dim=-1), dim=0)
has_tri_01 = torch.any(torch.all(
tri_01.reshape(1, -1, 3) == sorted_tri_mesh.reshape(-1, 1, 3), dim=-1), dim=0)
has_tri_10 = torch.any(torch.all(
tri_10.reshape(1, -1, 3) == sorted_tri_mesh.reshape(-1, 1, 3), dim=-1), dim=0)
has_tri_11 = torch.any(torch.all(
tri_11.reshape(1, -1, 3) == sorted_tri_mesh.reshape(-1, 1, 3), dim=-1), dim=0)
has_tri_0 = torch.logical_and(has_tri_00, has_tri_01)
has_tri_1 = torch.logical_and(has_tri_10, has_tri_11)
has_tri = torch.logical_or(has_tri_0, has_tri_1)
has_all_tri = torch.all(has_tri)
reconstructed_mesh = torch.cat([
tri_00[has_tri_00], tri_01[has_tri_01],
tri_10[has_tri_10], tri_11[has_tri_11]
], dim=0)
assert reconstructed_mesh.shape[0] == tri_00.shape[0] * 2
assert torch.unique(sorted_tri_mesh, dim=0).shape == sorted_tri_mesh.shape
assert torch.all(torch.unique(reconstructed_mesh, dim=0) == torch.unique(sorted_tri_mesh, dim=0))

0 comments on commit c67198d

Please sign in to comment.