Skip to content

Commit

Permalink
Fix model replication issue
Browse files Browse the repository at this point in the history
Signed-off-by: acivgin1 <[email protected]>
  • Loading branch information
acivgin1 committed Aug 24, 2020
1 parent ec193ff commit 4c4e009
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions kaolin/models/PointNet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,10 @@ def __init__(self,

self.num_points_out = num_points_out
self.pointnet_layer_dims_list = pointnet_layer_dims_list
self.sub_modules = nn.ModuleList()
self.layers = []
self.grouper_modules = nn.ModuleList()
self.pointnet_modules = nn.ModuleList()
self.num_samples_list = []

self.pointnet_in_channels = pointnet_in_features + \
(3 if use_xyz_feature else 0)

Expand Down Expand Up @@ -590,10 +592,9 @@ def __init__(self,
)

# Register sub-modules
self.sub_modules.append(grouper)
self.sub_modules.append(pointnet)

self.layers.append((grouper, pointnet, num_samples))
self.grouper_modules.append(grouper)
self.pointnet_modules.append(pointnet)
self.num_samples_list.append(num_samples)

def forward(self, xyz, features=None):
"""
Expand Down Expand Up @@ -625,7 +626,7 @@ def forward(self, xyz, features=None):
new_xyz = new_xyz.transpose(1, 2).contiguous()

new_features_list = []
for grouper, pointnet, num_samples in self.layers:
for grouper, pointnet, num_samples in zip(self.grouper_modules, self.pointnet_modules, self.num_samples_list):
new_features = grouper(xyz, new_xyz, features)
# shape = (batch_size, num_points_out, self.pointnet_in_channels, num_samples)
# if num_points_out is None:
Expand Down

0 comments on commit 4c4e009

Please sign in to comment.