Skip to content

Commit

Permalink
ENH: repartitioning with zoltan2
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian-Diaz committed Oct 20, 2024
1 parent 3a0e43e commit 3acabc4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
34 changes: 29 additions & 5 deletions examples/ann_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,27 +373,51 @@ int main(int argc, char* argv[])
ANNLayers(num_layers-1).distributed_outputs.update_host();

if(process_rank==0)
std::cout << "output values: \n";
std::cout << "output values grid: \n";
std::flush(std::cout);
MPI_Barrier(MPI_COMM_WORLD);

std::stringstream output_stream;
size_t local_output_size = ANNLayers(num_layers-1).distributed_outputs.submap_size();
for (size_t val=0; val < local_output_size; val++){
int global_index = ANNLayers(num_layers-1).distributed_outputs.getSubMapGlobalIndex(val);
int local_index = ANNLayers(num_layers-1).distributed_outputs.getMapLocalIndex(global_index);
std::cout << " " << ANNLayers(num_layers-1).distributed_outputs.host(local_index) << std::endl;
output_stream << " " << ANNLayers(num_layers-1).distributed_outputs.host(local_index);
if(val%10==0) output_stream << std::endl;
} // end for

std::cout << output_stream.str();
std::flush(std::cout);

//test repartition; assume a 10 by 10 grid of outputs from ANN
//assign coords to each grid point, find a partition of the grid, then repartition output layer using new map
TpetraMVArray<real_t> output_grid(100, 2); //array of 2D coordinates for 10 by 10 grid of points

//populate coords
FOR_ALL(i,0,output_grid.dims(0), {
output_grid(i, 0) = i/10;
output_grid(i, 1) = i%10;
}); // end parallel for
output_grid.repartition_vector();

MPI_Barrier(MPI_COMM_WORLD);
if(process_rank==0){
std::cout << std::endl;
std::cout << " Map before repartitioning" << std::endl;
}
std::flush(std::cout);
output_grid.pmap.print();

MPI_Barrier(MPI_COMM_WORLD);
output_grid.repartition_vector();
if(process_rank==0){
std::cout << std::endl;
std::cout << " Map after repartitioning" << std::endl;
}
output_grid.pmap.print();

if(process_rank==0){
std::cout << std::endl;
std::cout << " Grid components per rank after repartitioning" << std::endl;
}
output_grid.print();
} // end of kokkos scope

Kokkos::finalize();
Expand Down
17 changes: 17 additions & 0 deletions src/include/tpetra_wrapper_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,9 @@ class TpetraMVArray {
// Method that update device view
void update_device();

//print vector data
void print() const;

// Deconstructor
virtual KOKKOS_INLINE_FUNCTION
~TpetraMVArray ();
Expand Down Expand Up @@ -851,6 +854,14 @@ void TpetraMVArray<T,Layout,ExecSpace,MemoryTraits>::perform_comms() {

}

template <typename T, typename Layout, typename ExecSpace, typename MemoryTraits>
void TpetraMVArray<T,Layout,ExecSpace,MemoryTraits>::print() const {
std::ostream &out = std::cout;
Teuchos::RCP<Teuchos::FancyOStream> fos;
fos = Teuchos::fancyOStream(Teuchos::rcpFromRef(out));
tpetra_vector->describe(*fos,Teuchos::VERB_EXTREME);
}

template <typename T, typename Layout, typename ExecSpace, typename MemoryTraits>
void TpetraMVArray<T,Layout,ExecSpace,MemoryTraits>::repartition_vector() {

Expand Down Expand Up @@ -943,6 +954,12 @@ void TpetraMVArray<T,Layout,ExecSpace,MemoryTraits>::repartition_vector() {
own_comms = false; //reset submap setup now that full map is different
dims_[0] = tpetra_pmap->getLocalNumElements();
length_ = (dims_[0] * dims_[1]);

//copy new partitioned vector into another one constructed with our managed dual view
this_array_ = TArray1D(this_array_.d_view.label(), dims_[0], dims_[1]);
Teuchos::RCP<MV> managed_tpetra_vector = Teuchos::rcp(new MV(tpetra_pmap, this_array_));
managed_tpetra_vector->assign(*tpetra_vector);
tpetra_vector = managed_tpetra_vector;
// // migrate density vector if this is a restart file read
// if (simparam.restart_file&&repartition_node_densities)
// {
Expand Down

0 comments on commit 3acabc4

Please sign in to comment.