From 357ef748424910dfe74cd663fd52dadae2dbd08d Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 31 Jan 2025 17:41:08 +0530 Subject: [PATCH] Nuking extra code branching in task runner for inlets, outlets (#46302) Co-authored-by: Ash Berlin-Taylor --- .../airflow/sdk/execution_time/task_runner.py | 56 +++++++++---------- .../tests/execution_time/test_task_runner.py | 49 +++++++++++++++- 2 files changed, 75 insertions(+), 30 deletions(-) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 8b21fcfaf486a..7bcecc6343343 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -519,32 +519,31 @@ def run(ti: RuntimeTaskInstance, log: Logger): inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)] outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)] SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore - msg = SUPERVISOR_COMMS.get_message() # type: ignore - - if isinstance(msg, OKResponse) and not msg.ok: - log.info("Runtime checks failed for task, marking task as failed..") - msg = TaskState( - state=TerminalTIState.FAILED, - end_date=datetime.now(tz=timezone.utc), - ) - else: - context = ti.get_template_context() - with set_current_context(context): - jinja_env = ti.task.dag.get_template_env() - ti.task = ti.render_templates(context=context, jinja_env=jinja_env) - # TODO: Get things from _execute_task_with_callbacks - # - Pre Execute - # etc - result = _execute_task(context, ti.task) - - _push_xcom_if_needed(result, ti) - - task_outlets, outlet_events = _process_outlets(context, ti.task.outlets) - msg = SucceedTask( - end_date=datetime.now(tz=timezone.utc), - task_outlets=task_outlets, - outlet_events=outlet_events, - ) + ok_response = SUPERVISOR_COMMS.get_message() # type: ignore + if not isinstance(ok_response, OKResponse) or not ok_response.ok: + log.info("Runtime checks failed for task, marking task as failed..") + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + return + context = ti.get_template_context() + with set_current_context(context): + jinja_env = ti.task.dag.get_template_env() + ti.task = ti.render_templates(context=context, jinja_env=jinja_env) + # TODO: Get things from _execute_task_with_callbacks + # - Pre Execute + # etc + result = _execute_task(context, ti.task) + + _push_xcom_if_needed(result, ti) + + task_outlets, outlet_events = _process_outlets(context, ti.task.outlets) + msg = SucceedTask( + end_date=datetime.now(tz=timezone.utc), + task_outlets=task_outlets, + outlet_events=outlet_events, + ) except TaskDeferred as defer: # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) @@ -610,8 +609,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): log.exception("Task failed with exception") # TODO: Run task failure callbacks here msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) - if msg: - SUPERVISOR_COMMS.send_request(msg=msg, log=log) + finally: + if msg: + SUPERVISOR_COMMS.send_request(msg=msg, log=log) def _execute_task(context: Context, task: BaseOperator): diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index fef2f505b2a55..b4a83955670b3 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -647,14 +647,58 @@ def test_run_with_asset_outlets( ti = create_runtime_ti(task=task, dag_id="dag_with_asset_outlet_task") instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) + mock_supervisor_comms.get_message.return_value = OKResponse( + ok=True, + ) run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY) -def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms): +@pytest.mark.parametrize( + ["ok", "last_expected_msg"], + [ + pytest.param( + True, + SucceedTask( + end_date=timezone.datetime(2024, 12, 3, 10, 0), + task_outlets=[ + AssetProfile(name="name", uri="s3://bucket/my-task", asset_type="Asset"), + AssetProfile(name="new-name", uri="s3://bucket/my-task", asset_type="Asset"), + ], + outlet_events=[ + { + "asset_alias_events": [], + "extra": {}, + "key": {"name": "name", "uri": "s3://bucket/my-task"}, + }, + { + "asset_alias_events": [], + "extra": {}, + "key": {"name": "new-name", "uri": "s3://bucket/my-task"}, + }, + ], + ), + id="runtime_checks_pass", + ), + pytest.param( + False, + TaskState( + state=TerminalTIState.FAILED, + end_date=timezone.datetime(2024, 12, 3, 10, 0), + ), + id="runtime_checks_fail", + ), + ], +) +def test_run_with_inlets_and_outlets( + create_runtime_ti, mock_supervisor_comms, time_machine, ok, last_expected_msg +): """Test running a basic tasks with inlets and outlets.""" + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + from airflow.providers.standard.operators.bash import BashOperator task = BashOperator( @@ -672,7 +716,7 @@ def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms): ti = create_runtime_ti(task=task, dag_id="dag_with_inlets_and_outlets") mock_supervisor_comms.get_message.return_value = OKResponse( - ok=True, + ok=ok, ) run(ti, log=mock.MagicMock()) @@ -688,6 +732,7 @@ def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms): ], ) mock_supervisor_comms.send_request.assert_any_call(msg=expected, log=mock.ANY) + mock_supervisor_comms.send_request.assert_any_call(msg=last_expected_msg, log=mock.ANY) class TestRuntimeTaskInstance: