Skip to content

Commit

Permalink
renew test case
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu302 committed Aug 15, 2024
1 parent aa396a0 commit a9e8f3b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions tests/numerical_test/testset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def _get_test_files_from_dir(directory):

CUDA_WITH_GEMM_CODEGEN_XFAIL_SET = {
"MatmulTransposeAF16Module_basic",
"MatmulTransposeBF16Module_basic",
"MatmulTransposeModule_basic",
# "MatmulTransposeBF16Module_basic",
# "MatmulTransposeModule_basic",
# TODO: Test passed on A10. But failed on CI machine.
"BatchMatmulAddF32Module_basic",
# "BatchMatmulAddF32Module_basic",
# TODO: fix bug
"gemm_crr_f16f16f32.mlir",
"bmm_rcr_f16f16f32.mlir",
Expand Down
6 changes: 3 additions & 3 deletions tests/numerical_test/torch_e2e_testing/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def forward(self, a, b):
@register_test_case(module_factory=lambda: MatmulTransposeAF16Module())
def MatmulTransposeAF16Module_basic(module, tu: TestUtils):
module.forward(tu.rand(64, 128).to(torch.float16),
tu.rand(64, 1024).to(torch.float16))
tu.rand(64, 128).to(torch.float16))


class MatmulTransposeBF16Module(torch.nn.Module):
Expand All @@ -86,8 +86,8 @@ def forward(self, a, b):

@register_test_case(module_factory=lambda: MatmulTransposeBF16Module())
def MatmulTransposeBF16Module_basic(module, tu: TestUtils):
module.forward(tu.rand(128, 32).to(torch.float32),
tu.rand(128, 32).to(torch.float32))
module.forward(tu.rand(128, 64).to(torch.float32),
tu.rand(128, 64).to(torch.float32))

class MatmulTransposeModule(torch.nn.Module):

Expand Down

0 comments on commit a9e8f3b

Please sign in to comment.