Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC for GPU e2e feature testing #26367

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2444,3 +2444,47 @@ def setup_hypothesis(max_examples=30) -> None:
profile = HYPOTHESIS_PROFILE.value
logging.info("Using hypothesis profile: %s", profile)
hp.settings.load_profile(profile)

class XlaGpuFeatureTestCase(JaxTestCase):

def check_file_content(self, hlo_content, expected_unrolled_ag, expected_unrolled_rs):
"""
Check if the HLO contains the expected number of unrolled operations.

Returns:
bool: True if the counts match the expected values, False otherwise.
"""
expected_unrolled_ag = int(expected_unrolled_ag)
expected_unrolled_rs = int(expected_unrolled_rs)

pattern_ag = r'^unrolled_windowed_dot_general_body_ag'
pattern_rs = r'^unrolled_windowed_dot_general_body_rs'

count_ag = len(re.findall(pattern_ag, hlo_content, re.MULTILINE))
count_rs = len(re.findall(pattern_rs, hlo_content, re.MULTILINE))

return count_ag == expected_unrolled_ag and count_rs == expected_unrolled_rs

def check_collective_matmul(self, hlo_content, expected_unrolled_ag, expected_unrolled_rs):
"""
Verify correctness of collective matmul in HLO content.

Args:
hlo_content: The HLO file content as a string.
expected_unrolled_ag: Expected number of unrolled all-gather operations.
expected_unrolled_rs: Expected number of unrolled reduce-scatter operations.

Raises:
AssertionError: If the HLO content is not a string or if counts don't match expected values.
"""
if not test_device_matches(['gpu']):
raise unittest.SkipTest('Test only works with GPUs.')

self.assertIsInstance(hlo_content, str)

result = self.check_file_content(hlo_content, expected_unrolled_ag, expected_unrolled_rs)
self.assertTrue(result, f"Counts do not match expected values")


def check_fp8_gemm(self):
pass
2 changes: 2 additions & 0 deletions jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@
check_jvp as check_jvp,
check_vjp as check_vjp,
)

from jax._src.test_util import XlaGpuFeatureTestCase as XlaGpuFeatureTestCase