Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

softmax with EVT (draft) #177

Open
wants to merge 1 commit into
base: sycl-develop
Choose a base branch
from

Conversation

jiyang1011
Copy link
Collaborator

No description provided.

examples/sycl/pvc/pvc_gemm_with_epilogue_softmax.cpp Outdated Show resolved Hide resolved
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

gemm_op.can_implement(arguments);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling can_implement without checking and acting on result is pointless.

}
}
auto synchronize = [&] () {};
cst_callbacks.reduce(nullptr, synchronize, 0, 0, true, trD);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at Nvidia implementation it looks like the reduce call should be within the epi_n/epi_m loops and epi_n/epi_m should be passed to the call instead of the zeroes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This version is a draft. I do know epi_n / epi_m should be passed to the call, but the performance will be not good

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break all other implementations that use reduce.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why won’t the performance be good? Is it because we need to check for the last iteration? If so, the compiler should be able to optimise that since all the information is known at compile time.

include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp Outdated Show resolved Hide resolved
constexpr auto dim1 = decltype(size<1>(visit_results))::value;
constexpr auto dim2 = decltype(size<2>(visit_results))::value;

auto t1 = make_tensor(static_cast<decltype(visit_results) &&>(visit_results).data(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put better naming here

Comment on lines 446 to 447
constexpr auto m0 = decltype(size<0>(t1))::value;
constexpr auto m1 = decltype(size<1>(t1))::value;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these two dimensions try to define the reduce dimension and no reduce dimension


auto smem = syclcompat::local_mem<float[Sg_Nums * vec_size]>();

auto t =
Copy link
Collaborator

@mehdi-goli mehdi-goli Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is t a slice/reshape of t1? and can we have better naming here.
Also, according to your example the t1 is 16 x2 but your t has the 16x1 slice of the first row of M. So what happened to the next row of N there?


CUTLASS_PRAGMA_UNROLL
for (int loop = 0; loop < loop_cnt; loop++) {
auto loop_t = t(_, loop, _);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here if we have better naming


template <uint32_t sg_num, class mem_t, class RTensor>
CUTLASS_DEVICE
auto group_reduce_max1(mem_t smem, RTensor const &t, float *out) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does max1 mean?

//

bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) {
return true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have a TODO comment.

);

syclcompat::wait();
#define IDX (l * M * N + i * N + j)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no macros for computation (use function or lambda if you want it to be inline)

syclcompat::wait();
double hbm =
L *
(M * K * sizeof(ElementA) + K * N * sizeof(ElementB) +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inconsistent. on line 329 we use options.m

#include "cutlass/workspace.h"

#include "cute/tensor.hpp"
#include "sm90_visitor_tma_warpspecialized.hpp"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is that include needed?

inline x { assert(false); }
#endif

SYCL_DEVICE_OCL(float sub_group_reduce_add(float i));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this using the OCL function?

#undef EXP
#undef DIV

#define MAX sycl::max
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#define MAX sycl::max
using sycl::max;


template<uint32_t sg_num, uint32_t N, class mem_t>
CUTLASS_DEVICE
void work_group_reduce_max(mem_t &mem, float* vec) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is identical to work_group_reduce_sum. Please template on the op and avoid duplication.

}
}

work_group_reduce_sum<sg_per_wg_n, N, decltype(slm_base)>(slm_base, out);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like an odd why to split this into 2 functions. Because writing the data to slm is an essential part of the work_group reduce. I suggest to move it into that function.

auto base = sg_local_id * N * step + sg_group_id_n;
auto local_max = mem[base];

if (sg_group_id_n < N) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if-statement is the only reason to have the if statement on line 175? If so why not:

Suggested change
if (sg_group_id_n < N) {
if (sg_num <= N || sg_group_id_n < N) {

Because of short-circuit the or condition isn't evaluated if not needed. And that way you don't need the duplication of the if-else branches starting on line 175.

template<class STensor, class SyncFn, class VTensor>
CUTLASS_DEVICE void
reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) {
constexpr auto dim0 = decltype(size<0>(visit_results))::value;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just:

Suggested change
constexpr auto dim0 = decltype(size<0>(visit_results))::value;
constexpr auto dim0 = size<0>(visit_results);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants