Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add all-reduce for MPI, NCCL, and RCCL #8

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 66 additions & 34 deletions allgather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@
#include <stdio.h>
#include <stdlib.h>
#include <mpi.h>

#ifdef USE_CUDA
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#define bfloat16 nv_bfloat16
#elif USE_ROCM
#include <hip/hip_bfloat16.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#define bfloat16 hip_bfloat16
#endif

#ifdef USE_NCCL
#include "nccl.h"
#elif defined(USE_RCCL)
#include "rccl.h"
#elif USE_RCCL
#include <rccl/rccl.h>
#endif

#define NUM_WARMUP_ITERATIONS 5
Expand All @@ -40,6 +44,16 @@
} \
} while(0)

#define HIP_CHECK(cmd) do { \
hipError_t e = cmd; \
if(e != hipSuccess) { \
printf("HIP error %s:%d: %s\n", \
__FILE__, __LINE__, hipGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while(0)

// NCCL_CHECK is used to validate RCCL functions as well
#define NCCL_CHECK(cmd) do { \
ncclResult_t e = cmd; \
if (e != ncclSuccess) { \
Expand All @@ -49,9 +63,14 @@
} \
} while(0)

void initializeData(nv_bfloat16 *data, int size) {
for (int i = 0; i < (size / sizeof(nv_bfloat16)); ++i) {
void initializeData(bfloat16 *data, int size) {
for (int i = 0; i < (size / sizeof(bfloat16)); ++i) {
#ifdef USE_CUDA
data[i] = __float2bfloat16((float)i);
#elif USE_ROCM
// ROCm doesn't have a float2bfloat16 method
data[i] = (bfloat16) ((float) i);
#endif
}
}

Expand Down Expand Up @@ -86,33 +105,44 @@ int main(int argc, char *argv[]) {
}

// Initialize GPU context
#if USE_CUDA
cudaGetDeviceCount(&num_gpus_per_node);
cudaSetDevice((my_rank % num_gpus_per_node));
#elif USE_ROCM
hipGetDeviceCount(&num_gpus_per_node);
hipSetDevice((my_rank % num_gpus_per_node));
#endif

int local_data_size = max_msg_size; // Size of local data
int global_data_size = local_data_size * num_gpus; // Size of global data

nv_bfloat16 *local_data = (nv_bfloat16*)malloc(local_data_size);
nv_bfloat16 *global_data = (nv_bfloat16*)malloc(global_data_size);
bfloat16 *local_data = (bfloat16*)malloc(local_data_size);
bfloat16 *global_data = (bfloat16*)malloc(global_data_size);

// Initialize local data
initializeData(local_data, local_data_size);

// Allocate memory on GPU
nv_bfloat16 *d_local_data, *d_global_data;
bfloat16 *d_local_data, *d_global_data;
#ifdef USE_CUDA
CUDA_CHECK(cudaMalloc(&d_local_data, local_data_size));
CUDA_CHECK(cudaMalloc(&d_global_data, global_data_size));

// Copy local data to GPU
CUDA_CHECK(cudaMemcpy(d_local_data, local_data, local_data_size, cudaMemcpyHostToDevice));

#elif USE_ROCM
HIP_CHECK(hipMalloc(&d_local_data, local_data_size));
HIP_CHECK(hipMalloc(&d_global_data, global_data_size));
HIP_CHECK(hipMemcpy(d_local_data, local_data, local_data_size, hipMemcpyHostToDevice));
#endif

#ifdef USE_MPI
// create 2-byte datatype (send raw, un-interpreted bytes)
MPI_Datatype mpi_type_bfloat16;
MPI_Type_contiguous(2, MPI_BYTE, &mpi_type_bfloat16);
MPI_Type_commit(&mpi_type_bfloat16);

#elif USE_NCCL
#elif defined(USE_NCCL) || defined(USE_RCCL)
ncclUniqueId nccl_comm_id;
ncclComm_t nccl_comm;

Expand All @@ -125,13 +155,8 @@ int main(int argc, char *argv[]) {
MPI_CHECK(MPI_Bcast((void *)&nccl_comm_id, sizeof(nccl_comm_id), MPI_BYTE,
0, MPI_COMM_WORLD));

/* Create a new NCCL communicator */
/* Create a new NCCL/RCCL communicator */
NCCL_CHECK(ncclCommInitRank(&nccl_comm, num_pes, nccl_comm_id, my_rank));

#elif defined(USE_RCCL)
// TODO: fix later
rcclComm_t rccl_comm;
rcclCommInitRank(&comm, num_gpus, 0, rccl_root);
#endif

// Perform MPI_Iallgather, NCCL allgather, or RCCL allgather
Expand All @@ -148,20 +173,22 @@ int main(int argc, char *argv[]) {
fflush(NULL);

for (int msg_size = min_msg_size; msg_size <= max_msg_size; msg_size *= 2) {
msg_count = msg_size / sizeof(nv_bfloat16);
msg_count = msg_size / sizeof(bfloat16);
// warmup iterations
for (int i = 0; i < NUM_WARMUP_ITERATIONS; ++i) {
#ifdef USE_MPI
MPI_CHECK(MPI_Iallgather(d_local_data, msg_count, mpi_type_bfloat16,
d_global_data, msg_count, mpi_type_bfloat16, MPI_COMM_WORLD, &request));

MPI_CHECK(MPI_Wait(&request, &status));
#elif defined(USE_NCCL)
#elif defined(USE_NCCL) || defined(USE_RCCL)
NCCL_CHECK(ncclAllGather((const void*)d_local_data, (void*)d_global_data, msg_count, ncclBfloat16, nccl_comm, NULL));
cudaDeviceSynchronize();
#elif defined(USE_RCCL)
// TODO: fix later
rcclAllReduce((const void*)d_local_data, (void*)d_global_data, global_data_size, rcclInt, rcclSum, comm, NULL);
#endif

#ifdef USE_CUDA
cudaDeviceSynchronize();
#elif USE_ROCM
hipDeviceSynchronize();
#endif
}

Expand All @@ -172,16 +199,18 @@ int main(int argc, char *argv[]) {
start_time = MPI_Wtime();
for (int i = 0; i < iterations; ++i) {
#ifdef USE_MPI
MPI_CHECK(MPI_Iallgather(d_local_data, msg_count, mpi_type_bfloat16,
d_global_data, msg_count, mpi_type_bfloat16, MPI_COMM_WORLD, &request));

MPI_CHECK(MPI_Iallgather(d_local_data, msg_count, mpi_type_bfloat16,
d_global_data, msg_count, mpi_type_bfloat16, MPI_COMM_WORLD, &request));
MPI_CHECK(MPI_Wait(&request, &status));
#elif defined(USE_NCCL)
#elif defined(USE_NCCL) || defined(USE_RCCL)
NCCL_CHECK(ncclAllGather((const void*)d_local_data, (void*)d_global_data, msg_count, ncclBfloat16, nccl_comm, NULL));
cudaDeviceSynchronize();
#elif defined(USE_RCCL)
// TODO: fix later
rcclAllReduce((const void*)d_local_data, (void*)d_global_data, global_data_size, rcclInt, rcclSum, comm, NULL);
#endif

#ifdef USE_CUDA
cudaDeviceSynchronize();
#elif USE_ROCM
hipDeviceSynchronize();
#endif
}
MPI_Barrier(MPI_COMM_WORLD);
Expand All @@ -193,13 +222,16 @@ int main(int argc, char *argv[]) {
// Cleanup
free(local_data);
free(global_data);
#ifdef USE_CUDA
CUDA_CHECK(cudaFree(d_local_data));
CUDA_CHECK(cudaFree(d_global_data));
#elif USE_ROCM
HIP_CHECK(hipFree(d_local_data));
HIP_CHECK(hipFree(d_global_data));
#endif

#ifdef USE_NCCL
#ifdef defined(USE_NCCL) || defined(USE_RCCL)
ncclCommDestroy(nccl_comm);
#elif defined(USE_RCCL)
rcclCommDestroy(rccl_comm);
#endif

MPI_Finalize();
Expand Down
Loading