Skip to content

Commit

Permalink
Merge branch 'main' into feature/http-extra-options-check-response
Browse files Browse the repository at this point in the history
  • Loading branch information
dabla authored Jan 31, 2025
2 parents bd295d4 + 357ef74 commit 3b9e524
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 144 deletions.
1 change: 0 additions & 1 deletion airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def _serialize_dag_capturing_errors(
except Exception:
log.exception("Failed to write serialized DAG dag_id=%s fileloc=%s", dag.dag_id, dag.fileloc)
dagbag_import_error_traceback_depth = conf.getint("core", "dagbag_import_error_traceback_depth")
# todo AIP-66: this needs to use bundle name / rel fileloc instead
return [(dag.fileloc, traceback.format_exc(limit=-dagbag_import_error_traceback_depth))]


Expand Down
74 changes: 29 additions & 45 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,10 @@ class DagFileStat:
class DagFileInfo:
"""Information about a DAG file."""

rel_path: Path
path: str # absolute path of the file
bundle_name: str
bundle_path: Path | None = field(compare=False, default=None)

@property
def absolute_path(self) -> Path:
if not self.bundle_path:
raise ValueError("bundle_path not set")
return self.bundle_path / self.rel_path


def _config_int_factory(section: str, key: str):
return functools.partial(conf.getint, section, key)
Expand Down Expand Up @@ -247,9 +241,7 @@ def _scan_stale_dags(self):
elapsed_time_since_refresh = now - self._last_deactivate_stale_dags_time
if elapsed_time_since_refresh > self.parsing_cleanup_interval:
last_parsed = {
file_info: stat.last_finish_time
for file_info, stat in self._file_stats.items()
if stat.last_finish_time
fp: stat.last_finish_time for fp, stat in self._file_stats.items() if stat.last_finish_time
}
self.deactivate_stale_dags(last_parsed=last_parsed)
self._last_deactivate_stale_dags_time = time.monotonic()
Expand Down Expand Up @@ -277,11 +269,14 @@ def deactivate_stale_dags(
# last_parsed_time is the processor_timeout. Longer than that indicates that the DAG is
# no longer present in the file. We have a stale_dag_threshold configured to prevent a
# significant delay in deactivation of stale dags when a large timeout is configured
file_info = DagFileInfo(rel_path=Path(dag.relative_fileloc), bundle_name=dag.bundle_name)
if last_finish_time := last_parsed.get(file_info, None):
if dag.last_parsed_time + timedelta(seconds=self.stale_dag_threshold) < last_finish_time:
self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id)
to_deactivate.add(dag.dag_id)
dag_file_path = DagFileInfo(path=dag.fileloc, bundle_name=dag.bundle_name)
if (
dag_file_path in last_parsed
and (dag.last_parsed_time + timedelta(seconds=self.stale_dag_threshold))
< last_parsed[dag_file_path]
):
self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id)
to_deactivate.add(dag.dag_id)

if to_deactivate:
deactivated_dagmodel = session.execute(
Expand Down Expand Up @@ -488,30 +483,30 @@ def _refresh_dag_bundles(self):
"Version changed for %s, new version: %s", bundle.name, version_after_refresh
)

found_file_infos = [
DagFileInfo(rel_path=p, bundle_name=bundle.name, bundle_path=bundle.path)
for p in self._find_files_in_bundle(bundle)
]
bundle_file_paths = self._find_files_in_bundle(bundle)

new_file_paths = [f for f in self._file_paths if f.bundle_name != bundle.name]
new_file_paths.extend(found_file_infos)
new_file_paths.extend(
DagFileInfo(path=path, bundle_path=bundle.path, bundle_name=bundle.name)
for path in bundle_file_paths
)
self.set_file_paths(new_file_paths)

self.deactivate_deleted_dags(active_files=found_file_infos)
self.deactivate_deleted_dags(bundle_file_paths)
self.clear_nonexistent_import_errors()

self._bundle_versions[bundle.name] = bundle.get_current_version()

def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[Path]:
"""Get relative file paths from bundle dir."""
def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[str]:
"""Refresh file paths from bundle dir."""
# Build up a list of Python files that could contain DAGs
self.log.info("Searching for files in %s at %s", bundle.name, bundle.path)
file_paths = [Path(x).relative_to(bundle.path) for x in list_py_file_paths(bundle.path)]
file_paths = list_py_file_paths(bundle.path)
self.log.info("Found %s files for bundle %s", len(file_paths), bundle.name)

return file_paths

def deactivate_deleted_dags(self, active_files: list[DagFileInfo]) -> None:
def deactivate_deleted_dags(self, file_paths: set[str]) -> None:
"""Deactivate DAGs that come from files that are no longer present."""

def _iter_dag_filelocs(fileloc: str) -> Iterator[str]:
Expand All @@ -531,20 +526,10 @@ def _iter_dag_filelocs(fileloc: str) -> Iterator[str]:
except zipfile.BadZipFile:
self.log.exception("There was an error accessing ZIP file %s %s", fileloc)

active_subpaths: set[tuple[str, str]] = set()
"""
'subpath' here means bundle + modified rel path. What does modified rel path mean?
Well, '_iter_dag_filelocs' walks through zip files and may return a "path" that is,
rel path to the zip, plus the rel path within the zip. So, since this is is a bit different
from most uses of the word "rel path", I wanted to call it something different.
A set is used presumably since many dags can be in one file.
"""

for info in active_files:
for path in _iter_dag_filelocs(str(info.absolute_path)):
active_subpaths.add((info.bundle_name, path))
dag_filelocs = {full_loc for path in file_paths for full_loc in _iter_dag_filelocs(path)}

DagModel.deactivate_deleted_dags(active_subpaths)
# TODO: AIP-66: make bundle aware, as fileloc won't be unique long term.
DagModel.deactivate_deleted_dags(dag_filelocs)

def _print_stat(self):
"""Occasionally print out stats about how fast the files are getting processed."""
Expand All @@ -568,8 +553,7 @@ def clear_nonexistent_import_errors(self, session=NEW_SESSION):
if self._file_paths:
query = query.where(
tuple_(ParseImportError.filename, ParseImportError.bundle_name).notin_(
# todo AIP-66: ParseImportError should have rel fileloce + bundle name
[(str(f.absolute_path), f.bundle_name) for f in self._file_paths]
[(f.path, f.bundle_name) for f in self._file_paths]
),
)

Expand Down Expand Up @@ -614,7 +598,7 @@ def _log_file_processing_stats(self, known_file_paths):
proc = self._processors.get(file_path)
num_dags = stat.num_dags
num_errors = stat.import_errors
file_name = Path(file_path.rel_path).stem
file_name = Path(file_path.path).stem
processor_pid = proc.pid if proc else None
processor_start_time = proc.start_time if proc else None
runtime = (now - processor_start_time) if processor_start_time else None
Expand Down Expand Up @@ -766,7 +750,7 @@ def _render_log_filename(self, dag_file: DagFileInfo) -> str:
self._latest_log_symlink_date = datetime.today()

bundle = next(b for b in self._dag_bundles if b.name == dag_file.bundle_name)
relative_path = Path(dag_file.rel_path)
relative_path = Path(dag_file.path).relative_to(bundle.path)
return os.path.join(self._get_log_dir(), bundle.name, f"{relative_path}.log")

def _get_logger_for_dag_file(self, dag_file: DagFileInfo):
Expand All @@ -784,7 +768,7 @@ def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess:

return DagFileProcessorProcess.start(
id=id,
path=dag_file.absolute_path,
path=dag_file.path,
bundle_path=cast(Path, dag_file.bundle_path),
callbacks=callback_to_execute_for_file,
selector=self.selector,
Expand Down Expand Up @@ -836,7 +820,7 @@ def prepare_file_path_queue(self):
for file_path in self._file_paths:
if is_mtime_mode:
try:
files_with_mtime[file_path] = os.path.getmtime(file_path.absolute_path)
files_with_mtime[file_path] = os.path.getmtime(file_path.path)
except FileNotFoundError:
self.log.warning("Skipping processing of missing file: %s", file_path)
self._file_stats.pop(file_path, None)
Expand All @@ -862,7 +846,7 @@ def prepare_file_path_queue(self):
if is_mtime_mode:
file_paths = sorted(files_with_mtime, key=files_with_mtime.get, reverse=True)
elif list_mode == "alphabetical":
file_paths.sort(key=lambda f: f.rel_path)
file_paths.sort(key=lambda f: f.path)
elif list_mode == "random_seeded_by_host":
# Shuffle the list seeded by hostname so multiple DAG processors can work on different
# set of files. Since we set the seed, the sort order will remain same per host
Expand Down
14 changes: 7 additions & 7 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys
import time
from collections import defaultdict
from collections.abc import Collection, Generator, Iterable, Sequence
from collections.abc import Collection, Container, Generator, Iterable, Sequence
from contextlib import ExitStack
from datetime import datetime, timedelta
from functools import cache
Expand Down Expand Up @@ -2241,25 +2241,25 @@ def dag_display_name(self) -> str:
@provide_session
def deactivate_deleted_dags(
cls,
active: set[tuple[str, str]],
alive_dag_filelocs: Container[str],
session: Session = NEW_SESSION,
) -> None:
"""
Set ``is_active=False`` on the DAGs for which the DAG files have been removed.
:param active_paths: file paths of alive DAGs
:param alive_dag_filelocs: file paths of alive DAGs
:param session: ORM Session
"""
log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__)
dag_models = session.scalars(
select(cls).where(
cls.relative_fileloc.is_not(None),
cls.fileloc.is_not(None),
)
)

for dm in dag_models:
if (dm.bundle_name, dm.relative_fileloc) not in active:
dm.is_active = False
for dag_model in dag_models:
if dag_model.fileloc not in alive_dag_filelocs:
dag_model.is_active = False

@classmethod
def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class DagBag(LoggingMixin):

def __init__(
self,
dag_folder: str | Path | None = None, # todo AIP-66: rename this to path
dag_folder: str | Path | None = None,
include_examples: bool | ArgNotSet = NOTSET,
safe_mode: bool | ArgNotSet = NOTSET,
read_dags_from_db: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ class ParseImportError(Base):
__tablename__ = "import_error"
id = Column(Integer, primary_key=True)
timestamp = Column(UtcDateTime)
filename = Column(String(1024)) # todo AIP-66: make this bundle and relative fileloc
filename = Column(String(1024))
bundle_name = Column(StringID())
stacktrace = Column(Text)
56 changes: 28 additions & 28 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,32 +519,31 @@ def run(ti: RuntimeTaskInstance, log: Logger):
inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)]
outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)]
SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore
msg = SUPERVISOR_COMMS.get_message() # type: ignore

if isinstance(msg, OKResponse) and not msg.ok:
log.info("Runtime checks failed for task, marking task as failed..")
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
else:
context = ti.get_template_context()
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
# TODO: Get things from _execute_task_with_callbacks
# - Pre Execute
# etc
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

task_outlets, outlet_events = _process_outlets(context, ti.task.outlets)
msg = SucceedTask(
end_date=datetime.now(tz=timezone.utc),
task_outlets=task_outlets,
outlet_events=outlet_events,
)
ok_response = SUPERVISOR_COMMS.get_message() # type: ignore
if not isinstance(ok_response, OKResponse) or not ok_response.ok:
log.info("Runtime checks failed for task, marking task as failed..")
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
return
context = ti.get_template_context()
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
# TODO: Get things from _execute_task_with_callbacks
# - Pre Execute
# etc
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

task_outlets, outlet_events = _process_outlets(context, ti.task.outlets)
msg = SucceedTask(
end_date=datetime.now(tz=timezone.utc),
task_outlets=task_outlets,
outlet_events=outlet_events,
)
except TaskDeferred as defer:
# TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
Expand Down Expand Up @@ -610,8 +609,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
log.exception("Task failed with exception")
# TODO: Run task failure callbacks here
msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc))
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
finally:
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def _execute_task(context: Context, task: BaseOperator):
Expand Down
49 changes: 47 additions & 2 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,14 +647,58 @@ def test_run_with_asset_outlets(
ti = create_runtime_ti(task=task, dag_id="dag_with_asset_outlet_task")
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
mock_supervisor_comms.get_message.return_value = OKResponse(
ok=True,
)

run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY)


def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms):
@pytest.mark.parametrize(
["ok", "last_expected_msg"],
[
pytest.param(
True,
SucceedTask(
end_date=timezone.datetime(2024, 12, 3, 10, 0),
task_outlets=[
AssetProfile(name="name", uri="s3://bucket/my-task", asset_type="Asset"),
AssetProfile(name="new-name", uri="s3://bucket/my-task", asset_type="Asset"),
],
outlet_events=[
{
"asset_alias_events": [],
"extra": {},
"key": {"name": "name", "uri": "s3://bucket/my-task"},
},
{
"asset_alias_events": [],
"extra": {},
"key": {"name": "new-name", "uri": "s3://bucket/my-task"},
},
],
),
id="runtime_checks_pass",
),
pytest.param(
False,
TaskState(
state=TerminalTIState.FAILED,
end_date=timezone.datetime(2024, 12, 3, 10, 0),
),
id="runtime_checks_fail",
),
],
)
def test_run_with_inlets_and_outlets(
create_runtime_ti, mock_supervisor_comms, time_machine, ok, last_expected_msg
):
"""Test running a basic tasks with inlets and outlets."""
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

from airflow.providers.standard.operators.bash import BashOperator

task = BashOperator(
Expand All @@ -672,7 +716,7 @@ def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms):

ti = create_runtime_ti(task=task, dag_id="dag_with_inlets_and_outlets")
mock_supervisor_comms.get_message.return_value = OKResponse(
ok=True,
ok=ok,
)

run(ti, log=mock.MagicMock())
Expand All @@ -688,6 +732,7 @@ def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms):
],
)
mock_supervisor_comms.send_request.assert_any_call(msg=expected, log=mock.ANY)
mock_supervisor_comms.send_request.assert_any_call(msg=last_expected_msg, log=mock.ANY)


class TestRuntimeTaskInstance:
Expand Down
Loading

0 comments on commit 3b9e524

Please sign in to comment.