diff --git a/runtime/include/brt/backends/cuda/device/utils/op_kernel_impl_helpers.h b/runtime/include/brt/backends/cuda/device/utils/op_kernel_impl_helpers.h index e70737376..84f4065bb 100644 --- a/runtime/include/brt/backends/cuda/device/utils/op_kernel_impl_helpers.h +++ b/runtime/include/brt/backends/cuda/device/utils/op_kernel_impl_helpers.h @@ -143,6 +143,7 @@ using CurandOpKernelIfaceTraits = * struct ConcreateOpImpl { * ConcreateOpImpl(const OpAccessor&); * void Execute(args..., cudaStream_t); + * optional; * }; * using ConcreteOp = CudaOpKernel; */ @@ -153,6 +154,7 @@ BRT_DEF_OP_KERNEL_WRPPER(CudaOpKernel, * struct ConcreateOpImpl { * ConcreateOpImpl(const OpAccessor&); * void Execute(args..., cublasHandle_t, cudaStream_t); + * optional; * }; * using ConcreteOp = CublasOpKernel; */ @@ -163,6 +165,7 @@ BRT_DEF_OP_KERNEL_WRPPER(CublasOpKernel, * struct ConcreateOpImpl { * ConcreateOpImpl(const OpAccessor&); * void Execute(args..., cudnnHandle_t, cudaStream_t); + * optional; * }; * using ConcreteOp = CudnnOpKernel; */ @@ -173,6 +176,7 @@ BRT_DEF_OP_KERNEL_WRPPER(CudnnOpKernel, * struct ConcreateOpImpl { * ConcreateOpImpl(const OpAccessor&); * void Execute(args..., void* workspace, cudaStream_t); + * optional; * size_t GetWorkspaceSize(const ExecutionContext &); * }; * using ConcreteOp = CudaOpKernelWithWorkspace; diff --git a/runtime/include/brt/core/framework/op_kernel_impl_base.h b/runtime/include/brt/core/framework/op_kernel_impl_base.h index 9d21dc20a..77064a089 100644 --- a/runtime/include/brt/core/framework/op_kernel_impl_base.h +++ b/runtime/include/brt/core/framework/op_kernel_impl_base.h @@ -22,6 +22,7 @@ #include "brt/core/context/work_queue.h" #include "brt/core/framework/op_accessor.h" #include "brt/core/framework/op_kernel.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" namespace brt { @@ -169,6 +170,10 @@ template struct OpKernelIfaceTraitsBase { template common::Status static inline Run(Impl *impl, const ExecutionContext &ctx) { + auto status = impl->ProloguePerExecute(ctx); + if (!status.IsOK()) { + return status; + } return impl->Execute(Arguments::Get(impl, ctx)...); } @@ -187,10 +192,55 @@ template struct OpKernelIfaceTraitsBase { template struct NaiveOpKernelIfaceTraits : public OpKernelIfaceTraitsBase { + + template struct TrueHelper : std::true_type {}; + + template + struct HasProloguePerExecuteTraits { + template + static auto CheckPrologurePerExecute(int) + -> TrueHelper().ProloguePerExecute( + std::declval()...))>; + + template + static auto CheckPrologurePerExecute(...) -> std::false_type; + + public: + enum { + value = + decltype(CheckPrologurePerExecute(0))::value + }; + }; + template struct ImplMixin : public ImplBase { public: - explicit ImplMixin(const OpKernelInfo &info) - : ImplBase(info), info_(info) {} + explicit ImplMixin(const OpKernelInfo &info) : ImplBase(info), info_(info) { + // initialize `io_contain_dynamic_shape` + io_contain_dynamic_shape = false; + OpAccessor accessor(info); + size_t num_args = accessor.GetNumArgs(); + for (size_t i = 0; i < accessor.GetNumArgs(); ++i) { + auto shape = accessor.GetArgShape(i); + if (mlir::ShapedType::isDynamicShape(shape)) { + io_contain_dynamic_shape = true; + } + } + for (size_t i = 0; i < accessor.GetNumResults(); ++i) { + auto shape = accessor.GetArgShape(i + num_args); + if (mlir::ShapedType::isDynamicShape(shape)) { + io_contain_dynamic_shape = true; + } + } + } + + common::Status ProloguePerExecute(const ExecutionContext &ctx) { + if constexpr (HasProloguePerExecuteTraits::value) { + if (io_contain_dynamic_shape) { + ImplBase::ProloguePerExecute(GetOpAccessor(ctx)); + } + } + return Status::OK(); + } OpAccessor GetOpAccessor(const ExecutionContext &ctx) const { return OpAccessor(info_, ctx.exec_frame); @@ -198,6 +248,7 @@ struct NaiveOpKernelIfaceTraits : public OpKernelIfaceTraitsBase { private: const OpKernelInfo &info_; + bool io_contain_dynamic_shape; }; }; diff --git a/runtime/include/brt/core/framework/op_kernel_info.h b/runtime/include/brt/core/framework/op_kernel_info.h index d6c2a14b5..3a21968a8 100644 --- a/runtime/include/brt/core/framework/op_kernel_info.h +++ b/runtime/include/brt/core/framework/op_kernel_info.h @@ -143,13 +143,16 @@ class OpKernelInfo { // Utilities -// Get Tensor as uniuqe Index, from the ith argument of OpKernelInfo +// Get Tensor as unique Index, from the ith argument of OpKernelInfo size_t GetTensorIndexFromOpArgIndex(const OpKernelInfo &, unsigned int i); -// Get Tensor as uniuqe Index, from MLIR Value +// Get Tensor as unique Index, from MLIR Value size_t GetTensorIndexFromMLIRValue(const OpKernelInfo &, mlir::Value val); -// Get Scalar as uniuqe Index, from MLIR Value +// Get Scalar as unique Index, from the ith argument of OpKernelInfo +size_t GetScalarIndexFromOpArgIndex(const OpKernelInfo &, unsigned int i); + +// Get Scalar as unique Index, from MLIR Value size_t GetScalarIndexFromMLIRValue(const OpKernelInfo &, mlir::Value val); // Get Rank of MLIR Value, of ith argument of OpKernelInfo diff --git a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc index 39c77d5a3..b81d98791 100644 --- a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc +++ b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc @@ -45,6 +45,8 @@ using namespace mlir; #define BLOCK_SIZE_Z_ATTR "BlockSize.z" #define ARG_RANKS_ATTR "arg_ranks" #define CALL_CONVENTION_ATTR "call_convention" +#define DYNAMIC_CONFIG "__byteir_dynamic_config__" +#define KERNEL_LAUNCH_CONFIG_NUM 6 namespace brt { namespace cuda { @@ -123,42 +125,50 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) impl_->call_convention = "all"; // static assignment for config // TODO extend to support dynamic - if (!info.GetOperation()->hasAttrOfType(GRID_SIZE_X_ATTR)) { - BRT_THROW_EX(std::runtime_error, "no GridSize.x attr"); + bool dynamic_config_flag = false; + if (info.GetOperation()->hasAttr(DYNAMIC_CONFIG)) { + dynamic_config_flag = true; } + int gx, gy, gz, bx, by, bz; + gx = gy = gz = bx = by = bz = 1; + if (!dynamic_config_flag) { + if (!info.GetOperation()->hasAttrOfType(GRID_SIZE_X_ATTR)) { + BRT_THROW_EX(std::runtime_error, "no GridSize.x attr"); + } - if (!info.GetOperation()->hasAttrOfType(BLOCK_SIZE_X_ATTR)) { - BRT_THROW_EX(std::runtime_error, "no BlockSize.x attr"); - } + if (!info.GetOperation()->hasAttrOfType(BLOCK_SIZE_X_ATTR)) { + BRT_THROW_EX(std::runtime_error, "no BlockSize.x attr"); + } - int gx = static_cast(info.GetOperation() - ->getAttrOfType(GRID_SIZE_X_ATTR) - .getInt()), - gy = 1, gz = 1; - if (info.GetOperation()->hasAttrOfType(GRID_SIZE_Y_ATTR)) { - gy = static_cast(info.GetOperation() - ->getAttrOfType(GRID_SIZE_Y_ATTR) - .getInt()); - } - if (info.GetOperation()->hasAttrOfType(GRID_SIZE_Z_ATTR)) { - gz = static_cast(info.GetOperation() - ->getAttrOfType(GRID_SIZE_Z_ATTR) - .getInt()); - } + gx = static_cast(info.GetOperation() + ->getAttrOfType(GRID_SIZE_X_ATTR) + .getInt()), + gy = 1, gz = 1; + if (info.GetOperation()->hasAttrOfType(GRID_SIZE_Y_ATTR)) { + gy = static_cast(info.GetOperation() + ->getAttrOfType(GRID_SIZE_Y_ATTR) + .getInt()); + } + if (info.GetOperation()->hasAttrOfType(GRID_SIZE_Z_ATTR)) { + gz = static_cast(info.GetOperation() + ->getAttrOfType(GRID_SIZE_Z_ATTR) + .getInt()); + } - int bx = static_cast(info.GetOperation() - ->getAttrOfType(BLOCK_SIZE_X_ATTR) - .getInt()), - by = 1, bz = 1; - if (info.GetOperation()->hasAttrOfType(BLOCK_SIZE_Y_ATTR)) { - by = static_cast(info.GetOperation() - ->getAttrOfType(BLOCK_SIZE_Y_ATTR) - .getInt()); - } - if (info.GetOperation()->hasAttrOfType(BLOCK_SIZE_Z_ATTR)) { - bz = static_cast(info.GetOperation() - ->getAttrOfType(BLOCK_SIZE_Z_ATTR) - .getInt()); + bx = static_cast(info.GetOperation() + ->getAttrOfType(BLOCK_SIZE_X_ATTR) + .getInt()), + by = 1, bz = 1; + if (info.GetOperation()->hasAttrOfType(BLOCK_SIZE_Y_ATTR)) { + by = static_cast(info.GetOperation() + ->getAttrOfType(BLOCK_SIZE_Y_ATTR) + .getInt()); + } + if (info.GetOperation()->hasAttrOfType(BLOCK_SIZE_Z_ATTR)) { + bz = static_cast(info.GetOperation() + ->getAttrOfType(BLOCK_SIZE_Z_ATTR) + .getInt()); + } } std::vector ranks; @@ -172,6 +182,10 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) } auto num_arg = GetOpArgNum(info_); + // filter launch config in inputs + // TODO: make `shared_size` be a input operand in compiler. + if (dynamic_config_flag) + num_arg -= KERNEL_LAUNCH_CONFIG_NUM; impl_->grid = dim3(gx, gy, gz); impl_->block = dim3(bx, by, bz); impl_->shared_size = 0; @@ -198,6 +212,20 @@ common::Status PTXOpKernel::RunImpl(const ExecutionContext &ctx) { std::vector args; std::vector descs; args.reserve(impl_->arg_reserve_size); + bool dynamic_config_flag = false; + if (info_.GetOperation()->hasAttr(DYNAMIC_CONFIG)) { + dynamic_config_flag = true; + auto num_arg = GetOpArgNum(info_); + std::vector launch_config; + launch_config.reserve(KERNEL_LAUNCH_CONFIG_NUM); + for (size_t i = num_arg - KERNEL_LAUNCH_CONFIG_NUM; i < num_arg; ++i) { + size_t idx = GetScalarIndexFromOpArgIndex(info_, i); + launch_config.emplace_back(ctx.exec_frame->GetScalar(idx)); + } + impl_->grid = dim3(launch_config[0], launch_config[1], launch_config[2]); + impl_->block = dim3(launch_config[3], launch_config[4], launch_config[5]); + } + args.push_back(&(impl_->grid)); args.push_back(&(impl_->block)); args.push_back(&(impl_->shared_size)); @@ -205,13 +233,13 @@ common::Status PTXOpKernel::RunImpl(const ExecutionContext &ctx) { descs.reserve(impl_->tensor_ids.size()); for (size_t i = 0; i < impl_->tensor_ids.size(); ++i) { descs.emplace_back(ctx.exec_frame->GetAsyncValueRef(impl_->tensor_ids[i]), - impl_->tensor_ranks[i]); + ctx.exec_frame->GetShapeRef(impl_->tensor_ids[i])); if (impl_->call_convention == "bare_ptr") args.push_back(&descs.back().data); - else + else { InsertMemDescToArgs(descs.back(), args); + } } - auto work_queue = static_cast(ctx.work_queue); auto cuda_env = work_queue->GetCudaEnv(); BRT_ENFORCE(cuda_env.IsPrimaryContext(), diff --git a/runtime/lib/backends/cuda/providers/default/math/matmul.cc b/runtime/lib/backends/cuda/providers/default/math/matmul.cc index 9363c2ac5..3eb30f912 100644 --- a/runtime/lib/backends/cuda/providers/default/math/matmul.cc +++ b/runtime/lib/backends/cuda/providers/default/math/matmul.cc @@ -68,6 +68,24 @@ template MatmulImpl::MatmulImpl(const OpAccessor &accessor) { } } +template +void MatmulImpl::ProloguePerExecute(const OpAccessor &accessor) { + auto shape_a = accessor.GetArgShape(0); + auto shape_b = accessor.GetArgShape(1); + if (!lhs_transpose) { + m = shape_a[0]; + k = shape_a[1]; + } else { + m = shape_a[1]; + k = shape_a[0]; + } + if (!rhs_transpose) { + n = shape_b[1]; + } else { + n = shape_b[0]; + } +} + template <> void MatmulImpl::Execute(const float *a_val, const float *b_val, float *c_val, cublasHandle_t handle, diff --git a/runtime/lib/backends/cuda/providers/default/math/matmul.h b/runtime/lib/backends/cuda/providers/default/math/matmul.h index 1efaa01d9..481d89c7c 100644 --- a/runtime/lib/backends/cuda/providers/default/math/matmul.h +++ b/runtime/lib/backends/cuda/providers/default/math/matmul.h @@ -30,6 +30,8 @@ template class MatmulImpl { public: explicit MatmulImpl(const OpAccessor &accessor); + void ProloguePerExecute(const OpAccessor &); + void Execute(const T *a_val, const T *b_val, T *c_val, cublasHandle_t handle, cudaStream_t stream); diff --git a/runtime/lib/core/context/execution_frame.cc b/runtime/lib/core/context/execution_frame.cc index a091a75f5..4af0d4ac3 100644 --- a/runtime/lib/core/context/execution_frame.cc +++ b/runtime/lib/core/context/execution_frame.cc @@ -186,6 +186,11 @@ void BRTInferenceExecutionFrame::BindArg(size_t idx, const void *ptr) { } void *BRTInferenceExecutionFrame::GetArg(size_t idx) { + // this only for debug : get weight ptr + if (idx >= info_.graph_info.io_count) { + return ctx_.weights_and_ios[idx - info_.graph_info.io_count]; + } + BRT_ENFORCE(idx < info_.graph_info.io_count); int i = info_.weights.size() + idx; diff --git a/runtime/lib/core/framework/execution_plan.cc b/runtime/lib/core/framework/execution_plan.cc index 154b2b006..cd4c9f9e9 100644 --- a/runtime/lib/core/framework/execution_plan.cc +++ b/runtime/lib/core/framework/execution_plan.cc @@ -337,16 +337,20 @@ common::Status StaticBRTExecutionPlan::ProloguePerSession( return WalkResult::interrupt(); } - auto maybeSpace = brt::ir::GetSpace(op_arg); - if (!maybeSpace.has_value()) { - status_internal = Status(BRT, FAIL, "non-memref Arg of Op " + key); - return WalkResult::interrupt(); - } - - auto space = maybeSpace.value(); - IAllocator *cur_allocator = GetAllocator(allocators, space); - last_alloc = cur_allocator; + std::string space; + IAllocator *cur_allocator; + if (op_arg.getType().dyn_cast()) { + auto maybeSpace = brt::ir::GetSpace(op_arg); + if (!maybeSpace.has_value()) { + status_internal = + Status(BRT, FAIL, "non-memref Arg of Op " + key); + return WalkResult::interrupt(); + } + space = maybeSpace.value(); + cur_allocator = GetAllocator(allocators, space); + last_alloc = cur_allocator; + } // skip if visited if (visited_ptrs.count(arg_ptr) != 0) { continue; @@ -366,6 +370,10 @@ common::Status StaticBRTExecutionPlan::ProloguePerSession( graph_info_.tensor_to_id.emplace(arg_ptr, graph_info_.tensors.size()); graph_info_.tensors.push_back(arg_ptr); + } else if (op_arg.getType().isa()) { + int64_t scalar_index = graph_info_.scalars.size(); + graph_info_.scalar_to_id.emplace(arg_ptr, scalar_index); + graph_info_.scalars.push_back(arg_ptr); } else { status_internal = Status(BRT, FAIL, " non-supported Arg Type of Op " + key); @@ -473,6 +481,11 @@ common::Status StaticBRTExecutionPlan::ProloguePerSession( return WalkResult::interrupt(); } + // PTXOp launch config? + if (op_arg.getType().isa()) { + continue; + } + auto found_arg = graph_info_.tensor_to_id.find(arg_ptr); if (found_arg == graph_info_.tensor_to_id.end()) { status_internal = Status(BRT, FAIL, "cannot find arg"); diff --git a/runtime/lib/core/framework/op_kernel_info.cc b/runtime/lib/core/framework/op_kernel_info.cc index 25f4b1c60..e3704d91b 100644 --- a/runtime/lib/core/framework/op_kernel_info.cc +++ b/runtime/lib/core/framework/op_kernel_info.cc @@ -38,6 +38,18 @@ inline size_t GetTensorIndexFromOpArgIndexImpl(const OpKernelInfo &info, "at arg_idx " + std::to_string(arg_idx)); return found->second; } + +inline size_t GetScalarIndexFromOpArgIndexImpl(const OpKernelInfo &info, + unsigned int arg_idx) { + const std::unordered_map &arg_to_idx = + info.GetScalarToIndex(); + byre::ByreOp byre_op = cast(info.GetOperation()); + auto op_arg = byre_op->getOperand(arg_idx); + auto found = arg_to_idx.find(op_arg.getAsOpaquePointer()); + BRT_ENFORCE(found != arg_to_idx.end(), + "at arg_idx " + std::to_string(arg_idx)); + return found->second; +} } // namespace size_t GetTensorIndexFromOpArgIndex(const OpKernelInfo &info, @@ -45,6 +57,11 @@ size_t GetTensorIndexFromOpArgIndex(const OpKernelInfo &info, return GetTensorIndexFromOpArgIndexImpl(info, arg_idx); } +size_t GetScalarIndexFromOpArgIndex(const OpKernelInfo &info, + unsigned int arg_idx) { + return GetScalarIndexFromOpArgIndexImpl(info, arg_idx); +} + size_t GetTensorIndexFromMLIRValue(const OpKernelInfo &info, mlir::Value val) { const std::unordered_map &arg_to_idx = info.GetTensorToIndex(); diff --git a/runtime/test/backends/cuda/providers/default/request_context_test.cc b/runtime/test/backends/cuda/providers/default/request_context_test.cc index aea25fdd2..0b7bab74f 100644 --- a/runtime/test/backends/cuda/providers/default/request_context_test.cc +++ b/runtime/test/backends/cuda/providers/default/request_context_test.cc @@ -15,7 +15,10 @@ // //===----------------------------------------------------------------------===// +#include "brt/backends/cpu/providers/default/cpu_provider.h" +#include "brt/backends/cuda/device/common/cuda_call.h" #include "brt/backends/cuda/device/cuda_allocator.h" +#include "brt/backends/cuda/device/cuda_device_api.h" #include "brt/backends/cuda/providers/default/cuda_provider.h" #include "brt/core/common/status.h" #include "brt/core/session/request_context.h" @@ -214,3 +217,74 @@ TEST(CUDARequestContextTest, WeightSetting) { cudaFree(d_weight_0); cudaFree(d_arg_0); } + +TEST(SessionTest, GPUDynamicShape) { + Session session; + int device_id; + BRT_CUDA_CHECK(cudaGetDevice(&device_id)); + session.SetExecDevice(DeviceType::CUDA, device_id); + session.AddDeviceAPI(DeviceType::CUDA, GetCudaDeviceAPI()); + auto status_cpu_allocator = CPUAllocatorFactory(&session); + BRT_TEST_CHECK_STATUS(status_cpu_allocator); + auto status_cuda_allocator = CUDAAllocatorFactory(&session); + BRT_TEST_CHECK_STATUS(status_cuda_allocator); + auto status_cpu = NaiveCPUExecutionProviderFactory(&session); + BRT_TEST_CHECK_STATUS(status_cpu); + auto status_cuda = DefaultCUDAExecutionProviderFactory(&session); + BRT_TEST_CHECK_STATUS(status_cuda); + + std::string file_name = "test/test_files/DynamicShapes/MLP/entry.mlir"; + auto status_load = session.Load(file_name, "byre"); + BRT_TEST_CHECK_STATUS(status_load); + + std::unique_ptr request; + auto status_request = session.NewRequestContext(&request); + std::srand(std::time(0)); + BRT_TEST_CHECK_STATUS(status_request); + for (size_t t = 0; t < 10; ++t) { + int64_t N = 1024 - t; + // arg 0 & 1 is weight + // add weight offset in SetShape & SetType & GetShape .. + BRT_TEST_CHECK_STATUS(request->SetShape(2, {N, 10})); + BRT_TEST_CHECK_STATUS(request->SetShape(3, {N, 20})); + BRT_TEST_CHECK_STATUS(request->SetShape(4, {N, 20})); + + request->FinishIOBinding(); + // subtract the weight offset. + // TODO: refine this APIs to unify the input(SetShape/GetArg...) + float *i0 = static_cast(request->GetArg(0)), + *i1 = static_cast(request->GetArg(1)), + *o0 = static_cast(request->GetArg(2)); + + // float i_val_0 = rand() % 10 / 10.0 - 0.5; + // float i_val_1 = rand() % 10 / 10.0 - 0.5; + // float w_val_0 = rand() % 10 / 10.0 - 0.5; + // float w_val_1 = rand() % 10 / 10.0 - 0.5; + float i_val_0 = rand() % 10 - 5; + float i_val_1 = rand() % 10 - 5; + float w_val_0 = rand() % 10 - 5; + float w_val_1 = rand() % 10 - 5; + + // weight offset = idx + io_cnt + float *w0 = static_cast(request->GetArg(3)); + float *w1 = static_cast(request->GetArg(4)); + + AssignCUDABuffer(i0, N * 10, i_val_0); + AssignCUDABuffer(i1, N * 20, i_val_1); + AssignCUDABuffer(w0, 10 * 20, w_val_0); + AssignCUDABuffer(w1, 20, w_val_1); + + auto status_run = session.Run(*request); + BRT_TEST_CHECK_STATUS(status_run); + auto status_sync = request->Sync(); + BRT_TEST_CHECK_STATUS(status_sync); + + float result = w_val_0 * 10 * i_val_0 + w_val_1; + if (result < 0) + result = 0; + // llvm::outs() << "relu = " << result << ", "; + result += i_val_1; + // llvm::outs() << "result = " << result << "\n"; + CheckResult(o0, N * 20, result); + } +} diff --git a/runtime/test/test_files/DynamicShapes/MLP/device_output.ptx b/runtime/test/test_files/DynamicShapes/MLP/device_output.ptx new file mode 100644 index 000000000..3566fd3c7 --- /dev/null +++ b/runtime/test/test_files/DynamicShapes/MLP/device_output.ptx @@ -0,0 +1,110 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 7.0 +.target sm_80 +.address_size 64 + + // .globl Unknown0 + +.visible .entry Unknown0( + .param .u64 Unknown0_param_0, + .param .u64 Unknown0_param_1, + .param .u64 Unknown0_param_2, + .param .u64 Unknown0_param_3, + .param .u64 Unknown0_param_4, + .param .u64 Unknown0_param_5, + .param .u64 Unknown0_param_6, + .param .u64 Unknown0_param_7, + .param .u64 Unknown0_param_8, + .param .u64 Unknown0_param_9, + .param .u64 Unknown0_param_10, + .param .u64 Unknown0_param_11, + .param .u64 Unknown0_param_12, + .param .u64 Unknown0_param_13, + .param .u64 Unknown0_param_14, + .param .u64 Unknown0_param_15, + .param .u64 Unknown0_param_16, + .param .u64 Unknown0_param_17, + .param .u64 Unknown0_param_18, + .param .u64 Unknown0_param_19, + .param .u64 Unknown0_param_20, + .param .u64 Unknown0_param_21, + .param .u64 Unknown0_param_22, + .param .u64 Unknown0_param_23, + .param .u64 Unknown0_param_24, + .param .u64 Unknown0_param_25, + .param .u64 Unknown0_param_26, + .param .u64 Unknown0_param_27, + .param .u64 Unknown0_param_28, + .param .u64 Unknown0_param_29, + .param .u64 Unknown0_param_30, + .param .u64 Unknown0_param_31, + .param .u64 Unknown0_param_32 +) +{ + .reg .pred %p<4>; + .reg .b32 %r<5>; + .reg .f32 %f<7>; + .reg .b64 %rd<41>; + + mov.u32 %r1, %ctaid.x; + mov.u32 %r2, %ntid.x; + mov.u32 %r3, %tid.x; + cvt.s64.s32 %rd14, %r3; + mul.wide.s32 %rd15, %r2, %r1; + add.s64 %rd40, %rd15, %rd14; + ld.param.u64 %rd16, [Unknown0_param_15]; + mul.lo.s64 %rd7, %rd16, 20; + setp.ge.s64 %p1, %rd40, %rd7; + @%p1 bra $L__BB0_3; + ld.param.u64 %rd10, [Unknown0_param_27]; + cvta.to.global.u64 %rd1, %rd10; + ld.param.u64 %rd11, [Unknown0_param_20]; + cvta.to.global.u64 %rd2, %rd11; + ld.param.u64 %rd12, [Unknown0_param_13]; + cvta.to.global.u64 %rd3, %rd12; + ld.param.u64 %rd13, [Unknown0_param_8]; + cvta.to.global.u64 %rd4, %rd13; + mov.u32 %r4, %nctaid.x; + mul.wide.s32 %rd6, %r2, %r4; +$L__BB0_2: + mul.hi.s64 %rd17, %rd40, 7378697629483820647; + shr.u64 %rd18, %rd17, 63; + shr.s64 %rd19, %rd17, 3; + add.s64 %rd20, %rd19, %rd18; + mul.lo.s64 %rd21, %rd20, 20; + sub.s64 %rd22, %rd40, %rd21; + setp.lt.s64 %p2, %rd22, 0; + add.s64 %rd23, %rd22, 20; + selp.b64 %rd24, %rd23, %rd22, %p2; + shr.s64 %rd25, %rd40, 63; + xor.b64 %rd26, %rd25, %rd40; + mul.hi.s64 %rd27, %rd26, 7378697629483820647; + shr.u64 %rd28, %rd27, 63; + shr.s64 %rd29, %rd27, 3; + add.s64 %rd30, %rd29, %rd28; + xor.b64 %rd31, %rd30, %rd25; + shl.b64 %rd32, %rd24, 2; + add.s64 %rd33, %rd4, %rd32; + ld.global.nc.f32 %f1, [%rd33]; + mul.lo.s64 %rd34, %rd31, 20; + add.s64 %rd35, %rd34, %rd24; + shl.b64 %rd36, %rd35, 2; + add.s64 %rd37, %rd3, %rd36; + ld.global.nc.f32 %f2, [%rd37]; + add.s64 %rd38, %rd2, %rd36; + ld.global.nc.f32 %f3, [%rd38]; + add.rn.f32 %f4, %f1, %f2; + max.NaN.f32 %f5, %f4, 0f00000000; + add.rn.f32 %f6, %f3, %f5; + add.s64 %rd39, %rd1, %rd36; + st.global.f32 [%rd39], %f6; + add.s64 %rd40, %rd40, %rd6; + setp.lt.s64 %p3, %rd40, %rd7; + @%p3 bra $L__BB0_2; +$L__BB0_3: + ret; + +} diff --git a/runtime/test/test_files/DynamicShapes/MLP/entry.mlir b/runtime/test/test_files/DynamicShapes/MLP/entry.mlir new file mode 100644 index 000000000..ce0879565 --- /dev/null +++ b/runtime/test/test_files/DynamicShapes/MLP/entry.mlir @@ -0,0 +1,11 @@ +// --byre-host="device-file-name=your_file target=cuda entry-func=forward" + +module attributes {byre.container_module, gpu.container_module, torch.debug_module_name = "MLPModule"} { + func.func @forward(%arg0: memref<10x20xf32, "cuda"> {byre.argname = "Weight0", byre.argtype = 4 : i32, byre.weight_value = dense<"0xEDAEFD3D6B88963E1051E33DFC7732BE055CEE3C07F9413E080B9CBE9B2F4ABE5608BD3BF8E6DFBD507F46BEC61183BEACE23ABEF010903E824129BEAFB6D83C779721BE953B2B3E8B44CB3BB09383BE2456463EA8E7983ECE3D9EBC690042BEF0D34CBD5AFCAB3DCC1AF13D8E3CF43B0BF2583EC82B583D658D653C79C131BE9AEF24BD85B4B03D46DBAF3DEB4013BD26A9693EB17CC43CEDAF77BEA24E5BBE409C203EA0BE27BCA3380BBEC03C5A3E2775633D62069C3E0DF3963D259F883E7AAD743EEB5AC8BD4B210F3ECE1F303E3D4983BDA3A3F63D3D993FBE868FF1BC89B98ABE13D72E3D5012703E826D35BE725C76BECA748B3ED59EA23D63F6103E2F3527BE354722BECE329DBE39E496BE46C71CBDC849E83D8BD2ADBD2596A73DF2DA0ABED46BA03D752989BDB8624CBE971808BEB1C184BDCC23743E36B4AB3C7D51683E7ADA4B3EDA6EDD3D2DFDF1BDC9A5D5BD948E7E3ECFCE7FBE6E9E813D85DBE43B6543143D379970BEA80F16BED17C18BD05E9193E0F403B3D26EDA43B808C9ABEA6A2823D08D786BE2314AABCD8D783BE562D9ABE2B60943E88C91ABEB2DC513E136825BEBB0DBD3D8E26E7BDF4464EBDD9E68E3E5B5115BE3FCF123E1E194FBE2CACBF3DC71E153D96F24B3E53E9E9BD407141BE714361BE4D2E963DA67B8F3EB62B78BE0D424E3D5EBB7F3D056098BEC5F011BDCB26033E1656DD3C60B90DBE7EFF583DE9C118BDDDF5673E451A58BD8E0C32BE121C05BC8D148E3ED80AEA3DBE27AA3A1F534B3E97023EBD85C231BEEB7CF33BF5FF363D32DCDBBDFBFED83DB640623D66E5EA3D785D2B3E4E7470BE651C673EDC57473E601A24BECF5FFABD38AA14BE9EFC60BCE3E5983E10EE7D3E2BE0453ECE3C41BE0A3B753E4225383D967059BDC197B8BD4FA407BE5C0C17BE0F16953E9884E7BC2E6CF6BDAF01453EB6F466BE37968ABE5BC056BECBD28E3ED9AC493DACC7103D46EB73BEBB39213DFD30F0BD835CD4BD426F4FBD620472BE67FB97BEFBAE8CBDF08F553EFAB77D3EFCC7D3BC76249D3EA2C404BE937B9B3E14197E3D30212D3E262F63BD7899D1BDF72C993E7537DE3D6FD39D3E988D4B3DDB132B3E097461BD8461313C"> : tensor<10x20xf32>}, %arg1: memref<20xf32, "cuda"> {byre.argname = "Weight1", byre.argtype = 4 : i32, byre.weight_value = dense<[0.124238588, -0.0375917405, -0.178324029, 2.1018261E-4, -0.0708629936, 0.179958493, 0.201986402, -0.0302014686, -0.0842267424, 0.0796111747, 0.0201944318, -0.183529228, -0.133614406, -0.0192934573, 0.193412527, 0.219010666, -0.0464102961, 0.00334274326, -0.0029087835, 0.0903228372]> : tensor<20xf32>}, %arg2: memref {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg3: memref {byre.argname = "Input1", byre.argtype = 1 : i32}, %arg4: memref {byre.argname = "Output0", byre.argtype = 2 : i32}) attributes {byre.entry_point, device_file_name = "device_output.ptx"} { + %0:4 = "byre.compute_shape"(%arg2) <{shape_fn = "shapeComputaionFunc_0"}> {device = "cuda", kernel_name = "shapeComputaionFunc_0", llvm_file_name = "host_kernels.ll"} : (memref) -> (index, index, index, index) + %alloc = memref.alloc(%0#1) : memref + byre.compute @MatmulOp_f32f32_f32(%arg2, %arg0, %alloc) {device = "cuda", lhs_contracting_dimension = 1 : i64, memory_effects = [1 : i32, 1 : i32, 2 : i32], rhs_contracting_dimension = 0 : i64} : memref, memref<10x20xf32, "cuda">, memref + byre.compute @PTXOp(%arg2, %arg1, %alloc, %arg3, %arg4, %0#2, %0#0, %0#0, %0#3, %0#0, %0#0) {__byteir_dynamic_config__, device = "cuda", kernel_name = "Unknown0"} : memref, memref<20xf32, "cuda">, memref, memref, memref, index, index, index, index, index, index + return + } +} diff --git a/runtime/test/test_files/DynamicShapes/MLP/host_kernels.ll b/runtime/test/test_files/DynamicShapes/MLP/host_kernels.ll new file mode 100644 index 000000000..17a64e4f6 --- /dev/null +++ b/runtime/test/test_files/DynamicShapes/MLP/host_kernels.ll @@ -0,0 +1,42 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" + +define { i64, i64, i64, i64 } @shapeComputaionFunc_0(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6) { + %8 = mul i64 %3, 20 + %9 = add i64 %8, -1 + %10 = sdiv i64 %9, 1024 + %11 = add i64 %10, 1 + %12 = sub i64 0, %8 + %13 = sdiv i64 %12, 1024 + %14 = sub i64 0, %13 + %15 = icmp sgt i64 %8, 0 + %16 = select i1 %15, i64 %11, i64 %14 + %17 = call i64 @llvm.smax.i64(i64 %16, i64 1) + %18 = insertvalue { i64, i64, i64, i64 } { i64 1, i64 undef, i64 undef, i64 undef }, i64 %3, 1 + %19 = insertvalue { i64, i64, i64, i64 } %18, i64 %17, 2 + %20 = insertvalue { i64, i64, i64, i64 } %19, i64 256, 3 + ret { i64, i64, i64, i64 } %20 +} + +define void @_mlir_ciface_shapeComputaionFunc_0(ptr %0, ptr %1) { + %3 = load { ptr, ptr, i64, [2 x i64], [2 x i64] }, ptr %1, align 8 + %4 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 0 + %5 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 1 + %6 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 2 + %7 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 3, 0 + %8 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 3, 1 + %9 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 4, 0 + %10 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %3, 4, 1 + %11 = call { i64, i64, i64, i64 } @shapeComputaionFunc_0(ptr %4, ptr %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10) + store { i64, i64, i64, i64 } %11, ptr %0, align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i64 @llvm.smax.i64(i64, i64) #0 + +attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } + +!llvm.module.flags = !{!0} + +!0 = !{i32 2, !"Debug Info Version", i32 3} \ No newline at end of file