Skip to content

Commit

Permalink
Adding DataSync links (#46292)
Browse files Browse the repository at this point in the history
  • Loading branch information
ellisms authored Jan 31, 2025
1 parent 000796a commit 6d49170
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 2 deletions.
37 changes: 37 additions & 0 deletions providers/src/airflow/providers/amazon/aws/links/datasync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink


class DataSyncTaskLink(BaseAwsLink):
"""Helper class for constructing AWS DataSync Task console link."""

name = "DataSync Task"
key = "datasync_task"
format_str = BASE_AWS_CONSOLE_LINK + "/datasync/home?region={region_name}#" + "/tasks/{task_id}"


class DataSyncTaskExecutionLink(BaseAwsLink):
"""Helper class for constructing AWS DataSync TaskExecution console link."""

name = "DataSync Task Execution"
key = "datasync_task_execution"
format_str = (
BASE_AWS_CONSOLE_LINK + "/datasync/home?region={region_name}#/history/{task_id}/{task_execution_id}"
)
42 changes: 40 additions & 2 deletions providers/src/airflow/providers/amazon/aws/operators/datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
from airflow.providers.amazon.aws.links.datasync import DataSyncTaskExecutionLink, DataSyncTaskLink
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

Expand Down Expand Up @@ -130,6 +131,8 @@ class DataSyncOperator(AwsBaseOperator[DataSyncHook]):
}
ui_color = "#44b5e2"

operator_extra_links = (DataSyncTaskLink(), DataSyncTaskExecutionLink())

def __init__(
self,
*,
Expand Down Expand Up @@ -215,14 +218,31 @@ def execute(self, context: Context):
if not self.task_arn:
raise AirflowException("DataSync TaskArn could not be identified or created.")

task_id = self.task_arn.split("/")[-1]

task_url = DataSyncTaskLink.format_str.format(
aws_domain=DataSyncTaskLink.get_aws_domain(self.hook.conn_partition),
region_name=self.hook.conn_region_name,
task_id=task_id,
)

DataSyncTaskLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
task_id=task_id,
)
self.log.info("You can view this DataSync task at %s", task_url)

self.log.info("Using DataSync TaskArn %s", self.task_arn)

# Update the DataSync Task
if self.update_task_kwargs:
self._update_datasync_task()

# Execute the DataSync Task
self._execute_datasync_task()
self._execute_datasync_task(context=context)

if not self.task_execution_arn:
raise AirflowException("Nothing was executed")
Expand Down Expand Up @@ -327,7 +347,7 @@ def _update_datasync_task(self) -> None:
self.hook.update_task(self.task_arn, **self.update_task_kwargs)
self.log.info("Updated TaskArn %s", self.task_arn)

def _execute_datasync_task(self) -> None:
def _execute_datasync_task(self, context: Context) -> None:
"""Create and monitor an AWS DataSync TaskExecution for a Task."""
if not self.task_arn:
raise AirflowException("Missing TaskArn")
Expand All @@ -337,6 +357,24 @@ def _execute_datasync_task(self) -> None:
self.task_execution_arn = self.hook.start_task_execution(self.task_arn, **self.task_execution_kwargs)
self.log.info("Started TaskExecutionArn %s", self.task_execution_arn)

# Create the execution extra link
execution_url = DataSyncTaskExecutionLink.format_str.format(
aws_domain=DataSyncTaskExecutionLink.get_aws_domain(self.hook.conn_partition),
region_name=self.hook.conn_region_name,
task_id=self.task_arn.split("/")[-1],
task_execution_id=self.task_execution_arn.split("/")[-1],
)
DataSyncTaskExecutionLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
task_id=self.task_arn.split("/")[-1],
task_execution_id=self.task_execution_arn.split("/")[-1],
)

self.log.info("You can view this DataSync task execution at %s", execution_url)

if not self.wait_for_completion:
return

Expand Down
2 changes: 2 additions & 0 deletions providers/src/airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,8 @@ extra-links:
- airflow.providers.amazon.aws.links.step_function.StateMachineExecutionsDetailsLink
- airflow.providers.amazon.aws.links.comprehend.ComprehendPiiEntitiesDetectionLink
- airflow.providers.amazon.aws.links.comprehend.ComprehendDocumentClassifierLink
- airflow.providers.amazon.aws.links.datasync.DataSyncTaskLink
- airflow.providers.amazon.aws.links.datasync.DataSyncTaskExecutionLink


connection-types:
Expand Down
52 changes: 52 additions & 0 deletions providers/tests/amazon/aws/links/test_datasync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.amazon.aws.links.datasync import DataSyncTaskExecutionLink, DataSyncTaskLink

from providers.tests.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase

TASK_ID = "task-0b36221bf94ad2bdd"
EXECUTION_ID = "exec-00000000000000004"


class TestDataSyncTaskLink(BaseAwsLinksTestCase):
link_class = DataSyncTaskLink

def test_extra_link(self):
task_id = TASK_ID
self.assert_extra_link_url(
expected_url=(f"https://console.aws.amazon.com/datasync/home?region=us-east-1#/tasks/{TASK_ID}"),
region_name="us-east-1",
aws_partition="aws",
task_id=task_id,
)


class TestDataSyncTaskExecutionLink(BaseAwsLinksTestCase):
link_class = DataSyncTaskExecutionLink

def test_extra_link(self):
self.assert_extra_link_url(
expected_url=(
f"https://console.aws.amazon.com/datasync/home?region=us-east-1#/history/{TASK_ID}/{EXECUTION_ID}"
),
region_name="us-east-1",
aws_partition="aws",
task_id=TASK_ID,
task_execution_id=EXECUTION_ID,
)
22 changes: 22 additions & 0 deletions providers/tests/amazon/aws/operators/test_datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.exceptions import AirflowException
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
from airflow.providers.amazon.aws.links.datasync import DataSyncTaskLink
from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator
from airflow.utils import timezone
from airflow.utils.state import DagRunState
Expand Down Expand Up @@ -748,6 +749,27 @@ def test_init_fails(self, mock_get_conn):
# ### Check mocks:
mock_get_conn.assert_not_called()

def test_task_extra_links(self, mock_get_conn):
mock_get_conn.return_value = self.client
self.set_up_operator()

region = "us-east-1"
aws_domain = DataSyncTaskLink.get_aws_domain("aws")
task_id = self.task_arn.split("/")[-1]

base_url = f"https://console.{aws_domain}/datasync/home?region={region}#"
task_url = f"{base_url}/tasks/{task_id}"

with mock.patch.object(self.datasync.log, "info") as mock_logging:
result = self.datasync.execute(None)
task_execution_arn = result["TaskExecutionArn"]
execution_id = task_execution_arn.split("/")[-1]
execution_url = f"{base_url}/history/{task_id}/{execution_id}"

assert self.datasync.task_arn == self.task_arn
mock_logging.assert_any_call("You can view this DataSync task at %s", task_url)
mock_logging.assert_any_call("You can view this DataSync task execution at %s", execution_url)

def test_execute_task(self, mock_get_conn):
# ### Set up mocks:
mock_get_conn.return_value = self.client
Expand Down

0 comments on commit 6d49170

Please sign in to comment.