diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index c8bac5cef964a..879d234964473 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -210,7 +210,6 @@ def _serialize_dag_capturing_errors( except Exception: log.exception("Failed to write serialized DAG dag_id=%s fileloc=%s", dag.dag_id, dag.fileloc) dagbag_import_error_traceback_depth = conf.getint("core", "dagbag_import_error_traceback_depth") - # todo AIP-66: this needs to use bundle name / rel fileloc instead return [(dag.fileloc, traceback.format_exc(limit=-dagbag_import_error_traceback_depth))] diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index fa68cb4fab1ba..a966f0a29c85f 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -102,16 +102,10 @@ class DagFileStat: class DagFileInfo: """Information about a DAG file.""" - rel_path: Path + path: str # absolute path of the file bundle_name: str bundle_path: Path | None = field(compare=False, default=None) - @property - def absolute_path(self) -> Path: - if not self.bundle_path: - raise ValueError("bundle_path not set") - return self.bundle_path / self.rel_path - def _config_int_factory(section: str, key: str): return functools.partial(conf.getint, section, key) @@ -247,9 +241,7 @@ def _scan_stale_dags(self): elapsed_time_since_refresh = now - self._last_deactivate_stale_dags_time if elapsed_time_since_refresh > self.parsing_cleanup_interval: last_parsed = { - file_info: stat.last_finish_time - for file_info, stat in self._file_stats.items() - if stat.last_finish_time + fp: stat.last_finish_time for fp, stat in self._file_stats.items() if stat.last_finish_time } self.deactivate_stale_dags(last_parsed=last_parsed) self._last_deactivate_stale_dags_time = time.monotonic() @@ -277,11 +269,14 @@ def deactivate_stale_dags( # last_parsed_time is the processor_timeout. Longer than that indicates that the DAG is # no longer present in the file. We have a stale_dag_threshold configured to prevent a # significant delay in deactivation of stale dags when a large timeout is configured - file_info = DagFileInfo(rel_path=Path(dag.relative_fileloc), bundle_name=dag.bundle_name) - if last_finish_time := last_parsed.get(file_info, None): - if dag.last_parsed_time + timedelta(seconds=self.stale_dag_threshold) < last_finish_time: - self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id) - to_deactivate.add(dag.dag_id) + dag_file_path = DagFileInfo(path=dag.fileloc, bundle_name=dag.bundle_name) + if ( + dag_file_path in last_parsed + and (dag.last_parsed_time + timedelta(seconds=self.stale_dag_threshold)) + < last_parsed[dag_file_path] + ): + self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id) + to_deactivate.add(dag.dag_id) if to_deactivate: deactivated_dagmodel = session.execute( @@ -488,30 +483,30 @@ def _refresh_dag_bundles(self): "Version changed for %s, new version: %s", bundle.name, version_after_refresh ) - found_file_infos = [ - DagFileInfo(rel_path=p, bundle_name=bundle.name, bundle_path=bundle.path) - for p in self._find_files_in_bundle(bundle) - ] + bundle_file_paths = self._find_files_in_bundle(bundle) new_file_paths = [f for f in self._file_paths if f.bundle_name != bundle.name] - new_file_paths.extend(found_file_infos) + new_file_paths.extend( + DagFileInfo(path=path, bundle_path=bundle.path, bundle_name=bundle.name) + for path in bundle_file_paths + ) self.set_file_paths(new_file_paths) - self.deactivate_deleted_dags(active_files=found_file_infos) + self.deactivate_deleted_dags(bundle_file_paths) self.clear_nonexistent_import_errors() self._bundle_versions[bundle.name] = bundle.get_current_version() - def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[Path]: - """Get relative file paths from bundle dir.""" + def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[str]: + """Refresh file paths from bundle dir.""" # Build up a list of Python files that could contain DAGs self.log.info("Searching for files in %s at %s", bundle.name, bundle.path) - file_paths = [Path(x).relative_to(bundle.path) for x in list_py_file_paths(bundle.path)] + file_paths = list_py_file_paths(bundle.path) self.log.info("Found %s files for bundle %s", len(file_paths), bundle.name) return file_paths - def deactivate_deleted_dags(self, active_files: list[DagFileInfo]) -> None: + def deactivate_deleted_dags(self, file_paths: set[str]) -> None: """Deactivate DAGs that come from files that are no longer present.""" def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: @@ -531,20 +526,10 @@ def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: except zipfile.BadZipFile: self.log.exception("There was an error accessing ZIP file %s %s", fileloc) - active_subpaths: set[tuple[str, str]] = set() - """ - 'subpath' here means bundle + modified rel path. What does modified rel path mean? - Well, '_iter_dag_filelocs' walks through zip files and may return a "path" that is, - rel path to the zip, plus the rel path within the zip. So, since this is is a bit different - from most uses of the word "rel path", I wanted to call it something different. - A set is used presumably since many dags can be in one file. - """ - - for info in active_files: - for path in _iter_dag_filelocs(str(info.absolute_path)): - active_subpaths.add((info.bundle_name, path)) + dag_filelocs = {full_loc for path in file_paths for full_loc in _iter_dag_filelocs(path)} - DagModel.deactivate_deleted_dags(active_subpaths) + # TODO: AIP-66: make bundle aware, as fileloc won't be unique long term. + DagModel.deactivate_deleted_dags(dag_filelocs) def _print_stat(self): """Occasionally print out stats about how fast the files are getting processed.""" @@ -568,8 +553,7 @@ def clear_nonexistent_import_errors(self, session=NEW_SESSION): if self._file_paths: query = query.where( tuple_(ParseImportError.filename, ParseImportError.bundle_name).notin_( - # todo AIP-66: ParseImportError should have rel fileloce + bundle name - [(str(f.absolute_path), f.bundle_name) for f in self._file_paths] + [(f.path, f.bundle_name) for f in self._file_paths] ), ) @@ -614,7 +598,7 @@ def _log_file_processing_stats(self, known_file_paths): proc = self._processors.get(file_path) num_dags = stat.num_dags num_errors = stat.import_errors - file_name = Path(file_path.rel_path).stem + file_name = Path(file_path.path).stem processor_pid = proc.pid if proc else None processor_start_time = proc.start_time if proc else None runtime = (now - processor_start_time) if processor_start_time else None @@ -766,7 +750,7 @@ def _render_log_filename(self, dag_file: DagFileInfo) -> str: self._latest_log_symlink_date = datetime.today() bundle = next(b for b in self._dag_bundles if b.name == dag_file.bundle_name) - relative_path = Path(dag_file.rel_path) + relative_path = Path(dag_file.path).relative_to(bundle.path) return os.path.join(self._get_log_dir(), bundle.name, f"{relative_path}.log") def _get_logger_for_dag_file(self, dag_file: DagFileInfo): @@ -784,7 +768,7 @@ def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess: return DagFileProcessorProcess.start( id=id, - path=dag_file.absolute_path, + path=dag_file.path, bundle_path=cast(Path, dag_file.bundle_path), callbacks=callback_to_execute_for_file, selector=self.selector, @@ -836,7 +820,7 @@ def prepare_file_path_queue(self): for file_path in self._file_paths: if is_mtime_mode: try: - files_with_mtime[file_path] = os.path.getmtime(file_path.absolute_path) + files_with_mtime[file_path] = os.path.getmtime(file_path.path) except FileNotFoundError: self.log.warning("Skipping processing of missing file: %s", file_path) self._file_stats.pop(file_path, None) @@ -862,7 +846,7 @@ def prepare_file_path_queue(self): if is_mtime_mode: file_paths = sorted(files_with_mtime, key=files_with_mtime.get, reverse=True) elif list_mode == "alphabetical": - file_paths.sort(key=lambda f: f.rel_path) + file_paths.sort(key=lambda f: f.path) elif list_mode == "random_seeded_by_host": # Shuffle the list seeded by hostname so multiple DAG processors can work on different # set of files. Since we set the seed, the sort order will remain same per host diff --git a/airflow/models/dag.py b/airflow/models/dag.py index a0f6e901dc3e9..a22452a748cec 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -24,7 +24,7 @@ import sys import time from collections import defaultdict -from collections.abc import Collection, Generator, Iterable, Sequence +from collections.abc import Collection, Container, Generator, Iterable, Sequence from contextlib import ExitStack from datetime import datetime, timedelta from functools import cache @@ -2241,25 +2241,25 @@ def dag_display_name(self) -> str: @provide_session def deactivate_deleted_dags( cls, - active: set[tuple[str, str]], + alive_dag_filelocs: Container[str], session: Session = NEW_SESSION, ) -> None: """ Set ``is_active=False`` on the DAGs for which the DAG files have been removed. - :param active_paths: file paths of alive DAGs + :param alive_dag_filelocs: file paths of alive DAGs :param session: ORM Session """ log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__) dag_models = session.scalars( select(cls).where( - cls.relative_fileloc.is_not(None), + cls.fileloc.is_not(None), ) ) - for dm in dag_models: - if (dm.bundle_name, dm.relative_fileloc) not in active: - dm.is_active = False + for dag_model in dag_models: + if dag_model.fileloc not in alive_dag_filelocs: + dag_model.is_active = False @classmethod def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]: diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index a18b133f8e67e..96eceb4c52568 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -119,7 +119,7 @@ class DagBag(LoggingMixin): def __init__( self, - dag_folder: str | Path | None = None, # todo AIP-66: rename this to path + dag_folder: str | Path | None = None, include_examples: bool | ArgNotSet = NOTSET, safe_mode: bool | ArgNotSet = NOTSET, read_dags_from_db: bool = False, diff --git a/airflow/models/errors.py b/airflow/models/errors.py index 748d56c46b462..21c2236e2c18b 100644 --- a/airflow/models/errors.py +++ b/airflow/models/errors.py @@ -29,6 +29,6 @@ class ParseImportError(Base): __tablename__ = "import_error" id = Column(Integer, primary_key=True) timestamp = Column(UtcDateTime) - filename = Column(String(1024)) # todo AIP-66: make this bundle and relative fileloc + filename = Column(String(1024)) bundle_name = Column(StringID()) stacktrace = Column(Text) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 8b21fcfaf486a..7bcecc6343343 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -519,32 +519,31 @@ def run(ti: RuntimeTaskInstance, log: Logger): inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)] outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)] SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore - msg = SUPERVISOR_COMMS.get_message() # type: ignore - - if isinstance(msg, OKResponse) and not msg.ok: - log.info("Runtime checks failed for task, marking task as failed..") - msg = TaskState( - state=TerminalTIState.FAILED, - end_date=datetime.now(tz=timezone.utc), - ) - else: - context = ti.get_template_context() - with set_current_context(context): - jinja_env = ti.task.dag.get_template_env() - ti.task = ti.render_templates(context=context, jinja_env=jinja_env) - # TODO: Get things from _execute_task_with_callbacks - # - Pre Execute - # etc - result = _execute_task(context, ti.task) - - _push_xcom_if_needed(result, ti) - - task_outlets, outlet_events = _process_outlets(context, ti.task.outlets) - msg = SucceedTask( - end_date=datetime.now(tz=timezone.utc), - task_outlets=task_outlets, - outlet_events=outlet_events, - ) + ok_response = SUPERVISOR_COMMS.get_message() # type: ignore + if not isinstance(ok_response, OKResponse) or not ok_response.ok: + log.info("Runtime checks failed for task, marking task as failed..") + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + return + context = ti.get_template_context() + with set_current_context(context): + jinja_env = ti.task.dag.get_template_env() + ti.task = ti.render_templates(context=context, jinja_env=jinja_env) + # TODO: Get things from _execute_task_with_callbacks + # - Pre Execute + # etc + result = _execute_task(context, ti.task) + + _push_xcom_if_needed(result, ti) + + task_outlets, outlet_events = _process_outlets(context, ti.task.outlets) + msg = SucceedTask( + end_date=datetime.now(tz=timezone.utc), + task_outlets=task_outlets, + outlet_events=outlet_events, + ) except TaskDeferred as defer: # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) @@ -610,8 +609,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): log.exception("Task failed with exception") # TODO: Run task failure callbacks here msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) - if msg: - SUPERVISOR_COMMS.send_request(msg=msg, log=log) + finally: + if msg: + SUPERVISOR_COMMS.send_request(msg=msg, log=log) def _execute_task(context: Context, task: BaseOperator): diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index fef2f505b2a55..b4a83955670b3 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -647,14 +647,58 @@ def test_run_with_asset_outlets( ti = create_runtime_ti(task=task, dag_id="dag_with_asset_outlet_task") instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) + mock_supervisor_comms.get_message.return_value = OKResponse( + ok=True, + ) run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY) -def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms): +@pytest.mark.parametrize( + ["ok", "last_expected_msg"], + [ + pytest.param( + True, + SucceedTask( + end_date=timezone.datetime(2024, 12, 3, 10, 0), + task_outlets=[ + AssetProfile(name="name", uri="s3://bucket/my-task", asset_type="Asset"), + AssetProfile(name="new-name", uri="s3://bucket/my-task", asset_type="Asset"), + ], + outlet_events=[ + { + "asset_alias_events": [], + "extra": {}, + "key": {"name": "name", "uri": "s3://bucket/my-task"}, + }, + { + "asset_alias_events": [], + "extra": {}, + "key": {"name": "new-name", "uri": "s3://bucket/my-task"}, + }, + ], + ), + id="runtime_checks_pass", + ), + pytest.param( + False, + TaskState( + state=TerminalTIState.FAILED, + end_date=timezone.datetime(2024, 12, 3, 10, 0), + ), + id="runtime_checks_fail", + ), + ], +) +def test_run_with_inlets_and_outlets( + create_runtime_ti, mock_supervisor_comms, time_machine, ok, last_expected_msg +): """Test running a basic tasks with inlets and outlets.""" + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + from airflow.providers.standard.operators.bash import BashOperator task = BashOperator( @@ -672,7 +716,7 @@ def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms): ti = create_runtime_ti(task=task, dag_id="dag_with_inlets_and_outlets") mock_supervisor_comms.get_message.return_value = OKResponse( - ok=True, + ok=ok, ) run(ti, log=mock.MagicMock()) @@ -688,6 +732,7 @@ def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms): ], ) mock_supervisor_comms.send_request.assert_any_call(msg=expected, log=mock.ANY) + mock_supervisor_comms.send_request.assert_any_call(msg=last_expected_msg, log=mock.ANY) class TestRuntimeTaskInstance: diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index a68749f831ee0..a33884e6d4e07 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -75,8 +75,8 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1) -def _get_dag_file_paths(files: list[str | Path]) -> list[DagFileInfo]: - return [DagFileInfo(bundle_name="testing", bundle_path=TEST_DAGS_FOLDER, rel_path=Path(f)) for f in files] +def _get_dag_file_paths(files: list[str]) -> list[DagFileInfo]: + return [DagFileInfo(bundle_name="testing", bundle_path=TEST_DAGS_FOLDER, path=f) for f in files] class TestDagFileProcessorManager: @@ -164,9 +164,9 @@ def test_start_new_processes_with_same_filepath(self): """ manager = DagFileProcessorManager(max_runs=1) - file_1 = DagFileInfo(bundle_name="testing", rel_path=Path("file_1.py"), bundle_path=TEST_DAGS_FOLDER) - file_2 = DagFileInfo(bundle_name="testing", rel_path=Path("file_2.py"), bundle_path=TEST_DAGS_FOLDER) - file_3 = DagFileInfo(bundle_name="testing", rel_path=Path("file_3.py"), bundle_path=TEST_DAGS_FOLDER) + file_1 = DagFileInfo(bundle_name="testing", path="file_1.py", bundle_path=TEST_DAGS_FOLDER) + file_2 = DagFileInfo(bundle_name="testing", path="file_2.py", bundle_path=TEST_DAGS_FOLDER) + file_3 = DagFileInfo(bundle_name="testing", path="file_3.py", bundle_path=TEST_DAGS_FOLDER) manager._file_path_queue = deque([file_1, file_2, file_3]) # Mock that only one processor exists. This processor runs with 'file_1' @@ -188,9 +188,7 @@ def test_start_new_processes_with_same_filepath(self): def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): """Ensure processors and file stats are removed when the file path is not in the new file paths""" manager = DagFileProcessorManager(max_runs=1) - file = DagFileInfo( - bundle_name="testing", rel_path=Path("missing_file.txt"), bundle_path=TEST_DAGS_FOLDER - ) + file = DagFileInfo(bundle_name="testing", path="missing_file.txt", bundle_path=TEST_DAGS_FOLDER) manager._processors[file] = MagicMock() manager._file_stats[file] = DagFileStat() @@ -201,7 +199,7 @@ def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): manager = DagFileProcessorManager(max_runs=1) - file = DagFileInfo(bundle_name="testing", rel_path=Path("abc.txt"), bundle_path=TEST_DAGS_FOLDER) + file = DagFileInfo(bundle_name="testing", path="abc.txt", bundle_path=TEST_DAGS_FOLDER) mock_processor = MagicMock() manager._processors[file] = mock_processor @@ -289,10 +287,7 @@ def test_add_new_file_to_parsing_queue(self, mock_getmtime): assert manager._file_path_queue == deque(ordered_files) manager.set_file_paths( - [ - *dag_files, - DagFileInfo(bundle_name="testing", rel_path=Path("file_4.py"), bundle_path=TEST_DAGS_FOLDER), - ] + [*dag_files, DagFileInfo(bundle_name="testing", path="file_4.py", bundle_path=TEST_DAGS_FOLDER)] ) manager.add_new_file_path_to_queue() ordered_files = _get_dag_file_paths(["file_4.py", "file_3.py", "file_2.py", "file_1.py"]) @@ -306,9 +301,7 @@ def test_recently_modified_file_is_parsed_with_mtime_mode(self, mock_getmtime): """ freezed_base_time = timezone.datetime(2020, 1, 5, 0, 0, 0) initial_file_1_mtime = (freezed_base_time - timedelta(minutes=5)).timestamp() - dag_file = DagFileInfo( - bundle_name="testing", rel_path=Path("file_1.py"), bundle_path=TEST_DAGS_FOLDER - ) + dag_file = DagFileInfo(bundle_name="testing", path="file_1.py", bundle_path=TEST_DAGS_FOLDER) dag_files = [dag_file] mock_getmtime.side_effect = [initial_file_1_mtime] @@ -378,15 +371,10 @@ def test_scan_stale_dags(self, testing_dag_bundle): test_dag_path = DagFileInfo( bundle_name="testing", - rel_path=Path("test_example_bash_operator.py"), + path=str(TEST_DAG_FOLDER / "test_example_bash_operator.py"), bundle_path=TEST_DAGS_FOLDER, ) - dagbag = DagBag( - test_dag_path.absolute_path, - read_dags_from_db=False, - include_examples=False, - bundle_path=test_dag_path.bundle_path, - ) + dagbag = DagBag(test_dag_path.path, read_dags_from_db=False, include_examples=False) with create_session() as session: # Add stale DAG to the DB @@ -409,11 +397,7 @@ def test_scan_stale_dags(self, testing_dag_bundle): active_dag_count = ( session.query(func.count(DagModel.dag_id)) - .filter( - DagModel.is_active, - DagModel.relative_fileloc == str(test_dag_path.rel_path), - DagModel.bundle_name == test_dag_path.bundle_name, - ) + .filter(DagModel.is_active, DagModel.fileloc == test_dag_path.path) .scalar() ) assert active_dag_count == 1 @@ -422,11 +406,7 @@ def test_scan_stale_dags(self, testing_dag_bundle): active_dag_count = ( session.query(func.count(DagModel.dag_id)) - .filter( - DagModel.is_active, - DagModel.relative_fileloc == str(test_dag_path.rel_path), - DagModel.bundle_name == test_dag_path.bundle_name, - ) + .filter(DagModel.is_active, DagModel.fileloc == test_dag_path.path) .scalar() ) assert active_dag_count == 0 @@ -446,9 +426,7 @@ def test_kill_timed_out_processors_kill(self): processor = self.mock_processor() processor._process.create_time.return_value = timezone.make_aware(datetime.min).timestamp() manager._processors = { - DagFileInfo( - bundle_name="testing", rel_path=Path("abc.txt"), bundle_path=TEST_DAGS_FOLDER - ): processor + DagFileInfo(bundle_name="testing", path="abc.txt", bundle_path=TEST_DAGS_FOLDER): processor } with mock.patch.object(type(processor), "kill") as mock_kill: manager._kill_timed_out_processors() @@ -464,9 +442,7 @@ def test_kill_timed_out_processors_no_kill(self): processor = self.mock_processor() processor._process.create_time.return_value = timezone.make_aware(datetime.max).timestamp() manager._processors = { - DagFileInfo( - bundle_name="testing", rel_path=Path("abc.txt"), bundle_path=TEST_DAGS_FOLDER - ): processor + DagFileInfo(bundle_name="testing", path="abc.txt", bundle_path=TEST_DAGS_FOLDER): processor } with mock.patch.object(type(processor), "kill") as mock_kill: manager._kill_timed_out_processors() @@ -684,11 +660,6 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path, configure_te @pytest.mark.skip("AIP-66: callbacks are not implemented yet") def test_callback_queue(self, tmp_path): - """ - This test has gotten a bit out of sync with the codebase. - - I am just updating it to be consistent with the changes in DagFileInfo - """ # given manager = DagFileProcessorManager( max_runs=1, @@ -696,17 +667,17 @@ def test_callback_queue(self, tmp_path): ) dag1_path = DagFileInfo( - bundle_name="testing", rel_path=Path("green_eggs/ham/file1.py"), bundle_path=TEST_DAGS_FOLDER + bundle_name="testing", path="/green_eggs/ham/file1.py", bundle_path=TEST_DAGS_FOLDER ) dag1_req1 = DagCallbackRequest( - full_filepath=TEST_DAGS_FOLDER / "green_eggs/ham/file1.py", + full_filepath="/green_eggs/ham/file1.py", dag_id="dag1", run_id="run1", is_failure_callback=False, msg=None, ) dag1_req2 = DagCallbackRequest( - full_filepath=TEST_DAGS_FOLDER / "green_eggs/ham/file1.py", + full_filepath="/green_eggs/ham/file1.py", dag_id="dag1", run_id="run1", is_failure_callback=False, @@ -714,10 +685,10 @@ def test_callback_queue(self, tmp_path): ) dag2_path = DagFileInfo( - bundle_name="testing", rel_path=Path("green_eggs/ham/file2.py"), bundle_path=TEST_DAGS_FOLDER + bundle_name="testing", path="/green_eggs/ham/file2.py", bundle_path=TEST_DAGS_FOLDER ) dag2_req1 = DagCallbackRequest( - full_filepath=TEST_DAGS_FOLDER / "green_eggs/ham/file2.py", + full_filepath="/green_eggs/ham/file2.py", dag_id="dag2", run_id="run1", is_failure_callback=False, @@ -800,12 +771,10 @@ def test_bundles_are_refreshed(self): bundleone = MagicMock() bundleone.name = "bundleone" - bundleone.path = "/dev/null" bundleone.refresh_interval = 0 bundleone.get_current_version.return_value = None bundletwo = MagicMock() bundletwo.name = "bundletwo" - bundletwo.path = "/dev/null" bundletwo.refresh_interval = 300 bundletwo.get_current_version.return_value = None @@ -847,7 +816,7 @@ def _update_bundletwo_version(): manager.run() assert bundletwo.refresh.call_count == 2 - def test_bundles_versions_are_stored(self, session): + def test_bundles_versions_are_stored(self): config = [ { "name": "mybundle", @@ -857,8 +826,7 @@ def test_bundles_versions_are_stored(self, session): ] mybundle = MagicMock() - mybundle.name = "mybundle" - mybundle.path = "/dev/null" + mybundle.name = "bundleone" mybundle.refresh_interval = 0 mybundle.supports_versioning = True mybundle.get_current_version.return_value = "123" @@ -868,11 +836,11 @@ def test_bundles_versions_are_stored(self, session): with mock.patch( "airflow.dag_processing.bundles.manager.DagBundlesManager" ) as mock_bundle_manager: - mock_bundle_manager.return_value._bundle_config = {"mybundle": None} + mock_bundle_manager.return_value._bundle_config = {"bundleone": None} mock_bundle_manager.return_value.get_all_dag_bundles.return_value = [mybundle] manager = DagFileProcessorManager(max_runs=1) manager.run() with create_session() as session: - model = session.get(DagBundleModel, "mybundle") + model = session.get(DagBundleModel, "bundleone") assert model.version == "123" diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 78fffd45fd098..8766a89a8102d 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1081,16 +1081,18 @@ def add_failed_dag_run(dag, id, logical_date): dag.clear() self._clean_up(dag_id) - def test_dag_is_deactivated_upon_dagfile_deletion(self, dag_maker): + def test_dag_is_deactivated_upon_dagfile_deletion(self): dag_id = "old_existing_dag" - with dag_maker(dag_id, schedule=None, is_paused_upon_creation=True) as dag: - ... + dag_fileloc = "/usr/local/airflow/dags/non_existing_path.py" + dag = DAG(dag_id, schedule=None, is_paused_upon_creation=True) + dag.fileloc = dag_fileloc session = settings.Session() dag.sync_to_db(session=session) orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one() assert orm_dag.is_active + assert orm_dag.fileloc == dag_fileloc DagModel.deactivate_deleted_dags(list_py_file_paths(settings.DAGS_FOLDER))