Skip to content

Commit

Permalink
Merge branch 'main' into release-0.22.0
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana authored Jan 10, 2025
2 parents bfb7697 + 4f2a57f commit 82ed897
Show file tree
Hide file tree
Showing 7 changed files with 394 additions and 53 deletions.
184 changes: 141 additions & 43 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

from __future__ import annotations

import ast

# pylint: disable=ungrouped-imports
import inspect
import os
import re
from copy import deepcopy
from datetime import datetime, timedelta
from functools import partial
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Tuple, Union

from airflow import DAG, configuration
from airflow.models import BaseOperator, Variable
Expand Down Expand Up @@ -83,7 +85,7 @@
from airflow.utils.task_group import TaskGroup
from kubernetes.client.models import V1Container, V1Pod

from dagfactory import utils
from dagfactory import parsers, utils
from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException

# TimeTable is introduced in Airflow 2.2.0
Expand Down Expand Up @@ -293,8 +295,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
)
if not task_params.get("python_callable"):
task_params["python_callable"]: Callable = utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
task_params["python_callable_name"], task_params["python_callable_file"]
)
# remove dag-factory specific parameters
# Airflow 2.0 doesn't allow these to be passed to operator
Expand All @@ -312,8 +313,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
# Success checks
if task_params.get("success_check_file") and task_params.get("success_check_name"):
task_params["success"]: Callable = utils.get_python_callable(
task_params["success_check_name"],
task_params["success_check_file"],
task_params["success_check_name"], task_params["success_check_file"]
)
del task_params["success_check_name"]
del task_params["success_check_file"]
Expand All @@ -325,8 +325,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
# Failure checks
if task_params.get("failure_check_file") and task_params.get("failure_check_name"):
task_params["failure"]: Callable = utils.get_python_callable(
task_params["failure_check_name"],
task_params["failure_check_file"],
task_params["failure_check_name"], task_params["failure_check_file"]
)
del task_params["failure_check_name"]
del task_params["failure_check_file"]
Expand All @@ -347,8 +346,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
)
if task_params.get("response_check_file"):
task_params["response_check"]: Callable = utils.get_python_callable(
task_params["response_check_name"],
task_params["response_check_file"],
task_params["response_check_name"], task_params["response_check_file"]
)
# remove dag-factory specific parameters
# Airflow 2.0 doesn't allow these to be passed to operator
Expand Down Expand Up @@ -438,11 +436,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial")
) and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"):
# Getting expand and partial kwargs from task_params
(
task_params,
expand_kwargs,
partial_kwargs,
) = utils.get_expand_partial_kwargs(task_params)
(task_params, expand_kwargs, partial_kwargs) = utils.get_expand_partial_kwargs(task_params)

# If there are partial_kwargs we should merge them with existing task_params
if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params):
Expand Down Expand Up @@ -626,6 +620,132 @@ def replace_expand_values(task_conf: Dict, tasks_dict: Dict[str, BaseOperator]):
task_conf["expand"][expand_key] = tasks_dict[task_id].output
return task_conf

@staticmethod
def safe_eval(condition_string: str, dataset_map: dict) -> Any:
"""
Safely evaluates a condition string using the provided dataset map.
:param condition_string: A string representing the condition to evaluate.
Example: "(dataset_custom_1 & dataset_custom_2) | dataset_custom_3".
:type condition_string: str
:param dataset_map: A dictionary where keys are valid variable names (dataset aliases),
and values are Dataset objects.
:type dataset_map: dict
:returns: The result of evaluating the condition.
:rtype: Any
"""
tree = ast.parse(condition_string, mode="eval")
evaluator = parsers.SafeEvalVisitor(dataset_map)
return evaluator.evaluate(tree)

@staticmethod
def _extract_and_transform_datasets(datasets_conditions: str) -> Tuple[str, Dict[str, Any]]:
"""
Extracts dataset names and storage paths from the conditions string and transforms them into valid variable names.
:param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition.
:type datasets_conditions: str
:returns: A tuple containing the transformed conditions string and the dataset map.
:rtype: Tuple[str, Dict[str, Any]]
"""
dataset_map = {}
datasets_filter: List[str] = utils.extract_dataset_names(datasets_conditions) + utils.extract_storage_names(
datasets_conditions
)

for uri in datasets_filter:
valid_variable_name = utils.make_valid_variable_name(uri)
datasets_conditions = datasets_conditions.replace(uri, valid_variable_name)
dataset_map[valid_variable_name] = Dataset(uri)

return datasets_conditions, dataset_map

@staticmethod
def evaluate_condition_with_datasets(datasets_conditions: str) -> Any:
"""
Evaluates a condition using the dataset filter, transforming URIs into valid variable names.
:param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition.
:type datasets_conditions: str
:returns: The result of the logical condition evaluation with URIs replaced by valid variable names.
:rtype: Any
"""
datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions)
evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map)
return evaluated_condition

@staticmethod
def process_file_with_datasets(file: str, datasets_conditions: str) -> Any:
"""
Processes datasets from a file and evaluates conditions if provided.
:param file: The file path containing dataset information in a YAML or other structured format.
:type file: str
:param datasets_conditions: A string of dataset conditions to filter and process.
:type datasets_conditions: str
:returns: The result of the condition evaluation if `condition_string` is provided, otherwise a list of `Dataset` objects.
:rtype: Any
"""
is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0")
datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions)

if is_airflow_version_at_least_2_9:
map_datasets = utils.get_datasets_map_uri_yaml_file(file, list(dataset_map.keys()))
dataset_map = {alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items()}
evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map)
return evaluated_condition
else:
datasets_uri = utils.get_datasets_uri_yaml_file(file, list(dataset_map.keys()))
return [Dataset(uri) for uri in datasets_uri]

@staticmethod
def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) -> None:
"""
Configures the schedule for the DAG based on parameters and the Airflow version.
:param dag_params: A dictionary containing DAG parameters, including scheduling configuration.
Example: {"schedule": {"file": "datasets.yaml", "datasets": ["dataset_1"], "conditions": "dataset_1 & dataset_2"}}
:type dag_params: Dict[str, Any]
:param dag_kwargs: A dictionary for setting the resulting schedule configuration for the DAG.
:type dag_kwargs: Dict[str, Any]
:raises KeyError: If required keys like "schedule" or "datasets" are missing in the parameters.
:returns: None. The function updates `dag_kwargs` in-place.
"""
is_airflow_version_at_least_2_4 = version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0")
is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0")
has_schedule_attr = utils.check_dict_key(dag_params, "schedule")
has_schedule_interval_attr = utils.check_dict_key(dag_params, "schedule_interval")

if has_schedule_attr and not has_schedule_interval_attr and is_airflow_version_at_least_2_4:
schedule: Dict[str, Any] = dag_params.get("schedule")

has_file_attr = utils.check_dict_key(schedule, "file")
has_datasets_attr = utils.check_dict_key(schedule, "datasets")

if has_file_attr and has_datasets_attr:
file = schedule.get("file")
datasets: Union[List[str], str] = schedule.get("datasets")
datasets_conditions: str = utils.parse_list_datasets(datasets)
dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_conditions)

elif has_datasets_attr and is_airflow_version_at_least_2_9:
datasets = schedule["datasets"]
datasets_conditions: str = utils.parse_list_datasets(datasets)
dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets(datasets_conditions)

else:
dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule]

if has_file_attr:
schedule.pop("file")
if has_datasets_attr:
schedule.pop("datasets")

# pylint: disable=too-many-locals
def build(self) -> Dict[str, Union[str, DAG]]:
"""
Expand All @@ -649,8 +769,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:

if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"):
dag_kwargs["max_active_tasks"] = dag_params.get(
"max_active_tasks",
configuration.conf.getint("core", "max_active_tasks_per_dag"),
"max_active_tasks", configuration.conf.getint("core", "max_active_tasks_per_dag")
)

if dag_params.get("timetable"):
Expand All @@ -668,8 +787,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:
)

dag_kwargs["max_active_runs"] = dag_params.get(
"max_active_runs",
configuration.conf.getint("core", "max_active_runs_per_dag"),
"max_active_runs", configuration.conf.getint("core", "max_active_runs_per_dag")
)

dag_kwargs["dagrun_timeout"] = dag_params.get("dagrun_timeout", None)
Expand Down Expand Up @@ -702,24 +820,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:

dag_kwargs["is_paused_upon_creation"] = dag_params.get("is_paused_upon_creation", None)

if (
utils.check_dict_key(dag_params, "schedule")
and not utils.check_dict_key(dag_params, "schedule_interval")
and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0")
):
if utils.check_dict_key(dag_params["schedule"], "file") and utils.check_dict_key(
dag_params["schedule"], "datasets"
):
file = dag_params["schedule"]["file"]
datasets_filter = dag_params["schedule"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del dag_params["schedule"]["file"]
del dag_params["schedule"]["datasets"]
else:
datasets_uri = dag_params["schedule"]

dag_kwargs["schedule"] = [Dataset(uri) for uri in datasets_uri]
DagBuilder.configure_schedule(dag_params, dag_kwargs)

dag_kwargs["params"] = dag_params.get("params", None)

Expand All @@ -734,8 +835,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:

if dag_params.get("doc_md_python_callable_file") and dag_params.get("doc_md_python_callable_name"):
doc_md_callable = utils.get_python_callable(
dag_params.get("doc_md_python_callable_name"),
dag_params.get("doc_md_python_callable_file"),
dag_params.get("doc_md_python_callable_name"), dag_params.get("doc_md_python_callable_file")
)
dag.doc_md = doc_md_callable(**dag_params.get("doc_md_python_arguments", {}))

Expand Down Expand Up @@ -872,8 +972,7 @@ def adjust_general_task_params(task_params: dict(str, Any)):
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
task_params["execution_date_fn_name"], task_params["execution_date_fn_file"]
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]
Expand Down Expand Up @@ -937,8 +1036,7 @@ def make_decorator(
# Fetch the Python callable
if set(mandatory_keys_set1).issubset(task_params):
python_callable: Callable = utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
task_params["python_callable_name"], task_params["python_callable_file"]
)
# Remove dag-factory specific parameters since Airflow 2.0 doesn't allow these to be passed to operator
del task_params["python_callable_name"]
Expand Down
16 changes: 11 additions & 5 deletions dagfactory/dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,21 @@ def __join(loader: yaml.FullLoader, node: yaml.Node) -> str:
seq = loader.construct_sequence(node)
return "".join([str(i) for i in seq])

def __or(loader: yaml.FullLoader, node: yaml.Node) -> str:
seq = loader.construct_sequence(node)
return " | ".join([f"({str(i)})" for i in seq])

def __and(loader: yaml.FullLoader, node: yaml.Node) -> str:
seq = loader.construct_sequence(node)
return " & ".join([f"({str(i)})" for i in seq])

yaml.add_constructor("!join", __join, yaml.FullLoader)
yaml.add_constructor("!or", __or, yaml.FullLoader)
yaml.add_constructor("!and", __and, yaml.FullLoader)

with open(config_filepath, "r", encoding="utf-8") as fp:
yaml.add_constructor("!join", __join, yaml.FullLoader)
config_with_env = os.path.expandvars(fp.read())
config: Dict[str, Any] = yaml.load(
stream=config_with_env,
Loader=yaml.FullLoader,
)
config: Dict[str, Any] = yaml.load(stream=config_with_env, Loader=yaml.FullLoader)
except Exception as err:
raise DagFactoryConfigException("Invalid DAG Factory config file") from err
return config
Expand Down
34 changes: 34 additions & 0 deletions dagfactory/parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import ast


class SafeEvalVisitor(ast.NodeVisitor):
def __init__(self, dataset_map):
self.dataset_map = dataset_map

def evaluate(self, tree):
return self.visit(tree)

def visit_Expression(self, node):
return self.visit(node.body)

def visit_BinOp(self, node):
left = self.visit(node.left)
right = self.visit(node.right)

if isinstance(node.op, ast.BitAnd):
return left & right
elif isinstance(node.op, ast.BitOr):
return left | right
else:
raise ValueError(f"Unsupported binary operation: {type(node.op).__name__}")

def visit_Name(self, node):
if node.id in self.dataset_map:
return self.dataset_map[node.id]
raise NameError(f"Undefined variable: {node.id}")

def visit_Constant(self, node):
return node.value

def generic_visit(self, node):
raise ValueError(f"Unsupported syntax: {type(node).__name__}")
Loading

0 comments on commit 82ed897

Please sign in to comment.