diff --git a/extensions/csrc/communication/include/dummy.hpp b/extensions/csrc/communication/include/dummy.hpp new file mode 100644 index 000000000000..2cb994d556b8 --- /dev/null +++ b/extensions/csrc/communication/include/dummy.hpp @@ -0,0 +1,122 @@ +#pragma once +#define USE_C10D_NCCL + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace c10d { + +class BackendDummy : public Backend { + public: + BackendDummy(const c10::intrusive_ptr<::c10d::Store>&, int rank, int size); + + ::c10d::ProcessGroupNCCL pg_nccl; + + void cast_to_fp8(at::Tensor& input_tensor, at::Tensor& output_tensor, + at::Tensor& scale_inv); + at::Tensor cast_from_fp8(at::Tensor input_tensor, at::Tensor scale_inv, + caffe2::TypeMeta dtype); + + c10::intrusive_ptr broadcast( + std::vector& data, + const BroadcastOptions& opts = BroadcastOptions()) override; + + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override; + + c10::intrusive_ptr reduce( + std::vector& tensors, + const ReduceOptions& opts = ReduceOptions()) override; + + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr _allgather_base( + at::Tensor& outputBuffer, at::Tensor& inputBuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override; + + c10::intrusive_ptr gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts = GatherOptions()) override; + + c10::intrusive_ptr scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts = ScatterOptions()) override; + + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr send(std::vector& tensors, int dstRank, + int tag) override; + + c10::intrusive_ptr recv(std::vector& tensors, int srcRank, + int tag) override; + + c10::intrusive_ptr recvAnysource(std::vector& tensors, + int tag) override; + + static c10::intrusive_ptr createBackendDummy( + const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, + const std::chrono::duration& timeout); + + static void BackendDummyConstructor() __attribute__((constructor)) { + py::object module = py::module::import("torch.distributed"); + py::object register_backend = + module.attr("Backend").attr("register_backend"); + register_backend("dummy", py::cpp_function(createBackendDummy)); + } +}; + +class WorkDummy : public Work { + friend class BackendDummy; + + public: + WorkDummy( + OpType opType, + c10::intrusive_ptr future) // future of the output + : Work(-1, // rank, only used by recvAnySource, irrelevant in this demo + opType), + future_(std::move(future)) {} + bool isCompleted() override; + bool isSuccess() const override; + bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; + virtual c10::intrusive_ptr getFuture() override; + + private: + c10::intrusive_ptr future_; +}; + +} // namespace c10d diff --git a/extensions/csrc/communication/src/dummy.cpp b/extensions/csrc/communication/src/dummy.cpp new file mode 100644 index 000000000000..6c9fc2d4419a --- /dev/null +++ b/extensions/csrc/communication/src/dummy.cpp @@ -0,0 +1,215 @@ +#include "dummy.hpp" + +#include +#include + +#include + +namespace c10d { + +bool WorkDummy::isCompleted() { return true; } + +bool WorkDummy::isSuccess() const { return true; } + +bool WorkDummy::wait(std::chrono::milliseconds /* unused */) { return true; } + +c10::intrusive_ptr WorkDummy::getFuture() { + return future_; +} + +// If necessary, pass store/rank/size to the ctor and exchange connection +// information here +BackendDummy::BackendDummy(const c10::intrusive_ptr<::c10d::Store>& store, + int rank, int size) + : Backend(rank, size), pg_nccl(store, rank, size) { + // ::c10d::ProcessGroupNCCL + // auto pg_options = ::c10d::ProcessGroupNCCL::Options::create(); + // auto pg_nccl = ::c10d::ProcessGroupNCCL(store, rank, size, pg_options); +} + +void BackendDummy::cast_to_fp8(at::Tensor& input_tensor, + at::Tensor& output_tensor, + at::Tensor& scale_inv) { + at::Tensor tensor_max = input_tensor.abs().max(); + at::Tensor tensor_max_new = + torch::where(tensor_max > 0, tensor_max, at::Scalar(1)); + at::Tensor fp8_max = torch::scalar_tensor(at::Scalar(448.0)); + at::Tensor scale = fp8_max.div(tensor_max_new); + output_tensor = + scale.mul(input_tensor.to(torch::kFloat32)).to(at::kFloat8_e4m3fn); + scale_inv = 1.0 / scale; +} + +at::Tensor BackendDummy::cast_from_fp8(at::Tensor input_tensor, + at::Tensor scale_inv, + caffe2::TypeMeta dtype) { + return scale_inv.mul(input_tensor.to(torch::kFloat32)).to(dtype); +} +// This is a dummy allgather that sets all output tensors to zero +// Modify the implementation to conduct real communication asynchronously +c10::intrusive_ptr BackendDummy::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts /* unused */) { + return pg_nccl.allgather(outputTensors, inputTensors, opts); +} + +c10::intrusive_ptr BackendDummy::_allgather_base( + at::Tensor& tensor1 /* unused */, at::Tensor& tensor2 /* unused */, + const AllgatherOptions& /* unused */ opt) { + return pg_nccl._allgather_base(tensor1, tensor2, opt); + // throw std::runtime_error("not supported"); +} + +// This is a dummy allreduce that sets all output tensors to zero +// Modify the implementation to conduct real communication asynchronously + +c10::intrusive_ptr BackendDummy::allreduce( + std::vector& tensors, const AllreduceOptions& opts) { + std::vector tmp_size; + auto tensor = tensors[0]; + // int world_size = this->getSize(); + int world_size = 2; + auto input_type = tensor.dtype(); + auto device = tensor.device(); + + at::Tensor flatten_tensor = tensor.flatten(); + + at::Tensor fp8_tensor; + at::Tensor scale; + cast_to_fp8(flatten_tensor, fp8_tensor, scale); + fp8_tensor = fp8_tensor.view(torch::kInt8); + auto output_tensor = torch::empty_like(fp8_tensor); + + pg_nccl.alltoall_base(output_tensor, fp8_tensor, tmp_size, tmp_size) + ->wait(std::chrono::milliseconds(10000)); + + at::Tensor scale_list = torch::zeros( + {world_size}, + at::TensorOptions().dtype(scale.dtype()).device(scale.device())); + pg_nccl._allgather_base(scale_list, scale) + ->wait(std::chrono::milliseconds(10000)); + + auto output_tensor_chunk = at::chunk(output_tensor, world_size); + + auto sumed_output = torch::zeros_like(output_tensor_chunk[0]).to(input_type); + + for (int rank = 0; rank < world_size; ++rank) { + sumed_output += + cast_from_fp8(output_tensor_chunk[rank].view(at::kFloat8_e4m3fn), + scale_list[rank], input_type); + } + + at::Tensor sumed_output_fp8; + at::Tensor sumed_output_scale; + cast_to_fp8(sumed_output, sumed_output_fp8, sumed_output_scale); + sumed_output_fp8 = sumed_output_fp8.view(torch::kInt8); + + auto sumed_output_scale_list = torch::zeros( + {world_size}, + at::TensorOptions().dtype(scale.dtype()).device(scale.device())); + auto sumed_output_fp8_list = torch::empty_like(tensor).to(torch::kInt8); + + pg_nccl._allgather_base(sumed_output_scale_list, sumed_output_scale) + ->wait(std::chrono::milliseconds(10000)); + pg_nccl._allgather_base(sumed_output_fp8_list, sumed_output_fp8) + ->wait(std::chrono::milliseconds(10000)); + + auto sumed_output_fp8_chunk = at::chunk(sumed_output_fp8_list, world_size); + std::vector output; + for (int rank = 0; rank < world_size; ++rank) { + output.push_back( + cast_from_fp8(sumed_output_fp8_chunk[rank].view(at::kFloat8_e4m3fn), + sumed_output_scale_list[rank], input_type)); + } + + tensors[0].copy_(at::cat(output).reshape(tensor.sizes())); + + auto future = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); + future->markCompleted(c10::IValue(tensors)); + return c10::make_intrusive(OpType::ALLGATHER, std::move(future)); +} + +c10::intrusive_ptr BackendDummy::allreduce_coalesced( + std::vector& /* unused */, + const AllreduceCoalescedOptions& /* unused */) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::alltoall( + std::vector& /* unused */, + std::vector& /* unused */, + const AllToAllOptions& /* unused */) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::alltoall_base( + at::Tensor& outputTensor, at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::barrier( + const BarrierOptions& /* unused */) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::broadcast( + std::vector& tensors, const BroadcastOptions& opts) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::gather( + std::vector>& /* unused */, + std::vector& /* unused */, const GatherOptions& /* unused */) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::reduce( + std::vector& /* unused */, const ReduceOptions& /* unused */) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::reduce_scatter( + std::vector& /* unused */, + std::vector>& /* unused */, + const ReduceScatterOptions& /* unused */) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::scatter( + std::vector& /* unused */, + std::vector>& /* unused */, + const ScatterOptions& /* unused */) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::send(std::vector& tensors, + int dstRank, int tag) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::recv(std::vector& tensors, + int srcRank, int tag) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::recvAnysource( + std::vector& tensors, int tag) { + throw std::runtime_error("not supported"); +} + +c10::intrusive_ptr BackendDummy::createBackendDummy( + const c10::intrusive_ptr<::c10d::Store>& store /* unused */, int rank, + int size, const std::chrono::duration& /* unused */) { + return c10::make_intrusive(store, rank, size); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("createBackendDummy", &BackendDummy::createBackendDummy); +} + +} // namespace c10d