Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def transfer_files_async(self, files: list[str], gcs_hook: GCSHook, s3_hook: S3H
gcp_conn_id=self.gcp_conn_id,
job_names=job_names,
poll_interval=self.poll_interval,
files=files,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -334,16 +335,18 @@ def submit_transfer_jobs(self, files: list[str], gcs_hook: GCSHook, s3_hook: S3H

return job_names

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
def execute_complete(self, context: Context, event: dict[str, Any]) -> list[str] | None:
"""
Return immediately and relies on trigger to throw a success event. Callback for the trigger.
Handle the trigger callback when transfer jobs complete.

Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
Returns the list of copied file paths when available (deferrable mode with
files passed via trigger), so subsequent tasks can consume them via XCom.
Returns None when event does not contain files (e.g. legacy triggers).
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info("%s completed with response %s ", self.task_id, event["message"])
return event.get("files")

def get_transfer_hook(self):
return CloudDataTransferServiceHook(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger):
:param project_id: GCP project id.
:param poll_interval: Interval in seconds between polls.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param files: Optional list of file paths being transferred. When provided, included in the
success event for the operator to return to subsequent tasks (e.g. for XCom).
"""

def __init__(
Expand All @@ -50,12 +52,14 @@ def __init__(
project_id: str = PROVIDE_PROJECT_ID,
poll_interval: int = 10,
gcp_conn_id: str = "google_cloud_default",
files: list[str] | None = None,
) -> None:
super().__init__()
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.job_names = job_names
self.poll_interval = poll_interval
self.files = files

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize StorageTransferJobsTrigger arguments and classpath."""
Expand All @@ -66,6 +70,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"job_names": self.job_names,
"poll_interval": self.poll_interval,
"gcp_conn_id": self.gcp_conn_id,
"files": self.files,
},
)

Expand Down Expand Up @@ -117,13 +122,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.log.info("Transfer jobs completed: %s of %s", jobs_completed_successfully, jobs_total)
if jobs_completed_successfully == jobs_total:
s = "s" if jobs_total > 1 else ""
job_names = ", ".join(j for j in self.job_names)
yield TriggerEvent(
{
"status": "success",
"message": f"Transfer job{s} {job_names} completed successfully",
}
)
job_names_str = ", ".join(j for j in self.job_names)
event_payload: dict[str, Any] = {
"status": "success",
"message": f"Transfer job{s} {job_names_str} completed successfully",
}
if self.files is not None:
event_payload["files"] = self.files
yield TriggerEvent(event_payload)
return

self.log.info("Sleeping for %s seconds", self.poll_interval)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def test_execute_deferrable(self, mock_gcs_hook, mock_s3_hook, mock_transfer_hoo
assert trigger.project_id == PROJECT_ID
assert trigger.job_names == [TRANSFER_JOB_ID_0]
assert trigger.poll_interval == operator.poll_interval
assert trigger.files == MOCK_FILES

assert hasattr(exception_info.value, "method_name")
assert exception_info.value.method_name == "execute_complete"
Expand Down Expand Up @@ -386,6 +387,7 @@ def test_transfer_files_async(
assert trigger.project_id == PROJECT_ID
assert trigger.job_names == expected_job_names
assert trigger.poll_interval == operator.poll_interval
assert trigger.files == MOCK_FILES

assert hasattr(exception_info.value, "method_name")
assert exception_info.value.method_name == expected_method_name
Expand Down Expand Up @@ -486,11 +488,28 @@ def test_execute_complete_success(self, mock_log):
"message": expected_event_message,
}
operator = S3ToGCSOperator(task_id=TASK_ID, bucket=S3_BUCKET)
operator.execute_complete(context=mock.MagicMock(), event=event)
result = operator.execute_complete(context={}, event=event)

mock_log.return_value.info.assert_called_once_with(
"%s completed with response %s ", TASK_ID, event["message"]
)
assert result is None

@mock.patch(
"airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator.log", new_callable=PropertyMock
)
def test_execute_complete_success_returns_copied_files(self, mock_log):
"""Deferrable mode returns list of copied files for use in subsequent tasks via XCom."""
expected_files = [MOCK_FILE_1, MOCK_FILE_2]
event = {
"status": "success",
"message": "Transfer completed",
"files": expected_files,
}
operator = S3ToGCSOperator(task_id=TASK_ID, bucket=S3_BUCKET)
result = operator.execute_complete(context={}, event=event)

assert result == expected_files

@mock.patch(
"airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator.log", new_callable=PropertyMock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_serialize(self, trigger):
"job_names": JOB_NAMES,
"poll_interval": POLL_INTERVAL,
"gcp_conn_id": GCP_CONN_ID,
"files": None,
}

def test_get_async_hook(self, trigger):
Expand Down Expand Up @@ -119,6 +120,31 @@ async def test_run(self, get_jobs, get_latest_operation, trigger):

assert actual_event == expected_event

@pytest.mark.asyncio
@mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_latest_operation")
@mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_jobs")
async def test_run_includes_files_in_success_event(self, get_jobs, get_latest_operation):
"""When trigger has files, success event includes them for operator to return via XCom."""
TRANSFERRED_FILES = ["path/file1.csv", "path/file2.csv"]
trigger_with_files = CloudStorageTransferServiceCreateJobsTrigger(
project_id=PROJECT_ID,
job_names=JOB_NAMES,
poll_interval=POLL_INTERVAL,
gcp_conn_id=GCP_CONN_ID,
files=TRANSFERRED_FILES,
)
get_jobs.return_value = mock_jobs(names=JOB_NAMES, latest_operation_names=LATEST_OPERATION_NAMES)
get_latest_operation.side_effect = [
create_mock_operation(status=TransferOperation.Status.SUCCESS, name="operation_" + job_name)
for job_name in JOB_NAMES
]

generator = trigger_with_files.run()
actual_event = await generator.asend(None)

assert actual_event.payload["status"] == "success"
assert actual_event.payload["files"] == TRANSFERRED_FILES

@pytest.mark.parametrize(
"status",
[
Expand Down
Loading