From 56387cc2d620d08b9f361ca80120765da95e79e7 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 29 Nov 2024 13:56:12 +0000 Subject: [PATCH] Support using callable `config` in `@ray.task` (#103) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Ray provider 0.2.1 allowed users to define a hard-coded configuration to materialize the Kubernetes cluster. This PR aims to enable users to define a function that can receive the Airflow context and generate the configuration dynamically using context properties. This request came from an Astronomer customer. There is an example DAG file illustrating how to use this feature. It has a parent DAG that triggers two child DAGs, which leverage the just introduced `@ray.task` callable configuration. The screenshots below show their success, when using the [local development instructions](https://github.com/astronomer/astro-provider-ray/blob/main/docs/getting_started/local_development_setup.rst) using Astro CLI. Parent DAG: Screenshot 2024-11-29 at 12 15 13 Child 1 DAG: Screenshot 2024-11-29 at 12 15 56 Example of logs that illustrate the RayCluster using dynamic configuration was created and used in Kubernetes, with its own IP address: ``` (...) [2024-11-29T12:14:52.276+0000] {standard_task_runner.py:104} INFO - Running: ['airflow', 'tasks', 'run', 'ray_dynamic_config_child_1', 'process_data_with_ray', 'manual__2024-11-29T12:14:50.273712+00:00', '--job-id', '773', '--raw', '--subdir', 'DAGS_FOLDER/ray_dynamic_config.py', '--cfg-path', '/tmp/tmpkggwlv23'] [2024-11-29T12:14:52.278+0000] {logging_mixin.py:190} WARNING - /usr/local/lib/python3.12/site-packages/airflow/task/task_runner/standard_task_runner.py:70 DeprecationWarning: This process (pid=238) is multi-threaded, use of fork() may lead to deadlocks in the child. (...) [2024-11-29T12:14:52.745+0000] {decorators.py:94} INFO - Using the following config {'conn_id': 'ray_conn', 'runtime_env': {'working_dir': '/usr/local/airflow/dags/ray_scripts', 'pip': ['numpy']}, 'num_cpus': 1, 'num_gpus': 0, 'memory': 0, 'poll_interval': 5, 'ray_cluster_yaml': '/usr/local/airflow/dags/scripts/first-254.yaml', 'xcom_task_key': 'dashboard'} (...) [2024-11-29T12:14:55.430+0000] {hooks.py:474} INFO - ::group::Create Ray Cluster [2024-11-29T12:14:55.430+0000] {hooks.py:475} INFO - Loading yaml content for Ray cluster CRD... [2024-11-29T12:14:55.451+0000] {hooks.py:410} INFO - Creating new Ray cluster: first-254 [2024-11-29T12:14:55.456+0000] {hooks.py:494} INFO - ::endgroup:: (...) [2024-11-29T12:14:55.663+0000] {hooks.py:498} INFO - ::group::Setup Load Balancer service [2024-11-29T12:14:55.663+0000] {hooks.py:334} INFO - Attempt 1: Checking LoadBalancer status... [2024-11-29T12:14:55.669+0000] {hooks.py:278} ERROR - Error getting service first-254-head-svc: (404) Reason: Not Found HTTP response headers: HTTPHeaderDict({'Audit-Id': '81b07ac4-db3b-48a6-b336-f52ae93bee55', 'Cache-Control': 'no-cache, private', 'Content-Type': 'application/json', 'X-Kubernetes-Pf-Flowschema-Uid': '955e8bb0-08b1-4d45-a768-e49387a9767c', 'X-Kubernetes-Pf-Prioritylevel-Uid': 'd5240328-288d-4366-b094-d8fd793c7431', 'Date': 'Fri, 29 Nov 2024 12:14:55 GMT', 'Content-Length': '212'}) HTTP response body: {"kind":"Status","apiVersion":"v1","metadata":{},"status":"Failure","message":"services \"first-254-head-svc\" not found","reason":"NotFound","details":{"name":"first-254-head-svc","kind":"services"},"code":404} [2024-11-29T12:14:55.669+0000] {hooks.py:355} INFO - LoadBalancer service is not available yet... [2024-11-29T12:15:35.670+0000] {hooks.py:334} INFO - Attempt 2: Checking LoadBalancer status... [2024-11-29T12:15:35.688+0000] {hooks.py:348} INFO - LoadBalancer is ready. [2024-11-29T12:15:35.688+0000] {hooks.py:441} INFO - {'ip': '172.18.255.1', 'hostname': None, 'ports': [{'name': 'client', 'port': 10001}, {'name': 'dashboard', 'port': 8265}, {'name': 'gcs', 'port': 6379}, {'name': 'metrics', 'port': 8080}, {'name': 'serve', 'port': 8000}], 'working_address': '172.18.255.1'} (...) [2024-11-29T12:15:38.345+0000] {triggers.py:124} INFO - ::group:: Trigger 1/2: Checking the job status [2024-11-29T12:15:38.345+0000] {triggers.py:125} INFO - Polling for job raysubmit_paxAkyLiKxEHPmwG every 5 seconds... (...) [2024-11-29T12:15:38.354+0000] {hooks.py:156} INFO - Dashboard URL is: http://172.18.255.1:8265 [2024-11-29T12:15:38.361+0000] {hooks.py:208} INFO - Job raysubmit_paxAkyLiKxEHPmwG status: PENDING [2024-11-29T12:15:38.361+0000] {triggers.py:100} INFO - Status of job raysubmit_paxAkyLiKxEHPmwG is: PENDING [2024-11-29T12:15:38.361+0000] {triggers.py:108} INFO - ::group::raysubmit_paxAkyLiKxEHPmwG logs [2024-11-29T12:15:43.416+0000] {hooks.py:208} INFO - Job raysubmit_paxAkyLiKxEHPmwG status: RUNNING [2024-11-29T12:15:43.416+0000] {triggers.py:100} INFO - Status of job raysubmit_paxAkyLiKxEHPmwG is: RUNNING [2024-11-29T12:15:43.417+0000] {triggers.py:112} INFO - 2024-11-29 04:15:40,813 INFO worker.py:1429 -- Using address 10.244.0.140:6379 set in the environment variable RAY_ADDRESS [2024-11-29T12:15:43.417+0000] {triggers.py:112} INFO - 2024-11-29 04:15:40,814 INFO worker.py:1564 -- Connecting to existing Ray cluster at address: 10.244.0.140:6379... [2024-11-29T12:15:43.417+0000] {triggers.py:112} INFO - 2024-11-29 04:15:40,820 INFO worker.py:1740 -- Connected to Ray cluster. View the dashboard at 10.244.0.140:8265  [2024-11-29T12:15:48.430+0000] {hooks.py:208} INFO - Job raysubmit_paxAkyLiKxEHPmwG status: SUCCEEDED [2024-11-29T12:15:48.430+0000] {triggers.py:112} INFO - Mean of this population is 12.0 [2024-11-29T12:15:48.430+0000] {triggers.py:112} INFO - (autoscaler +5s) Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0. [2024-11-29T12:15:48.430+0000] {triggers.py:112} INFO - (autoscaler +5s) Adding 1 node(s) of type small-group. [2024-11-29T12:15:49.448+0000] {triggers.py:113} INFO - ::endgroup:: [2024-11-29T12:15:49.448+0000] {triggers.py:144} INFO - ::endgroup:: [2024-11-29T12:15:49.448+0000] {triggers.py:145} INFO - ::group:: Trigger 2/2: Job reached a terminal state [2024-11-29T12:15:49.448+0000] {triggers.py:146} INFO - Status of completed job raysubmit_paxAkyLiKxEHPmwG is: SUCCEEDED (...) ``` Child 2 DAG: Screenshot 2024-11-29 at 12 17 20 Kubernetes RayClusters spun: Screenshot 2024-11-29 at 12 15 37 **Limitations** The example DAGs are not currently being executed in the CI, but there is a dedicated ticket for this work: https://github.com/astronomer/astro-provider-ray/pull/95 **References** This PR had inspiration from: https://github.com/astronomer/astro-provider-ray/pull/67 --- dev/dags/ray_dynamic_config.py | 197 +++++++++++++++++++++++++++++++++ ray_provider/decorators.py | 101 +++++++++-------- tests/test_decorators.py | 62 +++++++++-- 3 files changed, 309 insertions(+), 51 deletions(-) create mode 100644 dev/dags/ray_dynamic_config.py diff --git a/dev/dags/ray_dynamic_config.py b/dev/dags/ray_dynamic_config.py new file mode 100644 index 0000000..046b60b --- /dev/null +++ b/dev/dags/ray_dynamic_config.py @@ -0,0 +1,197 @@ +""" +This example illustrates three DAGs. One + +The parent DAG (ray_dynamic_config_upstream_dag) uses TriggerDagRunOperator to trigger the other two: +* ray_dynamic_config_downstream_dag_1 +* ray_dynamic_config_downstream_dag_2 + +Each downstream DAG retrieves the context data (run_context) from dag_run.conf, which is passed by the parent DAG. + +The print_context tasks in the downstream DAGs output the received context to the logs. +""" + +import re +from pathlib import Path + +import yaml +from airflow import DAG +from airflow.decorators import task +from airflow.operators.empty import EmptyOperator +from airflow.operators.python import PythonOperator +from airflow.operators.trigger_dagrun import TriggerDagRunOperator +from airflow.utils.dates import days_ago +from jinja2 import Template + +from ray_provider.decorators import ray + +CONN_ID = "ray_conn" +RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml" +FOLDER_PATH = Path(__file__).parent / "ray_scripts" +RAY_TASK_CONFIG = { + "conn_id": CONN_ID, + "runtime_env": {"working_dir": str(FOLDER_PATH), "pip": ["numpy"]}, + "num_cpus": 1, + "num_gpus": 0, + "memory": 0, + "poll_interval": 5, + "ray_cluster_yaml": str(RAY_SPEC), + "xcom_task_key": "dashboard", +} + + +def slugify(value): + """ + Replace invalid characters with hyphens and make lowercase. + """ + return re.sub(r"[^\w\-\.]", "-", value).lower() + + +def create_config_from_context(context, **kwargs): + default_name = "{{ dag.dag_id }}-{{ dag_run.id }}" + + raycluster_name_template = context.get("dag_run").conf.get("raycluster_name", default_name) + raycluster_name = Template(raycluster_name_template).render(context).replace("_", "-") + raycluster_name = slugify(raycluster_name) + + raycluster_k8s_yml_filename_template = context.get("dag_run").conf.get( + "raycluster_k8s_yml_filename", default_name + ".yml" + ) + raycluster_k8s_yml_filename = Template(raycluster_k8s_yml_filename_template).render(context).replace("_", "-") + raycluster_k8s_yml_filename = slugify(raycluster_k8s_yml_filename) + + with open(RAY_SPEC) as file: + data = yaml.safe_load(file) + data["metadata"]["name"] = raycluster_name + + NEW_RAY_K8S_SPEC = Path(__file__).parent / "scripts" / raycluster_k8s_yml_filename + with open(NEW_RAY_K8S_SPEC, "w") as file: + yaml.safe_dump(data, file, default_flow_style=False) + + config = dict(RAY_TASK_CONFIG) + config["ray_cluster_yaml"] = str(NEW_RAY_K8S_SPEC) + return config + + +def print_context(**kwargs): + # Retrieve `conf` passed from the parent DAG + print(kwargs) + cluster_name = kwargs.get("dag_run").conf.get("raycluster_name", "No ray cluster name provided") + raycluster_k8s_yml_filename = kwargs.get("dag_run").conf.get( + "raycluster_k8s_yml_filename", "No ray cluster YML filename provided" + ) + print(f"Received cluster name: {cluster_name}") + print(f"Received cluster K8s YML filename: {raycluster_k8s_yml_filename}") + + +# Downstream 1 +with DAG( + dag_id="ray_dynamic_config_child_1", + start_date=days_ago(1), + schedule_interval=None, + catchup=False, +) as dag: + + print_context_task = PythonOperator( + task_id="print_context", + python_callable=print_context, + ) + print_context_task + + @task + def generate_data(): + return [1, 2, 3] + + @ray.task(config=create_config_from_context) + def process_data_with_ray(data): + import numpy as np + import ray + + @ray.remote + def cubic(x): + return x**3 + + ray.init() + data = np.array(data) + futures = [cubic.remote(x) for x in data] + results = ray.get(futures) + mean = np.mean(results) + print(f"Mean of this population is {mean}") + return mean + + data = generate_data() + process_data_with_ray(data) + + +# Downstream 2 +with DAG( + dag_id="ray_dynamic_config_child_2", + start_date=days_ago(1), + schedule_interval=None, + catchup=False, +) as dag: + + print_context_task = PythonOperator( + task_id="print_context", + python_callable=print_context, + ) + + @task + def generate_data(): + return [1, 2, 3] + + @ray.task(config=create_config_from_context) + def process_data_with_ray(data): + import numpy as np + import ray + + @ray.remote + def square(x): + return x**2 + + ray.init() + data = np.array(data) + futures = [square.remote(x) for x in data] + results = ray.get(futures) + mean = np.mean(results) + print(f"Mean of this population is {mean}") + return mean + + data = generate_data() + process_data_with_ray(data) + + +# Upstream +with DAG( + dag_id="ray_dynamic_config_parent", + start_date=days_ago(1), + schedule_interval=None, + catchup=False, +) as dag: + empty_task = EmptyOperator(task_id="empty_task") + + trigger_dag_1 = TriggerDagRunOperator( + task_id="trigger_downstream_dag_1", + trigger_dag_id="ray_dynamic_config_child_1", + conf={ + "raycluster_name": "first-{{ dag_run.id }}", + "raycluster_k8s_yml_filename": "first-{{ dag_run.id }}.yaml", + }, + ) + + trigger_dag_2 = TriggerDagRunOperator( + task_id="trigger_downstream_dag_2", + trigger_dag_id="ray_dynamic_config_child_2", + conf={}, + ) + + # Illustrates that by default two DAG runs of the same DAG will be using different Ray clusters + # Disabled because in the local dev MacOS we're only managing to spin up two Ray Cluster services concurrently + # trigger_dag_3 = TriggerDagRunOperator( + # task_id="trigger_downstream_dag_3", + # trigger_dag_id="ray_dynamic_config_child_2", + # conf={}, + # ) + + empty_task >> trigger_dag_1 + trigger_dag_1 >> trigger_dag_2 + # trigger_dag_1 >> trigger_dag_3 diff --git a/ray_provider/decorators.py b/ray_provider/decorators.py index bcb150d..ed2f799 100644 --- a/ray_provider/decorators.py +++ b/ray_provider/decorators.py @@ -3,15 +3,16 @@ import inspect import os import re -import shutil +import tempfile import textwrap -from tempfile import mkdtemp +from datetime import timedelta +from pathlib import Path from typing import Any, Callable from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory -from airflow.exceptions import AirflowException from airflow.utils.context import Context +from ray_provider.exceptions import RayAirflowException from ray_provider.operators import SubmitRayJob @@ -28,10 +29,27 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob): """ custom_operator_name = "@task.ray" + _config: dict[str, Any] | Callable[..., dict[str, Any]] = dict() template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs") - def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: + def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None: + self._config = config + self.kwargs = kwargs + super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs) + + def _build_config(self, context: Context) -> dict[str, Any]: + if callable(self._config): + config_params = inspect.signature(self._config).parameters + config_kwargs = {k: v for k, v in self.kwargs.items() if k in config_params and k != "context"} + if "context" in config_params: + config_kwargs["context"] = context + config = self._config(**config_kwargs) + assert isinstance(config, dict) + return config + return self._config + + def _load_config(self, config: dict[str, Any]) -> None: self.conn_id: str = config.get("conn_id", "") self.is_decorated_function = False if "entrypoint" in config else True self.entrypoint: str = config.get("entrypoint", "python script.py") @@ -39,9 +57,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: self.num_cpus: int | float = config.get("num_cpus", 1) self.num_gpus: int | float = config.get("num_gpus", 0) - self.memory: int | float = config.get("memory", None) - self.ray_resources: dict[str, Any] | None = config.get("resources", None) - self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None) + self.memory: int | float = config.get("memory", 1) + self.ray_resources: dict[str, Any] | None = config.get("resources") + self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml") self.update_if_exists: bool = config.get("update_if_exists", False) self.kuberay_version: str = config.get("kuberay_version", "1.0.0") self.gpu_device_plugin_yaml: str = config.get( @@ -50,35 +68,19 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: ) self.fetch_logs: bool = config.get("fetch_logs", True) self.wait_for_completion: bool = config.get("wait_for_completion", True) - job_timeout_seconds: int = config.get("job_timeout_seconds", 600) + job_timeout_seconds = config.get("job_timeout_seconds", 600) + self.job_timeout_seconds: timedelta | None = ( + timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None + ) self.poll_interval: int = config.get("poll_interval", 60) - self.xcom_task_key: str | None = config.get("xcom_task_key", None) + self.xcom_task_key: str | None = config.get("xcom_task_key") + self.config = config if not isinstance(self.num_cpus, (int, float)): - raise TypeError("num_cpus should be an integer or float value") + raise RayAirflowException("num_cpus should be an integer or float value") if not isinstance(self.num_gpus, (int, float)): - raise TypeError("num_gpus should be an integer or float value") - - super().__init__( - conn_id=self.conn_id, - entrypoint=self.entrypoint, - runtime_env=self.runtime_env, - num_cpus=self.num_cpus, - num_gpus=self.num_gpus, - memory=self.memory, - resources=self.ray_resources, - ray_cluster_yaml=self.ray_cluster_yaml, - update_if_exists=self.update_if_exists, - kuberay_version=self.kuberay_version, - gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, - fetch_logs=self.fetch_logs, - wait_for_completion=self.wait_for_completion, - job_timeout_seconds=job_timeout_seconds, - poll_interval=self.poll_interval, - xcom_task_key=self.xcom_task_key, - **kwargs, - ) + raise RayAirflowException("num_gpus should be an integer or float value") def execute(self, context: Context) -> Any: """ @@ -88,21 +90,21 @@ def execute(self, context: Context) -> Any: :return: The result of the Ray job execution. :raises AirflowException: If job submission fails. """ - tmp_dir = None - try: + config = self._build_config(context) + self.log.info(f"Using the following config {config}") + self._load_config(config) + + with tempfile.TemporaryDirectory(prefix="ray_") as tmpdirname: + temp_dir = Path(tmpdirname) + if self.is_decorated_function: self.log.info( f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}" ) - # Create a temporary directory that won't be immediately deleted - temp_dir = mkdtemp(prefix="ray_") - script_filename = os.path.join(temp_dir, "script.py") # Get the Python source code and extract just the function body full_source = inspect.getsource(self.python_callable) function_body = self._extract_function_body(full_source) - if not function_body: - raise ValueError("Failed to retrieve Python source code") # Prepare the function call args_str = ", ".join(repr(arg) for arg in self.op_args) @@ -110,6 +112,7 @@ def execute(self, context: Context) -> Any: call_str = f"{self.python_callable.__name__}({args_str}, {kwargs_str})" # Write the script with function definition and call + script_filename = os.path.join(temp_dir, "script.py") with open(script_filename, "w") as file: file.write(function_body) file.write(f"\n\n# Execute the function\n{call_str}\n") @@ -122,21 +125,27 @@ def execute(self, context: Context) -> Any: result = super().execute(context) return result - except Exception as e: - self.log.error(f"Failed during execution with error: {e}") - raise AirflowException("Job submission failed") from e - finally: - if tmp_dir and os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir) def _extract_function_body(self, source: str) -> str: """Extract the function, excluding only the ray.task decorator.""" + self.log.info(r"Ray pipeline intended to be executed: \n %s", source) + if "@ray.task" not in source: + raise RayAirflowException("Unable to parse this body. Expects the `@ray.task` decorator.") lines = source.split("\n") + # TODO: Review the current approach, that is quite hacky. + # It feels a mistake to have a user-facing module named the same as the official ray SDK. + # In particular, the decorator is working in a very artificial way, where ray means two different things + # at the scope of the task definition (Astro Ray Provider decorator) and inside the decorated method (Ray SDK) # Find the line where the ray.task decorator is + # Additionally, if users imported the ray decorator as "from ray_provider.decorators import ray as ray_decorator + # The following will stop working. ray_task_line = next((i for i, line in enumerate(lines) if re.match(r"^\s*@ray\.task", line.strip())), -1) # Include everything except the ray.task decorator line body = "\n".join(lines[:ray_task_line] + lines[ray_task_line + 1 :]) + + if not body: + raise RayAirflowException("Failed to extract Ray pipeline code decorated with @ray.task") # Dedent the body return textwrap.dedent(body) @@ -146,6 +155,7 @@ class ray: def task( python_callable: Callable[..., Any] | None = None, multiple_outputs: bool | None = None, + config: dict[str, Any] | Callable[[], dict[str, Any]] | None = None, **kwargs: Any, ) -> TaskDecorator: """ @@ -153,12 +163,15 @@ def task( :param python_callable: The callable function to decorate. :param multiple_outputs: If True, will return multiple outputs. + :param config: A dictionary of configuration or a callable that returns a dictionary. :param kwargs: Additional keyword arguments. :return: The decorated task. """ + config = config or {} return task_decorator_factory( python_callable=python_callable, multiple_outputs=multiple_outputs, decorated_operator_class=_RayDecoratedOperator, + config=config, **kwargs, ) diff --git a/tests/test_decorators.py b/tests/test_decorators.py index a6e6b15..70808a3 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest -from airflow.exceptions import AirflowException from airflow.utils.context import Context from ray_provider.decorators import _RayDecoratedOperator, ray +from ray_provider.exceptions import RayAirflowException class TestRayDecoratedOperator: @@ -29,6 +29,7 @@ def dummy_callable(): pass operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + operator._load_config(config) assert operator.conn_id == "ray_default" assert operator.entrypoint == "python my_script.py" @@ -50,13 +51,14 @@ def dummy_callable(): pass operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + operator._load_config(config) assert operator.conn_id == "" assert operator.entrypoint == "python script.py" assert operator.runtime_env == {} assert operator.num_cpus == 1 assert operator.num_gpus == 0 - assert operator.memory is None + assert operator.memory == 1 assert operator.ray_resources is None assert operator.fetch_logs == True assert operator.wait_for_completion == True @@ -64,6 +66,18 @@ def dummy_callable(): assert operator.poll_interval == 60 assert operator.xcom_task_key is None + def test_callable_config(self): + def dummy_callable(): + pass + + callable_config = lambda context: {"ray_cluster_yaml": "different.yml"} + + operator = _RayDecoratedOperator(task_id="test_task", config=callable_config, python_callable=dummy_callable) + new_config = operator._build_config(context={}) + operator._load_config(new_config) + + assert operator.ray_cluster_yaml == "different.yml" + def test_invalid_config_raises_exception(self): config = { "num_cpus": "invalid_number", @@ -72,13 +86,16 @@ def test_invalid_config_raises_exception(self): def dummy_callable(): pass - with pytest.raises(TypeError): - _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + with pytest.raises(RayAirflowException): + operator._load_config(config) config["num_cpus"] = 1 config["num_gpus"] = "invalid_number" - with pytest.raises(TypeError): - _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + with pytest.raises(RayAirflowException): + operator._load_config(config) @patch.object(_RayDecoratedOperator, "_extract_function_body") @patch("ray_provider.decorators.SubmitRayJob.execute") @@ -130,7 +147,7 @@ def dummy_callable(): operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) mock_super_execute.side_effect = Exception("Ray job failed") - with pytest.raises(AirflowException): + with pytest.raises(Exception): operator.execute(context) def test_extract_function_body(self): @@ -155,6 +172,37 @@ def dummy_callable(): """ ) + def test_extract_function_body_invalid_body(self): + config = {} + + @ray.task() + def dummy_callable(): + return "dummy" + + operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + with pytest.raises(RayAirflowException) as exc_info: + operator._extract_function_body( + """@ray_decorator.task() + def dummy_callable(): + return "dummy" + """ + ) + assert str(exc_info.value) == "Unable to parse this body. Expects the `@ray.task` decorator." + + def test_extract_function_body_empty_body(self): + config = {} + + @ray.task() + def dummy_callable(): + return "dummy" + + operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + with pytest.raises(RayAirflowException) as exc_info: + operator._extract_function_body("""@ray.task()""") + assert str(exc_info.value) == "Failed to extract Ray pipeline code decorated with @ray.task" + class TestRayTaskDecorator: def test_ray_task_decorator(self):