From 1a413d6d9537c5ba82c7e5622b58382484d0dbba Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Fri, 24 Jan 2025 16:19:21 +0530 Subject: [PATCH] Test async dag --- cosmos/operators/airflow_async.py | 45 ++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index 56056f143..ed73fc91d 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from typing import Any, Sequence from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator @@ -8,10 +9,12 @@ from cosmos.config import ProfileConfig from cosmos.constants import BIGQUERY_PROFILE_TYPE from cosmos.exceptions import CosmosValueError +from cosmos.operators.base import AbstractDbtBaseOperator from cosmos.operators.local import ( DbtBuildLocalOperator, DbtCloneLocalOperator, DbtCompileLocalOperator, + DbtLocalBaseOperator, DbtLSLocalOperator, DbtRunLocalOperator, DbtRunOperationLocalOperator, @@ -57,7 +60,13 @@ class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalO class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator, DbtRunLocalOperator): # type: ignore - template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("full_refresh", "project_dir", "location") # type: ignore[operator] + template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ( # type: ignore[operator] + "full_refresh", + "project_dir", + "gcp_project", + "dataset", + "location", + ) def __init__( # type: ignore self, @@ -84,15 +93,39 @@ def __init__( # type: ignore self.location = location self.configuration = configuration or {} self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore - - super().__init__( - project_dir=self.project_dir, - profile_config=self.profile_config, + profile = self.profile_config.profile_mapping.profile + self.gcp_project = profile["project"] + self.dataset = profile["dataset"] + + # Cosmos attempts to pass many kwargs that BigQueryInsertJobOperator simply does not accept. + # We need to pop them. + async_op_kwargs = {} + cosmos_op_kwargs = {} + non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys()) + non_async_args |= set(inspect.signature(DbtLocalBaseOperator.__init__).parameters.keys()) + non_async_args -= {"task_id"} + + for arg_key, arg_value in kwargs.items(): + if arg_key not in non_async_args: + async_op_kwargs[arg_key] = arg_value + else: + cosmos_op_kwargs[arg_key] = arg_value + + # The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode + BigQueryInsertJobOperator.__init__( + self, gcp_conn_id=self.gcp_conn_id, configuration=self.configuration, location=self.location, deferrable=True, - **kwargs, + **async_op_kwargs, + ) + + DbtRunLocalOperator.__init__( + self, + project_dir=self.project_dir, + profile_config=self.profile_config, + **cosmos_op_kwargs, ) self.async_context = extra_context or {} self.async_context["profile_type"] = self.profile_type