Skip to content

Commit

Permalink
Nuking extra code branching in task runner for inlets, outlets (apach…
Browse files Browse the repository at this point in the history
…e#46302)

Co-authored-by: Ash Berlin-Taylor <[email protected]>
  • Loading branch information
amoghrajesh and ashb authored Jan 31, 2025
1 parent 49a9fb6 commit 357ef74
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 30 deletions.
56 changes: 28 additions & 28 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,32 +519,31 @@ def run(ti: RuntimeTaskInstance, log: Logger):
inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)]
outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)]
SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore
msg = SUPERVISOR_COMMS.get_message() # type: ignore

if isinstance(msg, OKResponse) and not msg.ok:
log.info("Runtime checks failed for task, marking task as failed..")
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
else:
context = ti.get_template_context()
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
# TODO: Get things from _execute_task_with_callbacks
# - Pre Execute
# etc
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

task_outlets, outlet_events = _process_outlets(context, ti.task.outlets)
msg = SucceedTask(
end_date=datetime.now(tz=timezone.utc),
task_outlets=task_outlets,
outlet_events=outlet_events,
)
ok_response = SUPERVISOR_COMMS.get_message() # type: ignore
if not isinstance(ok_response, OKResponse) or not ok_response.ok:
log.info("Runtime checks failed for task, marking task as failed..")
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
return
context = ti.get_template_context()
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
# TODO: Get things from _execute_task_with_callbacks
# - Pre Execute
# etc
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

task_outlets, outlet_events = _process_outlets(context, ti.task.outlets)
msg = SucceedTask(
end_date=datetime.now(tz=timezone.utc),
task_outlets=task_outlets,
outlet_events=outlet_events,
)
except TaskDeferred as defer:
# TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
Expand Down Expand Up @@ -610,8 +609,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
log.exception("Task failed with exception")
# TODO: Run task failure callbacks here
msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc))
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
finally:
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def _execute_task(context: Context, task: BaseOperator):
Expand Down
49 changes: 47 additions & 2 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,14 +647,58 @@ def test_run_with_asset_outlets(
ti = create_runtime_ti(task=task, dag_id="dag_with_asset_outlet_task")
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
mock_supervisor_comms.get_message.return_value = OKResponse(
ok=True,
)

run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY)


def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms):
@pytest.mark.parametrize(
["ok", "last_expected_msg"],
[
pytest.param(
True,
SucceedTask(
end_date=timezone.datetime(2024, 12, 3, 10, 0),
task_outlets=[
AssetProfile(name="name", uri="s3://bucket/my-task", asset_type="Asset"),
AssetProfile(name="new-name", uri="s3://bucket/my-task", asset_type="Asset"),
],
outlet_events=[
{
"asset_alias_events": [],
"extra": {},
"key": {"name": "name", "uri": "s3://bucket/my-task"},
},
{
"asset_alias_events": [],
"extra": {},
"key": {"name": "new-name", "uri": "s3://bucket/my-task"},
},
],
),
id="runtime_checks_pass",
),
pytest.param(
False,
TaskState(
state=TerminalTIState.FAILED,
end_date=timezone.datetime(2024, 12, 3, 10, 0),
),
id="runtime_checks_fail",
),
],
)
def test_run_with_inlets_and_outlets(
create_runtime_ti, mock_supervisor_comms, time_machine, ok, last_expected_msg
):
"""Test running a basic tasks with inlets and outlets."""
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

from airflow.providers.standard.operators.bash import BashOperator

task = BashOperator(
Expand All @@ -672,7 +716,7 @@ def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms):

ti = create_runtime_ti(task=task, dag_id="dag_with_inlets_and_outlets")
mock_supervisor_comms.get_message.return_value = OKResponse(
ok=True,
ok=ok,
)

run(ti, log=mock.MagicMock())
Expand All @@ -688,6 +732,7 @@ def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms):
],
)
mock_supervisor_comms.send_request.assert_any_call(msg=expected, log=mock.ANY)
mock_supervisor_comms.send_request.assert_any_call(msg=last_expected_msg, log=mock.ANY)


class TestRuntimeTaskInstance:
Expand Down

0 comments on commit 357ef74

Please sign in to comment.