Skip to content

Commit

Permalink
add fp8 comm op
Browse files Browse the repository at this point in the history
  • Loading branch information
GuangyaoZhang committed Aug 9, 2024
1 parent e4aadee commit 80f3882
Show file tree
Hide file tree
Showing 2 changed files with 337 additions and 0 deletions.
122 changes: 122 additions & 0 deletions extensions/csrc/communication/include/dummy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#pragma once
#define USE_C10D_NCCL

#include <pybind11/chrono.h>
#include <torch/python.h>

#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>

namespace c10d {

class BackendDummy : public Backend {
public:
BackendDummy(const c10::intrusive_ptr<::c10d::Store>&, int rank, int size);

::c10d::ProcessGroupNCCL pg_nccl;

void cast_to_fp8(at::Tensor& input_tensor, at::Tensor& output_tensor,
at::Tensor& scale_inv);
at::Tensor cast_from_fp8(at::Tensor input_tensor, at::Tensor scale_inv,
caffe2::TypeMeta dtype);

c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;

c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;

c10::intrusive_ptr<Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;

c10::intrusive_ptr<Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;

c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;

c10::intrusive_ptr<Work> _allgather_base(
at::Tensor& outputBuffer, at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;

c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;

c10::intrusive_ptr<Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;

c10::intrusive_ptr<Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;

c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;

c10::intrusive_ptr<Work> alltoall_base(
at::Tensor& outputTensor, at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;

c10::intrusive_ptr<Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;

c10::intrusive_ptr<Work> send(std::vector<at::Tensor>& tensors, int dstRank,
int tag) override;

c10::intrusive_ptr<Work> recv(std::vector<at::Tensor>& tensors, int srcRank,
int tag) override;

c10::intrusive_ptr<Work> recvAnysource(std::vector<at::Tensor>& tensors,
int tag) override;

static c10::intrusive_ptr<Backend> createBackendDummy(
const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size,
const std::chrono::duration<float>& timeout);

static void BackendDummyConstructor() __attribute__((constructor)) {
py::object module = py::module::import("torch.distributed");
py::object register_backend =
module.attr("Backend").attr("register_backend");
register_backend("dummy", py::cpp_function(createBackendDummy));
}
};

class WorkDummy : public Work {
friend class BackendDummy;

public:
WorkDummy(
OpType opType,
c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output
: Work(-1, // rank, only used by recvAnySource, irrelevant in this demo
opType),
future_(std::move(future)) {}
bool isCompleted() override;
bool isSuccess() const override;
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;

private:
c10::intrusive_ptr<c10::ivalue::Future> future_;
};

} // namespace c10d
215 changes: 215 additions & 0 deletions extensions/csrc/communication/src/dummy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#include "dummy.hpp"

#include <ATen/ATen.h>
#include <torch/torch.h>

#include <iostream>

namespace c10d {

bool WorkDummy::isCompleted() { return true; }

bool WorkDummy::isSuccess() const { return true; }

bool WorkDummy::wait(std::chrono::milliseconds /* unused */) { return true; }

c10::intrusive_ptr<c10::ivalue::Future> WorkDummy::getFuture() {
return future_;
}

// If necessary, pass store/rank/size to the ctor and exchange connection
// information here
BackendDummy::BackendDummy(const c10::intrusive_ptr<::c10d::Store>& store,
int rank, int size)
: Backend(rank, size), pg_nccl(store, rank, size) {
// ::c10d::ProcessGroupNCCL
// auto pg_options = ::c10d::ProcessGroupNCCL::Options::create();
// auto pg_nccl = ::c10d::ProcessGroupNCCL(store, rank, size, pg_options);
}

void BackendDummy::cast_to_fp8(at::Tensor& input_tensor,
at::Tensor& output_tensor,
at::Tensor& scale_inv) {
at::Tensor tensor_max = input_tensor.abs().max();
at::Tensor tensor_max_new =
torch::where(tensor_max > 0, tensor_max, at::Scalar(1));
at::Tensor fp8_max = torch::scalar_tensor(at::Scalar(448.0));
at::Tensor scale = fp8_max.div(tensor_max_new);
output_tensor =
scale.mul(input_tensor.to(torch::kFloat32)).to(at::kFloat8_e4m3fn);
scale_inv = 1.0 / scale;
}

at::Tensor BackendDummy::cast_from_fp8(at::Tensor input_tensor,
at::Tensor scale_inv,
caffe2::TypeMeta dtype) {
return scale_inv.mul(input_tensor.to(torch::kFloat32)).to(dtype);
}
// This is a dummy allgather that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously
c10::intrusive_ptr<Work> BackendDummy::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts /* unused */) {
return pg_nccl.allgather(outputTensors, inputTensors, opts);
}

c10::intrusive_ptr<Work> BackendDummy::_allgather_base(
at::Tensor& tensor1 /* unused */, at::Tensor& tensor2 /* unused */,
const AllgatherOptions& /* unused */ opt) {
return pg_nccl._allgather_base(tensor1, tensor2, opt);
// throw std::runtime_error("not supported");
}

// This is a dummy allreduce that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously

c10::intrusive_ptr<Work> BackendDummy::allreduce(
std::vector<at::Tensor>& tensors, const AllreduceOptions& opts) {
std::vector<int64_t> tmp_size;
auto tensor = tensors[0];
// int world_size = this->getSize();
int world_size = 2;
auto input_type = tensor.dtype();
auto device = tensor.device();

at::Tensor flatten_tensor = tensor.flatten();

at::Tensor fp8_tensor;
at::Tensor scale;
cast_to_fp8(flatten_tensor, fp8_tensor, scale);
fp8_tensor = fp8_tensor.view(torch::kInt8);
auto output_tensor = torch::empty_like(fp8_tensor);

pg_nccl.alltoall_base(output_tensor, fp8_tensor, tmp_size, tmp_size)
->wait(std::chrono::milliseconds(10000));

at::Tensor scale_list = torch::zeros(
{world_size},
at::TensorOptions().dtype(scale.dtype()).device(scale.device()));
pg_nccl._allgather_base(scale_list, scale)
->wait(std::chrono::milliseconds(10000));

auto output_tensor_chunk = at::chunk(output_tensor, world_size);

auto sumed_output = torch::zeros_like(output_tensor_chunk[0]).to(input_type);

for (int rank = 0; rank < world_size; ++rank) {
sumed_output +=
cast_from_fp8(output_tensor_chunk[rank].view(at::kFloat8_e4m3fn),
scale_list[rank], input_type);
}

at::Tensor sumed_output_fp8;
at::Tensor sumed_output_scale;
cast_to_fp8(sumed_output, sumed_output_fp8, sumed_output_scale);
sumed_output_fp8 = sumed_output_fp8.view(torch::kInt8);

auto sumed_output_scale_list = torch::zeros(
{world_size},
at::TensorOptions().dtype(scale.dtype()).device(scale.device()));
auto sumed_output_fp8_list = torch::empty_like(tensor).to(torch::kInt8);

pg_nccl._allgather_base(sumed_output_scale_list, sumed_output_scale)
->wait(std::chrono::milliseconds(10000));
pg_nccl._allgather_base(sumed_output_fp8_list, sumed_output_fp8)
->wait(std::chrono::milliseconds(10000));

auto sumed_output_fp8_chunk = at::chunk(sumed_output_fp8_list, world_size);
std::vector<at::Tensor> output;
for (int rank = 0; rank < world_size; ++rank) {
output.push_back(
cast_from_fp8(sumed_output_fp8_chunk[rank].view(at::kFloat8_e4m3fn),
sumed_output_scale_list[rank], input_type));
}

tensors[0].copy_(at::cat(output).reshape(tensor.sizes()));

auto future = c10::make_intrusive<c10::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()));
future->markCompleted(c10::IValue(tensors));
return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));
}

c10::intrusive_ptr<Work> BackendDummy::allreduce_coalesced(
std::vector<at::Tensor>& /* unused */,
const AllreduceCoalescedOptions& /* unused */) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::alltoall(
std::vector<at::Tensor>& /* unused */,
std::vector<at::Tensor>& /* unused */,
const AllToAllOptions& /* unused */) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::alltoall_base(
at::Tensor& outputTensor, at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& /* unused */) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::barrier(
const BarrierOptions& /* unused */) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::broadcast(
std::vector<at::Tensor>& tensors, const BroadcastOptions& opts) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::gather(
std::vector<std::vector<at::Tensor>>& /* unused */,
std::vector<at::Tensor>& /* unused */, const GatherOptions& /* unused */) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::reduce(
std::vector<at::Tensor>& /* unused */, const ReduceOptions& /* unused */) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::reduce_scatter(
std::vector<at::Tensor>& /* unused */,
std::vector<std::vector<at::Tensor>>& /* unused */,
const ReduceScatterOptions& /* unused */) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::scatter(
std::vector<at::Tensor>& /* unused */,
std::vector<std::vector<at::Tensor>>& /* unused */,
const ScatterOptions& /* unused */) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::send(std::vector<at::Tensor>& tensors,
int dstRank, int tag) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::recv(std::vector<at::Tensor>& tensors,
int srcRank, int tag) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Work> BackendDummy::recvAnysource(
std::vector<at::Tensor>& tensors, int tag) {
throw std::runtime_error("not supported");
}

c10::intrusive_ptr<Backend> BackendDummy::createBackendDummy(
const c10::intrusive_ptr<::c10d::Store>& store /* unused */, int rank,
int size, const std::chrono::duration<float>& /* unused */) {
return c10::make_intrusive<BackendDummy>(store, rank, size);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("createBackendDummy", &BackendDummy::createBackendDummy);
}

} // namespace c10d

0 comments on commit 80f3882

Please sign in to comment.