From d31c26260894e35a36a674b977b28d26eeb425a8 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Wed, 28 Aug 2024 23:37:37 +0400 Subject: [PATCH] feat: add support for cross join (#1118) Co-authored-by: tokoko --- ibis_substrait/compiler/translate.py | 30 ++++++++++++++------ ibis_substrait/tests/compiler/test_parity.py | 17 +++++++++++ 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/ibis_substrait/compiler/translate.py b/ibis_substrait/compiler/translate.py index f4b51932..a764f346 100644 --- a/ibis_substrait/compiler/translate.py +++ b/ibis_substrait/compiler/translate.py @@ -914,14 +914,25 @@ def join( for i, join_link in enumerate(op.rest): predicates = [pred.to_expr() for pred in join_link.predicates] - relation = stalg.Rel( - join=stalg.JoinRel( - left=( - translate(op.first.parent, compiler=compiler, **kwargs) - if i == 0 - else relation - ), - right=translate(join_link.table.parent, compiler=compiler, **kwargs), + left = ( + translate(op.first.parent, compiler=compiler, **kwargs) + if i == 0 + else relation + ) + + right = translate(join_link.table.parent, compiler=compiler, **kwargs) + + if join_link.how == "cross": + rel = stalg.CrossRel( + left=left, + right=right, + ) + + relation = stalg.Rel(cross=rel) + else: + rel = stalg.JoinRel( + left=left, + right=right, expression=translate( functools.reduce(operator.and_, predicates), compiler=compiler, @@ -930,7 +941,8 @@ def join( ), type=_translate_join_type(join_link.how), ) - ) + + relation = stalg.Rel(join=rel) return relation diff --git a/ibis_substrait/tests/compiler/test_parity.py b/ibis_substrait/tests/compiler/test_parity.py index f4e14ac9..394770aa 100644 --- a/ibis_substrait/tests/compiler/test_parity.py +++ b/ibis_substrait/tests/compiler/test_parity.py @@ -125,6 +125,23 @@ def test_inner_join(consumer: str, request): run_parity_test(request.getfixturevalue(consumer), expr) +@pytest.mark.parametrize( + "consumer", + [ + pytest.param( + "acero_consumer", + marks=[ + pytest.mark.xfail(pa.ArrowNotImplementedError, reason="Unimplemented") + ], + ), + "datafusion_consumer", + ], +) +def test_cross_join(consumer: str, request): + expr = orders.cross_join(stores) + run_parity_test(request.getfixturevalue(consumer), expr) + + @pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"]) def test_left_join(consumer: str, request): expr = orders.join(stores, orders["fk_store_id"] == stores["store_id"], how="left")