From 8263c947b25dc818b9e36404bd7385fb9ceffd03 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 29 Jan 2025 22:15:53 +0000 Subject: [PATCH] Swap CeleryExecutor over to use TaskSDK for execution. Some points of note about this PR: - Logging is changed in Celery, but only for Airflow 3 Celery does it's own "capture stdout" logging, which conflicts with the ones we do in the TaskSDK, so we disable that; but to not change anything for Airflow 3. - Simplify task SDK logging redirection As part of this discovery that Celery captures stdout/stderr itself (and before disabling that) I discovered a simpler way to re-open the stdin/out/err so that the implementation needs fewer/no special casing. - Make JSON task logs more readable by giving them a consistent/useful order We re-order (by re-creating) the event_dict so that timestamp, level, and then even are always the first items in the dict - Makes the CeleryExecutor understand the concept of "workloads" instead a command tuple. This change isn't done in the best way, but until Kube executor is swapped over (and possibly the other in-tree executors, such as ECS) we need to support both styles concurrently. The change should be done in such a way that the provider still works with Airflow v2, if it's running on that version. - Upgrade Celery This turned out to not be 100% necessary but it does fix some deprecation warnings when running on Python 3.12 - Ensure that the forked process in TaskSDK _never ever_ exits Again, this isn't possible usually, but since the setup step of `_fork_main` died, it didn't call `os._exit()`, and was caught further up, which meant the process stayed alive as it never closed the sockets properly. We put and extra safety try/except block in place to catch that I have not yet included a newsfragment for changing the executor interface as the old style is _currently_ still supported. --- airflow/executors/base_executor.py | 41 ++++++++--- airflow/jobs/scheduler_job_runner.py | 4 +- generated/provider_dependencies.json | 2 +- providers/celery/pyproject.toml | 4 +- .../providers/celery/cli/celery_command.py | 5 ++ .../celery/executors/celery_executor.py | 37 +++++++++- .../celery/executors/celery_executor_utils.py | 70 ++++++++++++++----- .../celery/executors/default_celery.py | 15 ++-- .../airflow/sdk/execution_time/supervisor.py | 49 +++++++------ task_sdk/src/airflow/sdk/log.py | 53 ++++++++++---- 10 files changed, 201 insertions(+), 79 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index a0f48c74b1356..a9cff9a88a59c 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -223,7 +223,12 @@ def has_task(self, task_instance: TaskInstance) -> bool: :param task_instance: TaskInstance :return: True if the task is known to this executor """ - return task_instance.key in self.queued_tasks or task_instance.key in self.running + return ( + task_instance.id in self.queued_tasks + or task_instance.id in self.running + or task_instance.key in self.queued_tasks + or task_instance.key in self.running + ) def sync(self) -> None: """ @@ -319,6 +324,20 @@ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, QueuedTa :return: List of tuples from the queued_tasks according to the priority. """ + from airflow.executors import workloads + + if not self.queued_tasks: + return [] + + kind = next(iter(self.queued_tasks.values())) + if isinstance(kind, workloads.BaseActivity): + # V3 + new executor that supports workloads + return sorted( + self.queued_tasks.items(), + key=lambda x: x[1].ti.priority_weight, + reverse=True, + ) + return sorted( self.queued_tasks.items(), key=lambda x: x[1][1], @@ -332,12 +351,12 @@ def trigger_tasks(self, open_slots: int) -> None: :param open_slots: Number of open slots """ - span = Trace.get_current_span() sorted_queue = self.order_queued_tasks_by_priority() task_tuples = [] + workloads = [] for _ in range(min((open_slots, len(self.queued_tasks)))): - key, (command, _, queue, ti) = sorted_queue.pop(0) + key, item = sorted_queue.pop(0) # If a task makes it here but is still understood by the executor # to be running, it generally means that the task has been killed @@ -375,15 +394,19 @@ def trigger_tasks(self, open_slots: int) -> None: else: if key in self.attempts: del self.attempts[key] - task_tuples.append((key, command, queue, ti.executor_config)) - if span.is_recording(): - span.add_event( - name="task to trigger", - attributes={"command": str(command), "conf": str(ti.executor_config)}, - ) + # TODO: TaskSDK: Compat, remove when KubeExecutor is fully moved over to TaskSDK too. + # TODO: TaskSDK: We need to minimum version requirements on executors with Airflow 3. + # How/where do we do that? Executor loader? + if hasattr(self, "_process_workloads"): + workloads.append(item) + else: + (command, _, queue, ti) = item + task_tuples.append((key, command, queue, getattr(ti, "executor_config", None))) if task_tuples: self._process_tasks(task_tuples) + elif workloads: + self._process_workloads(workloads) # type: ignore[attr-defined] @add_span def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 92b7c2b0010ed..b2c8c8220b993 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -836,7 +836,6 @@ def process_executor_events( ) if info is not None: msg += " Extra info: %s" % info # noqa: RUF100, UP031, flynt - cls.logger().error(msg) session.add(Log(event="state mismatch", extra=msg, task_instance=ti.key)) # Get task from the Serialized DAG @@ -849,6 +848,9 @@ def process_executor_events( continue ti.task = task if task.on_retry_callback or task.on_failure_callback: + # Only log the error/extra info here, since the `ti.handle_failure()` path will log it + # too, which would lead to double logging + cls.logger().error(msg) request = TaskCallbackRequest( full_filepath=ti.dag_model.fileloc, ti=ti, diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 6b396cfb2925a..120d6c168ff5f 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -335,7 +335,7 @@ "celery": { "deps": [ "apache-airflow>=2.9.0", - "celery[redis]>=5.3.0,<6,!=5.3.3,!=5.3.2", + "celery[redis]>=5.4.0,<6", "flower>=1.0.0", "google-re2>=1.0" ], diff --git a/providers/celery/pyproject.toml b/providers/celery/pyproject.toml index 011a3812c4803..959a219c2ebc3 100644 --- a/providers/celery/pyproject.toml +++ b/providers/celery/pyproject.toml @@ -59,9 +59,7 @@ dependencies = [ # The Celery is known to introduce problems when upgraded to a MAJOR version. Airflow Core # Uses Celery for CeleryExecutor, and we also know that Kubernetes Python client follows SemVer # (https://docs.celeryq.dev/en/stable/contributing.html?highlight=semver#versions). - # Make sure that the limit here is synchronized with [celery] extra in the airflow core - # The 5.3.3/5.3.2 limit comes from https://github.com/celery/celery/issues/8470 - "celery[redis]>=5.3.0,<6,!=5.3.3,!=5.3.2", + "celery[redis]>=5.4.0,<6", "flower>=1.0.0", "google-re2>=1.0", ] diff --git a/providers/celery/src/airflow/providers/celery/cli/celery_command.py b/providers/celery/src/airflow/providers/celery/cli/celery_command.py index aaff91bff226c..aa0a0ec2ebe57 100644 --- a/providers/celery/src/airflow/providers/celery/cli/celery_command.py +++ b/providers/celery/src/airflow/providers/celery/cli/celery_command.py @@ -154,6 +154,11 @@ def worker(args): # This needs to be imported locally to not trigger Providers Manager initialization from airflow.providers.celery.executors.celery_executor import app as celery_app + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.log import configure_logging + + configure_logging(output=sys.stdout.buffer) + # Disable connection pool so that celery worker does not hold an unnecessary db connection settings.reconfigure_orm(disable_connection_pool=True) if not settings.validate_session(): diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index d9121dcd7ab32..630b8afb3f69a 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -53,7 +53,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning, AirflowTaskTimeout from airflow.executors.base_executor import BaseExecutor -from airflow.providers.celery.version_compat import AIRFLOW_V_2_8_PLUS +from airflow.providers.celery.version_compat import AIRFLOW_V_2_8_PLUS, AIRFLOW_V_3_0_PLUS from airflow.stats import Stats from airflow.utils.state import TaskInstanceState from celery import states as celery_states @@ -67,6 +67,9 @@ if TYPE_CHECKING: import argparse + from sqlalchemy.orm import Session + + from airflow.executors import workloads from airflow.executors.base_executor import CommandType, TaskTuple from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -228,6 +231,11 @@ class CeleryExecutor(BaseExecutor): supports_ad_hoc_ti_run: bool = True supports_sentry: bool = True + if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: + # In the v3 path, we store workloads, not commands as strings. + # TODO: TaskSDK: move this type change into BaseExecutor + queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] + def __init__(self): super().__init__() @@ -256,10 +264,25 @@ def _num_tasks_per_send_process(self, to_send_count: int) -> int: return max(1, math.ceil(to_send_count / self._sync_parallelism)) def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: + # Airflow V2 version from airflow.providers.celery.executors.celery_executor_utils import execute_command task_tuples_to_send = [task_tuple[:3] + (execute_command,) for task_tuple in task_tuples] - first_task = next(t[3] for t in task_tuples_to_send) + + self._send_tasks(task_tuples_to_send) + + def _process_workloads(self, workloads: list[workloads.All]) -> None: + # Airflow V3 version + from airflow.providers.celery.executors.celery_executor_utils import execute_workload + + tasks = [ + (workload.ti.key, (workload.model_dump_json(),), workload.ti.queue, execute_workload) + for workload in workloads + ] + self._send_tasks(tasks) + + def _send_tasks(self, task_tuples_to_send): + first_task = next(t[-1] for t in task_tuples_to_send) # Celery state queries will stuck if we do not use one same backend # for all tasks. @@ -359,7 +382,7 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None self.success(key, info) elif state in (celery_states.FAILURE, celery_states.REVOKED): self.fail(key, info) - elif state in (celery_states.STARTED, celery_states.PENDING): + elif state in (celery_states.STARTED, celery_states.PENDING, celery_states.RETRY): pass else: self.log.info("Unexpected state for %s: %s", key, state) @@ -416,6 +439,10 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task for celery_task_id, (state, info) in states_by_celery_task_id.items(): result, ti = celery_tasks[celery_task_id] result.backend = cached_celery_backend + if isinstance(result.result, BaseException): + e = result.result + # Log the exception we got from the remote end + self.log.warning("Task %s tailed with error", ti.key, exc_info=e) # Set the correct elements of the state dicts, then update this # like we just queried it. @@ -475,6 +502,10 @@ def get_cli_commands() -> list[GroupCommand]: ), ] + def queue_workload(self, workload: workloads.ExecuteTask, session: Session | None) -> None: + ti = workload.ti + self.queued_tasks[ti.key] = workload + def _get_parser() -> argparse.ArgumentParser: """ diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 6d88d9f578d24..6f9b02f342ba5 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -40,8 +40,9 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout from airflow.executors.base_executor import BaseExecutor +from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.stats import Stats -from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.providers_configuration_loader import providers_configuration_loaded @@ -125,21 +126,54 @@ def on_celery_import_modules(*args, **kwargs): import kubernetes.client # noqa: F401 -@app.task -def execute_command(command_to_exec: CommandType) -> None: - """Execute command.""" - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command_to_exec) +# Once Celery5 is out of beta, we can pass `pydantic=True` to the decorator and it will handle the validation +# and deserialization for us +@app.task(name="execute_workload") +def execute_workload(input: str) -> None: + from pydantic import TypeAdapter + + from airflow.configuration import conf + from airflow.executors import workloads + from airflow.sdk.execution_time.supervisor import supervise + + decoder = TypeAdapter(workloads.All) + workload = decoder.validate_json(input) + celery_task_id = app.current_task.request.id - log.info("[%s] Executing command in Celery: %s", celery_task_id, command_to_exec) - with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): - try: - if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: - _execute_in_subprocess(command_to_exec, celery_task_id) - else: - _execute_in_fork(command_to_exec, celery_task_id) - except Exception: - Stats.incr("celery.execute_command.failure") - raise + + if not isinstance(workload, workloads.ExecuteTask): + raise ValueError(f"CeleryExecutor does not now how to handle {type(workload)}") + + log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload) + + supervise( + # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. + ti=workload.ti, # type: ignore[arg-type] + dag_rel_path=workload.dag_rel_path, + bundle_info=workload.bundle_info, + token=workload.token, + server=conf.get("workers", "execution_api_server_url", fallback="http://localhost:9091/execution/"), + log_path=workload.log_path, + ) + + +if not AIRFLOW_V_3_0_PLUS: + + @app.task + def execute_command(command_to_exec: CommandType) -> None: + """Execute command.""" + dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command_to_exec) + celery_task_id = app.current_task.request.id + log.info("[%s] Executing command in Celery: %s", celery_task_id, command_to_exec) + with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): + try: + if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: + _execute_in_subprocess(command_to_exec, celery_task_id) + else: + _execute_in_fork(command_to_exec, celery_task_id) + except Exception: + Stats.incr("celery.execute_command.failure") + raise def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None = None) -> None: @@ -213,15 +247,15 @@ def send_task_to_executor( task_tuple: TaskInstanceInCelery, ) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]: """Send task to executor.""" - key, command, queue, task_to_run = task_tuple + key, args, queue, task_to_run = task_tuple try: with timeout(seconds=OPERATION_TIMEOUT): - result = task_to_run.apply_async(args=[command], queue=queue) + result = task_to_run.apply_async(args=args, queue=queue) except (Exception, AirflowTaskTimeout) as e: exception_traceback = f"Celery Task ID: {key}\n{traceback.format_exc()}" result = ExceptionWithTraceback(e, exception_traceback) - return key, command, result + return key, args, result def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | ExceptionWithTraceback, Any]: diff --git a/providers/celery/src/airflow/providers/celery/executors/default_celery.py b/providers/celery/src/airflow/providers/celery/executors/default_celery.py index 20c307a77b04f..9fb4a7e3bbbb6 100644 --- a/providers/celery/src/airflow/providers/celery/executors/default_celery.py +++ b/providers/celery/src/airflow/providers/celery/executors/default_celery.py @@ -27,6 +27,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS def _broker_supports_visibility_timeout(url): @@ -67,7 +68,7 @@ def _broker_supports_visibility_timeout(url): result_backend = conf.get_mandatory_value("celery", "RESULT_BACKEND") else: log.debug("Value for celery result_backend not found. Using sql_alchemy_conn with db+ prefix.") - result_backend = f'db+{conf.get("database", "SQL_ALCHEMY_CONN")}' + result_backend = f"db+{conf.get('database', 'SQL_ALCHEMY_CONN')}" extra_celery_config = conf.getjson("celery", "extra_celery_config", fallback={}) @@ -81,6 +82,9 @@ def _broker_supports_visibility_timeout(url): "task_track_started": conf.getboolean("celery", "task_track_started", fallback=True), "broker_url": broker_url, "broker_transport_options": broker_transport_options, + "broker_connection_retry_on_startup": conf.getboolean( + "celery", "broker_connection_retry_on_startup", fallback=True + ), "result_backend": result_backend, "database_engine_options": conf.getjson( "celery", "result_backend_sqlalchemy_engine_options", fallback={} @@ -90,6 +94,11 @@ def _broker_supports_visibility_timeout(url): **(extra_celery_config if isinstance(extra_celery_config, dict) else {}), } +# In order to not change anything pre Task Execution API, we leave this setting as it was (unset) in Airflow2 +if AIRFLOW_V_3_0_PLUS: + DEFAULT_CELERY_CONFIG.setdefault("worker_redirect_stdouts", False) + DEFAULT_CELERY_CONFIG.setdefault("worker_hijack_root_logger", False) + def _get_celery_ssl_active() -> bool: try: @@ -126,9 +135,7 @@ def _get_celery_ssl_active() -> bool: DEFAULT_CELERY_CONFIG["broker_use_ssl"] = broker_use_ssl except AirflowConfigException: raise AirflowException( - "AirflowConfigException: SSL_ACTIVE is True, " - "please ensure SSL_KEY, " - "SSL_CERT and SSL_CACERT are set" + "AirflowConfigException: SSL_ACTIVE is True, please ensure SSL_KEY, SSL_CERT and SSL_CACERT are set" ) except Exception as e: raise AirflowException( diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 569855016cfe6..dde84dd5229a4 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -165,38 +165,27 @@ def _configure_logs_over_json_channel(log_fd: int): from airflow.sdk.log import configure_logging log_io = os.fdopen(log_fd, "wb", buffering=0) - configure_logging(enable_pretty_log=False, output=log_io) + configure_logging(enable_pretty_log=False, output=log_io, sending_to_supervisor=True) def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): - if "PYTEST_CURRENT_TEST" in os.environ: - # When we are running in pytest, it's output capturing messes us up. This works around it - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the # pipes from the supervisor - for handle_name, sock, mode in ( - ("stdin", child_stdin, "r"), - ("stdout", child_stdout, "w"), - ("stderr", child_stderr, "w"), + for handle_name, fd, sock, mode in ( + ("stdin", 0, child_stdin, "r"), + ("stdout", 1, child_stdout, "w"), + ("stderr", 2, child_stderr, "w"), ): handle = getattr(sys, handle_name) - try: - fd = handle.fileno() - os.dup2(sock.fileno(), fd) - # dup2 creates another open copy of the fd, we can close the "socket" copy of it. - sock.close() - except io.UnsupportedOperation: - if "PYTEST_CURRENT_TEST" in os.environ: - # When we're running under pytest, the stdin is not a real filehandle with an fd, so we need - # to handle that differently - fd = sock.fileno() - else: - raise - # We can't open text mode fully unbuffered (python throws an exception if we try), but we can make it line buffered with `buffering=1` - handle = os.fdopen(fd, mode, buffering=1) + handle.close() + os.dup2(sock.fileno(), fd) + del sock + + # We open the socket/fd as binary, and then pass it to a TextIOWrapper so that it looks more like a + # normal sys.stdout etc. + binary = os.fdopen(fd, mode + "b") + handle = io.TextIOWrapper(binary, line_buffering=True) setattr(sys, handle_name, handle) @@ -352,7 +341,17 @@ def start( del logger # Run the child entrypoint - _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + try: + _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + except BaseException as e: + try: + print("Exception in _fork_main, exiting with code 124", e, file=sys.stderr) + except BaseException as e: + pass + + # It's really super super important we never exit this block. We are in the forked child, and if we + # do then _THINGS GET WEIRD_.. (Normally `_fork_main` itself will `_exit()` so we never get here) + os._exit(124) requests_fd = child_comms.fileno() diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py index fa5b113588bf5..8549518e205b2 100644 --- a/task_sdk/src/airflow/sdk/log.py +++ b/task_sdk/src/airflow/sdk/log.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import io import itertools import logging.config import os @@ -39,7 +40,9 @@ ] -def exception_group_tracebacks(format_exception: Callable[[ExcInfo], list[dict[str, Any]]]) -> Processor: +def exception_group_tracebacks( + format_exception: Callable[[ExcInfo], list[dict[str, Any]]], +) -> Processor: # Make mypy happy if not hasattr(__builtins__, "BaseExceptionGroup"): T = TypeVar("T") @@ -178,13 +181,6 @@ def logging_processors( "console": console, } else: - # Imports to suppress showing code from these modules - import contextlib - - import click - import httpcore - import httpx - dict_exc_formatter = structlog.tracebacks.ExceptionDictTransformer( use_rich=False, show_locals=False, suppress=suppress ) @@ -197,9 +193,19 @@ def logging_processors( exc_group_processor = None def json_dumps(msg, default): + # Note: this is likely an "expensive" step, but lets massage the dict order for nice + # viewing of the raw JSON logs. + # Maybe we don't need this once the UI renders the JSON instead of displaying the raw text + msg = { + "timestamp": msg.pop("timestamp"), + "level": msg.pop("level"), + "event": msg.pop("event"), + **msg, + } return msgspec.json.encode(msg, enc_hook=default) def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: + # Stdlib logging doesn't need the re-ordering, it's fine as it is return msgspec.json.encode(event_dict).decode("utf-8") json = structlog.processors.JSONRenderer(serializer=json_dumps) @@ -224,13 +230,11 @@ def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: def configure_logging( enable_pretty_log: bool = True, log_level: str = "DEBUG", - output: BinaryIO | None = None, + output: BinaryIO | TextIO | None = None, cache_logger_on_first_use: bool = True, + sending_to_supervisor: bool = False, ): """Set up struct logging and stdlib logging config.""" - if enable_pretty_log and output is not None: - raise ValueError("output can only be set if enable_pretty_log is not") - lvl = structlog.stdlib.NAME_TO_LEVEL[log_level.lower()] if enable_pretty_log: @@ -263,13 +267,30 @@ def configure_logging( wrapper_class = structlog.make_filtering_bound_logger(lvl) if enable_pretty_log: + if output is not None and not isinstance(output, TextIO): + wrapper = io.TextIOWrapper(output, line_buffering=True) + logger_factory = structlog.WriteLoggerFactory(wrapper) + else: + logger_factory = structlog.WriteLoggerFactory(output) structlog.configure( processors=processors, cache_logger_on_first_use=cache_logger_on_first_use, wrapper_class=wrapper_class, + logger_factory=logger_factory, ) color_formatter.append(named["console"]) else: + if output is not None and "b" not in output.mode: + if not hasattr(output, "buffer"): + raise ValueError( + f"output needed to be a binary stream, but it didn't have a buffer attribute ({output=})" + ) + else: + output = output.buffer + if TYPE_CHECKING: + # Not all binary streams are isinstance of BinaryIO, so we check via looking at `mode` at + # runtime. mypy doesn't grok that though + assert isinstance(output, BinaryIO) structlog.configure( processors=processors, cache_logger_on_first_use=cache_logger_on_first_use, @@ -324,7 +345,7 @@ def configure_logging( "loggers": { # Set Airflow logging to the level requested, but most everything else at "INFO" "": { - "handlers": ["to_supervisor" if output else "default"], + "handlers": ["to_supervisor" if sending_to_supervisor else "default"], "level": "INFO", "propagate": True, }, @@ -413,10 +434,12 @@ def init_log_file(local_relative_path: str) -> Path: from airflow.configuration import conf new_file_permissions = int( - conf.get("logging", "file_task_handler_new_file_permissions", fallback="0o664"), 8 + conf.get("logging", "file_task_handler_new_file_permissions", fallback="0o664"), + 8, ) new_folder_permissions = int( - conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), 8 + conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), + 8, ) base_log_folder = conf.get("logging", "base_log_folder")