Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sharded_moe.py to support top2 gate with Tutel #6948

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

xenshinu
Copy link

@xenshinu xenshinu commented Jan 14, 2025

Tutel is forced to be unused on k > 1 since #2053
Given the fact that multiple experts per token is very common, and the gather and scatter operation without Tutel is so inefficient, I added support of tutel to top2 gate and tested on pipeline engine. This can be done for any k actually, I'll push that later when I have time to test,

Given the fact that multiple experts per token is very common, and the gather and scatter operation without Tutel is so inefficient, I added support of tutel to top2 gate and tested on pipeline engine. This can be done for any k actually, I'll push that later when I have time to test,
@xenshinu xenshinu requested a review from tohtana as a code owner January 14, 2025 20:11
@xenshinu
Copy link
Author

@xenshinu please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree company="University of Michigan"

@xenshinu
Copy link
Author

xenshinu commented Jan 14, 2025

Not sure if this check is needed

    if use_tutel:
        # Tutel doesn't support index values masked with zero
        # so we need to replace masked indices with -1
        indices1_mask = mask1.sum(dim=1) * num_experts - 1
        indices1_s = torch.min(indices1_s, indices1_mask)
        indices2_mask = mask2.sum(dim=1) * num_experts - 1
        indices2_s = torch.min(indices2_s, indices2_mask)

I see that in top1gate,


but when I refer to examples from Tutel
https://github.com/microsoft/Tutel/blob/ab7937bb929bc78111d74261b490da25657a7e5c/tutel/impls/fast_dispatch.py#L143

I didn't see any specify check for non-zero mask.

@loadams loadams requested a review from hwchen2017 January 14, 2025 23:39
@@ -517,7 +535,7 @@ def forward(self,

elif self.k == 2:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling, use_tutel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xenshinu - thanks for this PR, could you run the pre-commit formatter on your branch to resolve the "Formatting" error? I believe it just wants the use_tutel on a new line here.

Copy link
Author

@xenshinu xenshinu Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out. I've updated the file and it looks like the pre-commit check has passed.
Can someone take a look to this question?
#6948 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants