diff --git a/.buildkite/_forge.rayci.yml b/.buildkite/_forge.rayci.yml index 9c3374da23f51..ed2cb1eb2579e 100644 --- a/.buildkite/_forge.rayci.yml +++ b/.buildkite/_forge.rayci.yml @@ -25,6 +25,7 @@ steps: - "12.1.1-cudnn8" - "12.3.2-cudnn9" - "12.4.1-cudnn" + - "12.5.1-cudnn" env: PYTHON_VERSION: "{{matrix.python}}" CUDA_VERSION: "{{matrix.cuda}}" diff --git a/.buildkite/build.rayci.yml b/.buildkite/build.rayci.yml index b8048ad8de4be..e4d5b0c5ff9d1 100644 --- a/.buildkite/build.rayci.yml +++ b/.buildkite/build.rayci.yml @@ -59,7 +59,8 @@ steps: - bazel run //ci/ray_ci:build_in_docker -- docker --python-version {{matrix}} --platform cu11.7.1-cudnn8 --platform cu11.8.0-cudnn8 --platform cu12.1.1-cudnn8 --platform cu12.3.2-cudnn9 - --platform cu12.4.1-cudnn --platform cpu + --platform cu12.4.1-cudnn --platform cu12.5.1-cudnn + --platform cpu --image-type ray --upload depends_on: - manylinux diff --git a/.buildkite/linux_aarch64.rayci.yml b/.buildkite/linux_aarch64.rayci.yml index 5b0774aa55de2..57835ae2bbe57 100644 --- a/.buildkite/linux_aarch64.rayci.yml +++ b/.buildkite/linux_aarch64.rayci.yml @@ -33,6 +33,7 @@ steps: - "12.1.1-cudnn8" - "12.3.2-cudnn9" - "12.4.1-cudnn" + - "12.5.1-cudnn" instance_type: builder-arm64 env: PYTHON_VERSION: "{{matrix.python}}" @@ -82,7 +83,8 @@ steps: - bazel run //ci/ray_ci:build_in_docker -- docker --python-version {{matrix}} --platform cu11.7.1-cudnn8 --platform cu11.8.0-cudnn8 --platform cu12.1.1-cudnn8 --platform cu12.3.2-cudnn9 - --platform cu12.4.1-cudnn --platform cpu + --platform cu12.4.1-cudnn --platform cu12.5.1-cudnn + --platform cpu --image-type ray --architecture aarch64 --upload depends_on: - manylinux-aarch64 diff --git a/.buildkite/llm.rayci.yml b/.buildkite/llm.rayci.yml new file mode 100644 index 0000000000000..e966c293ff3bc --- /dev/null +++ b/.buildkite/llm.rayci.yml @@ -0,0 +1,17 @@ +group: llm tests +depends_on: + - forge + - oss-ci-base_ml +steps: + - name: llmbuild + wanda: ci/docker/llm.build.wanda.yaml + + - label: "llm tests" + key: "llm-tests" + tags: + - python + - llm + instance_type: medium + commands: + - bazel run //ci/ray_ci:test_in_docker -- //python/ray/llm/... llm + depends_on: llmbuild diff --git a/.buildkite/rllib.rayci.yml b/.buildkite/rllib.rayci.yml index 7eb7452b9cec3..4adad804654ee 100644 --- a/.buildkite/rllib.rayci.yml +++ b/.buildkite/rllib.rayci.yml @@ -83,6 +83,7 @@ steps: tags: - rllib_gpu - gpu + - skip-on-microcheck parallelism: 5 instance_type: gpu commands: @@ -137,6 +138,7 @@ steps: tags: - rllib_directly - doc + - skip-on-microcheck instance_type: medium commands: # doc tests @@ -159,6 +161,7 @@ steps: tags: - rllib_gpu - gpu + - skip-on-microcheck parallelism: 5 instance_type: gpu-large commands: diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f23bff84a6716..1e9d8fb6ef5fe 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -91,6 +91,9 @@ /python/ray/train/ @hongpeng-guo @justinvyu @matthewdeng @raulchen @woshiyyya /doc/source/train/ @hongpeng-guo @justinvyu @matthewdeng @raulchen @woshiyyya @ray-project/ray-docs +# LLM +/python/ray/llm/ @ray-project/ray-llm + # Serve (docs) /doc/source/serve/ @edoakes @zcin @GeneDer @akshay-anyscale @ray-project/ray-docs diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e1a5bc7d06b1..26c384497ecdb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,6 +3,8 @@ exclude: | python/ray/core/generated/| python/ray/serve/generated/| python/ray/cloudpickle/| + python/ray/dashboard/client/public/| + python/ray/tests/test_cli_patterns| python/ray/_private/runtime_env/_clonevirtualenv.py| doc/external/| doc/source/ diff --git a/.rayciversion b/.rayciversion index a3df0a6959e15..78bc1abd14f2c 100644 --- a/.rayciversion +++ b/.rayciversion @@ -1 +1 @@ -0.8.0 +0.10.0 diff --git a/.vale/styles/config/vocabularies/RLlib/accept.txt b/.vale/styles/config/vocabularies/RLlib/accept.txt index 6575bdbdc333d..9eb656d54afa5 100644 --- a/.vale/styles/config/vocabularies/RLlib/accept.txt +++ b/.vale/styles/config/vocabularies/RLlib/accept.txt @@ -3,11 +3,13 @@ [Aa]lgos? (APPO|appo) [Aa]utoscal(e|ing) +[Aa]utoregressive boolean [Cc]allables? [Cc]heckpoints?(ing)? [Cc]heckpointable classmethods? +CNNs? coeff config (DQN|dqn) @@ -15,10 +17,12 @@ config (IMPALA|impala) logits? log-probs? +LSTMs? hyperparameters? MARLModule (MARWIL|marwil) MLAgents +MLPs? multiagent [Pp]erceptrons? postprocessing diff --git a/.vale/styles/config/vocabularies/Train/accept.txt b/.vale/styles/config/vocabularies/Train/accept.txt index 38f7eed079981..64d29a7518dbb 100644 --- a/.vale/styles/config/vocabularies/Train/accept.txt +++ b/.vale/styles/config/vocabularies/Train/accept.txt @@ -5,4 +5,4 @@ LightGBM PyTorch PyTorch Lightning TensorFlow -XGBoost \ No newline at end of file +XGBoost diff --git a/BUILD.bazel b/BUILD.bazel index d3944f4c69f47..c87c6e6068405 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -386,6 +386,7 @@ ray_cc_library( "//src/ray/protobuf:common_cc_proto", "//src/ray/util", "//src/ray/util:compat", + "//src/ray/util:counter_map", "@msgpack", ], ) @@ -537,7 +538,10 @@ ray_cc_library( ":scheduler", ":worker_rpc", "//src/ray/protobuf:agent_manager_cc_proto", + "//src/ray/util:counter_map", "//src/ray/util:thread_checker", + "//src/ray/util:throttler", + "//src/ray/util:type_traits", "@boost//:bimap", "@com_github_grpc_grpc//src/proto/grpc/health/v1:health_proto", "@com_google_absl//absl/container:btree", @@ -718,6 +722,7 @@ ray_cc_library( "//src/ray/protobuf:common_cc_proto", "//src/ray/protobuf:runtime_env_agent_cc_proto", "//src/ray/util", + "//src/ray/util:throttler", "@boost//:asio", "@boost//:beast", "@boost//:system", @@ -1972,6 +1977,7 @@ ray_cc_library( ":pubsub_lib", ":ray_common", ":redis_store_client", + "//src/ray/util:sequencer", "//src/ray/protobuf:usage_cc_proto", ], ) @@ -2209,6 +2215,7 @@ ray_cc_library( ":ray_common", ":stats_lib", "//src/ray/util", + "//src/ray/util:exponential_backoff", "@boost//:asio", ], ) @@ -2224,7 +2231,7 @@ ray_cc_library( "src/ray/gcs/store_client/store_client.h", ], deps = [ - "redis_client", + ":redis_client", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", ], @@ -2370,6 +2377,7 @@ ray_cc_library( "//src/ray/protobuf:gcs_cc_proto", "//src/ray/protobuf:gcs_service_cc_proto", "//src/ray/util", + "//src/ray/util:exponential_backoff", "@boost//:asio", ], ) @@ -2510,6 +2518,7 @@ pyx_library( "//:stats_lib", "//src/ray/protobuf:serialization_cc_proto", "//src/ray/util", + "//src/ray/util:memory", ], ) diff --git a/binder/requirements.txt b/binder/requirements.txt index c0bd5da28a01f..eaf0b539f259f 100644 --- a/binder/requirements.txt +++ b/binder/requirements.txt @@ -1,4 +1,4 @@ # Set up requirements needed to build a binder server to launch into. -r ../doc/requirements-doc.txt -r ../python/requirements.txt -ray[rllib, serve, tune] \ No newline at end of file +ray[rllib, serve, tune] diff --git a/ci/docker/base.gpu.Dockerfile b/ci/docker/base.gpu.Dockerfile index 14cec40cca2d4..f875f3c39fb4c 100644 --- a/ci/docker/base.gpu.Dockerfile +++ b/ci/docker/base.gpu.Dockerfile @@ -18,12 +18,6 @@ ENV RAY_INSTALL_JAVA=0 ENV BUILDKITE_PULL_REQUEST=${BUILDKITE_PULL_REQUEST} ENV BUILDKITE_COMMIT=${BUILDKITE_COMMIT} ENV BUILDKITE_PULL_REQUEST_BASE_BRANCH=${BUILDKITE_PULL_REQUEST_BASE_BRANCH} -# For wheel build -# https://github.com/docker-library/docker/blob/master/20.10/docker-entrypoint.sh -ENV DOCKER_TLS_CERTDIR=/certs -ENV DOCKER_HOST=tcp://docker:2376 -ENV DOCKER_TLS_VERIFY=1 -ENV DOCKER_CERT_PATH=/certs/client ENV TRAVIS_COMMIT=${BUILDKITE_COMMIT} ENV BUILDKITE_BAZEL_CACHE_URL=${REMOTE_CACHE_URL} diff --git a/ci/docker/base.test.Dockerfile b/ci/docker/base.test.Dockerfile index c34b9210a506d..affc28abc9b26 100644 --- a/ci/docker/base.test.Dockerfile +++ b/ci/docker/base.test.Dockerfile @@ -14,12 +14,6 @@ ENV PYTHON=$PYTHON ENV RAY_USE_RANDOM_PORTS=1 ENV RAY_DEFAULT_BUILD=1 ENV RAY_INSTALL_JAVA=0 -# For wheel build -# https://github.com/docker-library/docker/blob/master/20.10/docker-entrypoint.sh -ENV DOCKER_TLS_CERTDIR=/certs -ENV DOCKER_HOST=tcp://docker:2376 -ENV DOCKER_TLS_VERIFY=1 -ENV DOCKER_CERT_PATH=/certs/client ENV BUILDKITE_BAZEL_CACHE_URL=${BUILDKITE_BAZEL_CACHE_URL} RUN < None: container = RayDockerContainer(v, "cu12.4.1-cudnn", "ray") assert container.get_platform_tag() == "-cu124" + container = RayDockerContainer(v, "cu12.5.1-cudnn", "ray") + assert container.get_platform_tag() == "-cu125" + def test_should_upload(self) -> None: v = DEFAULT_PYTHON_VERSION test_cases = [ diff --git a/ci/repro-ci-requirements.txt b/ci/repro-ci-requirements.txt index 0543f4e84c94e..2f3dd219e7c10 100644 --- a/ci/repro-ci-requirements.txt +++ b/ci/repro-ci-requirements.txt @@ -3,4 +3,4 @@ boto3 click paramiko pyyaml -pybuildkite \ No newline at end of file +pybuildkite diff --git a/cpp/include/ray/api/arguments.h b/cpp/include/ray/api/arguments.h index c33f54b2beaa9..43c75c5abe0f5 100644 --- a/cpp/include/ray/api/arguments.h +++ b/cpp/include/ray/api/arguments.h @@ -120,4 +120,4 @@ class Arguments { }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/common_types.h b/cpp/include/ray/api/common_types.h index 71defe2fd2ec3..4c3b9e2d8727d 100644 --- a/cpp/include/ray/api/common_types.h +++ b/cpp/include/ray/api/common_types.h @@ -53,4 +53,4 @@ using RemoteMemberFunction = using RemoteMemberFunctionMap_t = std::unordered_map; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/function_manager.h b/cpp/include/ray/api/function_manager.h index 678213e387230..4b683d99a89e6 100644 --- a/cpp/include/ray/api/function_manager.h +++ b/cpp/include/ray/api/function_manager.h @@ -351,4 +351,4 @@ class FunctionManager { std::map, std::string> mem_func_to_key_map_; }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/object_ref.h b/cpp/include/ray/api/object_ref.h index 9a554815df787..bd3c24fe32317 100644 --- a/cpp/include/ray/api/object_ref.h +++ b/cpp/include/ray/api/object_ref.h @@ -222,4 +222,4 @@ class ObjectRef { std::string id_; }; -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/ray_exception.h b/cpp/include/ray/api/ray_exception.h index 6f5a10e6d9ff8..51a595afe7503 100644 --- a/cpp/include/ray/api/ray_exception.h +++ b/cpp/include/ray/api/ray_exception.h @@ -65,4 +65,4 @@ class RayTimeoutException : public RayException { }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/ray_remote.h b/cpp/include/ray/api/ray_remote.h index 4ab42c42e70ee..7381111240a5d 100644 --- a/cpp/include/ray/api/ray_remote.h +++ b/cpp/include/ray/api/ray_remote.h @@ -80,4 +80,4 @@ inline static int RegisterRemoteFunctions(const T &t, U... u) { #define RAY_FUNC(f, ...) ray::internal::underload<__VA_ARGS__>(f) -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/ray_runtime.h b/cpp/include/ray/api/ray_runtime.h index 597fab770e56d..8a8bf35e83ce1 100644 --- a/cpp/include/ray/api/ray_runtime.h +++ b/cpp/include/ray/api/ray_runtime.h @@ -102,4 +102,4 @@ class RayRuntime { const std::string &serialized_actor_handle) = 0; }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/ray_runtime_holder.h b/cpp/include/ray/api/ray_runtime_holder.h index 3afb3405e5b26..8e2acffb3e932 100644 --- a/cpp/include/ray/api/ray_runtime_holder.h +++ b/cpp/include/ray/api/ray_runtime_holder.h @@ -45,4 +45,4 @@ inline static std::shared_ptr GetRayRuntime() { } } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/serializer.h b/cpp/include/ray/api/serializer.h index 7c91c7952a3df..415b7ebba7b0e 100644 --- a/cpp/include/ray/api/serializer.h +++ b/cpp/include/ray/api/serializer.h @@ -83,4 +83,4 @@ class Serializer { }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/type_traits.h b/cpp/include/ray/api/type_traits.h index ffc8bd395fb36..229837d4b862a 100644 --- a/cpp/include/ray/api/type_traits.h +++ b/cpp/include/ray/api/type_traits.h @@ -85,4 +85,4 @@ template auto constexpr is_x_lang_v = is_java_v || is_python_v; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/wait_result.h b/cpp/include/ray/api/wait_result.h index bbe333745b7c7..526c716e4d6cc 100644 --- a/cpp/include/ray/api/wait_result.h +++ b/cpp/include/ray/api/wait_result.h @@ -34,4 +34,4 @@ class WaitResult { : ready(ready_objects), unready(unready_objects){}; }; -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/include/ray/api/xlang_function.h b/cpp/include/ray/api/xlang_function.h index bb78f1ea49636..f517a643784dd 100644 --- a/cpp/include/ray/api/xlang_function.h +++ b/cpp/include/ray/api/xlang_function.h @@ -82,4 +82,4 @@ inline constexpr std::string_view METADATA_STR_XLANG = "XLANG"; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/src/ray/runtime/object/object_store.cc b/cpp/src/ray/runtime/object/object_store.cc index 018c6912a2694..9b34dc6b09dc0 100644 --- a/cpp/src/ray/runtime/object/object_store.cc +++ b/cpp/src/ray/runtime/object/object_store.cc @@ -49,4 +49,4 @@ ObjectStore::GetAllReferenceCounts() const { return core_worker.GetAllReferenceCounts(); } } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/src/ray/runtime/object/object_store.h b/cpp/src/ray/runtime/object/object_store.h index c5648083340c5..9c12eb443f33e 100644 --- a/cpp/src/ray/runtime/object/object_store.h +++ b/cpp/src/ray/runtime/object/object_store.h @@ -102,4 +102,4 @@ class ObjectStore { const std::vector &ids, int timeout_ms) = 0; }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.h b/cpp/src/ray/runtime/task/local_mode_task_submitter.h index 210adfb5d2c9f..57a0e4648ad0e 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.h +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.h @@ -65,4 +65,4 @@ class LocalModeTaskSubmitter : public TaskSubmitter { std::unordered_map placement_groups_; }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/src/ray/runtime/task/native_task_submitter.h b/cpp/src/ray/runtime/task/native_task_submitter.h index eb055474da129..49c34988853b1 100644 --- a/cpp/src/ray/runtime/task/native_task_submitter.h +++ b/cpp/src/ray/runtime/task/native_task_submitter.h @@ -41,4 +41,4 @@ class NativeTaskSubmitter : public TaskSubmitter { ObjectID Submit(InvocationSpec &invocation, const CallOptions &call_options); }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/src/ray/runtime/task/task_submitter.h b/cpp/src/ray/runtime/task/task_submitter.h index 032217b7797d8..eb36aa0fb1359 100644 --- a/cpp/src/ray/runtime/task/task_submitter.h +++ b/cpp/src/ray/runtime/task/task_submitter.h @@ -52,4 +52,4 @@ class TaskSubmitter { } }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/src/ray/test/cluster/counter.h b/cpp/src/ray/test/cluster/counter.h index 1e7c50afe5c9d..d34147d0c8e66 100644 --- a/cpp/src/ray/test/cluster/counter.h +++ b/cpp/src/ray/test/cluster/counter.h @@ -164,4 +164,4 @@ class ActorConcurrentCall { private: CountDownLatch contdown_{3}; -}; \ No newline at end of file +}; diff --git a/cpp/src/ray/test/serialization_test.cc b/cpp/src/ray/test/serialization_test.cc index b157ee51ab404..5b533f21697c8 100644 --- a/cpp/src/ray/test/serialization_test.cc +++ b/cpp/src/ray/test/serialization_test.cc @@ -123,4 +123,4 @@ TEST(SerializationTest, BoundaryValueTest) { auto out_arg3 = ray::internal::Serializer::Deserialize>( buffer1.data(), buffer1.size()); EXPECT_EQ(std::vector(), out_arg3); -} \ No newline at end of file +} diff --git a/cpp/src/ray/util/function_helper.cc b/cpp/src/ray/util/function_helper.cc index 035eea69aa3ac..58c3d32e9460d 100644 --- a/cpp/src/ray/util/function_helper.cc +++ b/cpp/src/ray/util/function_helper.cc @@ -183,4 +183,4 @@ const EntryFuntion &FunctionHelper::GetExecutableMemberFunctions( } } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/cpp/src/ray/util/function_helper.h b/cpp/src/ray/util/function_helper.h index ac9ee63e8adb2..6c0ce6011a8bf 100644 --- a/cpp/src/ray/util/function_helper.h +++ b/cpp/src/ray/util/function_helper.h @@ -62,4 +62,4 @@ class FunctionHelper { std::unordered_map remote_member_funcs_; }; } // namespace internal -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/doc/BUILD b/doc/BUILD index ed85a4b60249c..5cf8a9d0e402d 100644 --- a/doc/BUILD +++ b/doc/BUILD @@ -448,11 +448,25 @@ doctest( "source/rllib/**/*.rst", "source/rllib/**/*.md", ], + exclude = [ + "source/rllib/getting-started.rst", + ] ), data = ["//rllib:cartpole-v1_large"], tags = ["team:rllib"], ) +doctest( + name = "doctest[rllib2]", + size = "large", + files = glob( + include = [ + "source/rllib/getting-started.rst", + ], + ), + tags = ["team:rllib"], +) + doctest( files = [ "source/data/batch_inference.rst", diff --git a/doc/azure/azure-init.sh b/doc/azure/azure-init.sh index 0add7a6b2762b..cc748c4ab8f8f 100755 --- a/doc/azure/azure-init.sh +++ b/doc/azure/azure-init.sh @@ -96,4 +96,4 @@ if [ "$type" = "head" ]; then echo "Starting TensorBoard..." systemctl start tensorboard -fi \ No newline at end of file +fi diff --git a/doc/azure/azure-ray-template.json b/doc/azure/azure-ray-template.json index 71ef21aa0257d..c5b5db6475f10 100644 --- a/doc/azure/azure-ray-template.json +++ b/doc/azure/azure-ray-template.json @@ -517,4 +517,4 @@ "condition": "[parameters('PublicWebUI')]" } } -} \ No newline at end of file +} diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt index 6460dc483e592..239bbfece315a 100644 --- a/doc/requirements-doc.txt +++ b/doc/requirements-doc.txt @@ -38,4 +38,4 @@ urllib3 < 1.27 # See doc/source/conf.py for examples of how to mock out external dependencies. click==8.1.7 boto3==1.34.69 -requests==2.32.3 \ No newline at end of file +requests==2.32.3 diff --git a/doc/source/_static/css/custom.css b/doc/source/_static/css/custom.css index 2ac45d4e9a566..bf20f3c5dc049 100644 --- a/doc/source/_static/css/custom.css +++ b/doc/source/_static/css/custom.css @@ -398,6 +398,6 @@ table.autosummary tr > td:first-child > p > a > code > span { } /* Hide the RTD version switcher since we are using PyData theme one */ -#rtd-footer-container { - display: none; +readthedocs-flyout { + display: none !important; } diff --git a/doc/source/_static/img/run-quickstart-anyscale.svg b/doc/source/_static/img/run-quickstart-anyscale.svg index 07d8f872d4a10..5e5c95ac687d5 100644 --- a/doc/source/_static/img/run-quickstart-anyscale.svg +++ b/doc/source/_static/img/run-quickstart-anyscale.svg @@ -1,6 +1,22 @@ - - - + + + + + + + + + + + + + + + + + + + diff --git a/doc/source/cluster/kubernetes/user-guides/azure-aks-gpu-cluster.md b/doc/source/cluster/kubernetes/user-guides/azure-aks-gpu-cluster.md new file mode 100644 index 0000000000000..6239330ea5a66 --- /dev/null +++ b/doc/source/cluster/kubernetes/user-guides/azure-aks-gpu-cluster.md @@ -0,0 +1,55 @@ +(kuberay-aks-gpu-cluster-setup)= + +# Start Azure AKS Cluster with GPUs for KubeRay + +This guide walks you through the steps to create an Azure AKS cluster with GPU nodes specifically for KubeRay. +The configuration outlined here can be applied to most KubeRay examples found in the documentation. + +You can find the landing page for AKS [here](https://azure.microsoft.com/en-us/services/kubernetes-service/). +If you have an account set up, you can immediately start experimenting with Kubernetes clusters in the provider's console. Alternatively, check out the [documentation](https://docs.microsoft.com/en-us/azure/aks/) and [quickstart guides](https://docs.microsoft.com/en-us/azure/aks/learn/quick-kubernetes-deploy-portal?tabs=azure-cli). +To successfully deploy Ray on Kubernetes, you will need to use node pools following the guidance [here](https://docs.microsoft.com/en-us/azure/aks/use-multiple-node-pools). + +## Step 1: Create a Resource Group + +To create a resource group in a particular region: + +``` +az group create -l eastus -n kuberay-rg +``` + +## Step 2: Create AKS Cluster + +To create an AKS cluster with system nodepool: +``` +az aks create \ + -g kuberay-rg \ + -n kuberay-gpu-cluster \ + --nodepool-name system \ + --node-vm-size Standard_D8s_v3 \ + --node-count 3 +``` + +## Step 3: Add a GPU node group + +To add a GPU nodepool with autoscaling: +``` +az aks nodepool add \ + -g kuberay-rg \ + --cluster-name kuberay-gpu-cluster \ + --nodepool-name gpupool \ + --node-vm-size Standard_NC6s_v3 \ + --node-taints nvidia.com/gpu=present:NoSchedule \ + --min-count 0 \ + --max-count 3 \ + --enable-cluster-autoscaler +``` +To use Nvidia GPU operator alternatively, follow instructions [here](https://learn.microsoft.com/en-us/azure/aks/gpu-cluster?tabs=add-ubuntu-gpu-node-pool#skip-gpu-driver-installation-preview) + +## Step 4: Get kubeconfig + +To get kubeconfig: +``` +az aks get-credentials --resource-group kuberay-rg \ + --name kuberay-gpu-cluster \ + --overwrite-existing +``` \ No newline at end of file diff --git a/doc/source/cluster/kubernetes/user-guides/k8s-cluster-setup.md b/doc/source/cluster/kubernetes/user-guides/k8s-cluster-setup.md index 87acff1fc52e1..66bc514c08610 100644 --- a/doc/source/cluster/kubernetes/user-guides/k8s-cluster-setup.md +++ b/doc/source/cluster/kubernetes/user-guides/k8s-cluster-setup.md @@ -8,6 +8,7 @@ aws-eks-gpu-cluster gcp-gke-gpu-cluster gcp-gke-tpu-cluster +azure-aks-gpu-cluster ``` Most KubeRay documentation examples only require a local Kubernetes cluster such as [Kind](https://kind.sigs.k8s.io/). @@ -26,10 +27,5 @@ We collect a few helpful links for users who are getting started with a managed - {ref}`kuberay-eks-gpu-cluster-setup` (aks-setup)= -# Setting up an AKS (Microsoft Azure) -You can find the landing page for AKS [here](https://azure.microsoft.com/en-us/services/kubernetes-service/). -If you have an account set up, you can immediately start experimenting with Kubernetes clusters in the provider's console. -Alternatively, check out the [documentation](https://docs.microsoft.com/en-us/azure/aks/) and -[quickstart guides](https://docs.microsoft.com/en-us/azure/aks/learn/quick-kubernetes-deploy-portal?tabs=azure-cli). To successfully deploy Ray on Kubernetes, -you will need to configure pools of Kubernetes nodes; -find guidance [here](https://docs.microsoft.com/en-us/azure/aks/use-multiple-node-pools). +# Set up an AKS cluster (Microsoft Azure) +- {ref}`kuberay-aks-gpu-cluster-setup` diff --git a/doc/source/cluster/kubernetes/user-guides/kubectl-plugin.md b/doc/source/cluster/kubernetes/user-guides/kubectl-plugin.md index 80980a40d4b4c..96191196e8c6b 100644 --- a/doc/source/cluster/kubernetes/user-guides/kubectl-plugin.md +++ b/doc/source/cluster/kubernetes/user-guides/kubectl-plugin.md @@ -1,14 +1,14 @@ (kubectl-plugin)= -# Use kubectl Plugin (alpha) +# Use kubectl plugin (beta) -Starting from KubeRay v1.2.2, you can use the `kubectl ray` plugin to simplify common workflows when deploying Ray on Kubernetes. If you aren't familiar with Kubernetes, this plugin simplifies running Ray on Kubernetes. +Starting from KubeRay v1.3, you can use the `kubectl ray` plugin to simplify common workflows when deploying Ray on Kubernetes. If you aren't familiar with Kubernetes, this plugin simplifies running Ray on Kubernetes. ## Installation See [KubeRay kubectl Plugin](https://github.com/ray-project/kuberay/tree/master/kubectl-plugin) to install the plugin. -Install the Kuberay kubectl plugin using one of the following methods: +Install the KubeRay kubectl plugin using one of the following methods: - Install using Krew kubectl plugin manager (recommended) - Download from GitHub releases @@ -40,11 +40,20 @@ After installing the plugin, you can use `kubectl ray --help` to see the availab ## Example -This example assumes you have a Ray cluster running on Kubernetes. See {ref}`RayCluster Quickstart ` if you don't have a Ray cluster running on Kubernetes. +Most of this example assumes you have a Ray cluster running on Kubernetes. See {ref}`RayCluster Quickstart ` if you don't have a Ray cluster running on Kubernetes. + +### Get all Ray clusters + +```text +$ kubectl ray get cluster +NAME NAMESPACE DESIRED WORKERS AVAILABLE WORKERS CPUS GPUS TPUS MEMORY AGE +rayjob-sample-raycluster-zwbc6 default 1 1 4 0 0 8Gi 71s +sample-cluster default 1 1 4 0 0 8Gi 21d +``` ### Forward local ports to Ray cluster -```shell +```text $ kubectl ray session ray-cluster-kuberay Forwarding ports to service ray-cluster-kuberay-head-svc @@ -59,10 +68,251 @@ Forwarding from [::1]:10001 -> 10001 ### Get Ray cluster logs -```shell -$ kubectl ray logs rayjob-sample-raycluster-kfhl6 +```text +$ kubectl ray logs sample-raycluster-kfhl6 No output directory specified, creating dir under current directory using cluster name. Command set to retrieve both head and worker node logs. -Downloading log for Ray Node rayjob-sample-raycluster-kfhl6-head-87xpb -Downloading log for Ray Node rayjob-sample-raycluster-kfhl6-small-group-worker-54qfm +Downloading log for Ray Node sample-raycluster-kfhl6-head-87xpb +Downloading log for Ray Node sample-raycluster-kfhl6-small-group-worker-54qfm +``` + +#### Get Ray job logs + +You can also access Ray logs with RayJobs and RayServices. + +```text +$ kubectl ray logs rayjob/rayjob-interactivemode +No output directory specified, creating dir under current directory using resource name. +Command set to retrieve both head and worker node logs. +Downloading log for Ray Node rayjob-interactivemode-raycluster-qbkr8-head-dwm84 +Downloading log for Ray Node rayjob-interactivemode-raycluster-qbkr8-small-grou-worker-hr2jp +``` + +### Create a Ray cluster + +The `ray create cluster` command allows you to create a valid RayCluster without an existing YAML file. The default values are follows: + +| Parameter | Default | +| -------- | ---------- | +| ray version | 2.39.0 | +| ray image | rayproject/ray:\ | +| head cpu | 2 | +| head memory | 4Gi | +| worker replicas | 1 | +| worker cpu | 2 | +| worker memory | 4Gi | +| worker gpu | 0 | + +Currently only one worker group is created. + +```text +$ kubectl ray create cluster raycluster-sample +Created Ray Cluster: raycluster-sample +$ kubectl ray get cluster +NAME NAMESPACE DESIRED WORKERS AVAILABLE WORKERS CPUS GPUS TPUS MEMORY AGE +raycluster-sample default 1 1 4 0 0 8Gi 25s +``` + +### Submit a Ray job + +This is a wrapper around the `ray job submit` command. The plugin can automatically forward the ports to the Ray cluster and submit the job. This command can also provision a ephemeral cluster needed to execute the job if no RayJob is provided. + +Assume that under the current directory, you have a file named `sample_code.py`. + +```python +import ray +ray.init(address="auto") + +@ray.remote +def f(x): + return x * x + +futures = [f.remote(i) for i in range(4)] +print(ray.get(futures)) # [0, 1, 4, 9] +``` + +#### Submit a Ray job without a YAML file + +You can submit a RayJob without specifying a YAML file. The command generates a RayJob based on the following: + +| Parameter | Default | +| -------- | ---------- | +| ray version | 2.39.0 | +| ray image | rayproject/ray:\ | +| head cpu | 2 | +| head memory | 4Gi | +| worker replicas | 1 | +| worker cpu | 2 | +| worker memory | 4Gi | +| worker gpu | 0 | + +```text +$ kubectl ray job submit --name rayjob-sample --working-dir . -- python sample_code.py +Submitted RayJob rayjob-sample. +Waiting for RayCluster +Checking Cluster Status for cluster rayjob-sample-raycluster-2qgsj... +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Waiting for portforwarding...Port Forwarding service rayjob-sample-raycluster-2qgsj-head-svc +Forwarding from 127.0.0.1:8265 -> 8265 +Forwarding from [::1]:8265 -> 8265 +Handling connection for 8265 +Portforwarding started on http://localhost:8265 +Ray command: [ray job submit --address http://localhost:8265 --working-dir . -- python sample_code.py] +Running ray submit job command... +Handling connection for 8265 +Handling connection for 8265 +Handling connection for 8265 +2025-01-06 11:53:32,710 INFO dashboard_sdk.py:338 -- Uploading package gcs://_ray_pkg_bd1a1af41a246cb2.zip. +2025-01-06 11:53:32,714 INFO packaging.py:601 -- Creating a file package for local module '.'. +Handling connection for 8265 +Handling connection for 8265 +Handling connection for 8265 +Handling connection for 8265 +Handling connection for 8265 +2025-01-06 11:53:32,552 INFO cli.py:39 -- Job submission server address: http://localhost:8265 +2025-01-06 11:53:33,629 SUCC cli.py:63 -- ------------------------------------------------------- +2025-01-06 11:53:33,630 SUCC cli.py:64 -- Job 'raysubmit_9NfCvwcmcyMNFCvX' submitted successfully +2025-01-06 11:53:33,630 SUCC cli.py:65 -- ------------------------------------------------------- +2025-01-06 11:53:33,630 INFO cli.py:289 -- Next steps +2025-01-06 11:53:33,630 INFO cli.py:290 -- Query the logs of the job: +2025-01-06 11:53:33,631 INFO cli.py:292 -- ray job logs raysubmit_9NfCvwcmcyMNFCvX +2025-01-06 11:53:33,631 INFO cli.py:294 -- Query the status of the job: +2025-01-06 11:53:33,631 INFO cli.py:296 -- ray job status raysubmit_9NfCvwcmcyMNFCvX +2025-01-06 11:53:33,631 INFO cli.py:298 -- Request the job to be stopped: +2025-01-06 11:53:33,631 INFO cli.py:300 -- ray job stop raysubmit_9NfCvwcmcyMNFCvX +2025-01-06 11:53:33,786 INFO cli.py:307 -- Tailing logs until the job exits (disable with --no-wait): +2025-01-06 11:53:33,410 INFO job_manager.py:530 -- Runtime env is setting up. +2025-01-06 11:53:34,806 INFO worker.py:1494 -- Using address 10.12.0.9:6379 set in the environment variable RAY_ADDRESS +2025-01-06 11:53:34,806 INFO worker.py:1634 -- Connecting to existing Ray cluster at address: 10.12.0.9:6379... +2025-01-06 11:53:34,814 INFO worker.py:1810 -- Connected to Ray cluster. View the dashboard at 10.12.0.9:8265 +[0, 1, 4, 9] +2025-01-06 11:53:38,368 SUCC cli.py:63 -- ------------------------------------------ +2025-01-06 11:53:38,368 SUCC cli.py:64 -- Job 'raysubmit_9NfCvwcmcyMNFCvX' succeeded +2025-01-06 11:53:38,368 SUCC cli.py:65 -- ------------------------------------------ +``` + +#### Submit a Ray job with a RayJob YAML + +Users can also designate a specific RayJob YAML to submit a Ray job. + +Add the following YAML file `ray-job.interactivemode.yaml`: + +```yaml +apiVersion: ray.io/v1 +kind: RayJob +metadata: + name: rayjob-interactivemode +spec: + submissionMode: InteractiveMode + rayClusterSpec: + rayVersion: '2.39.0' + headGroupSpec: + rayStartParams: + dashboard-host: '0.0.0.0' + template: + spec: + containers: + - name: ray-head + image: rayproject/ray:2.39.0 + ports: + - containerPort: 6379 + name: gcs-server + - containerPort: 8265 + name: dashboard + - containerPort: 10001 + name: client + resources: + limits: + cpu: "1" + requests: + cpu: "200m" + workerGroupSpecs: + - replicas: 1 + minReplicas: 1 + maxReplicas: 3 + groupName: small-group + rayStartParams: {} + template: + spec: + containers: + - name: ray-worker + image: rayproject/ray:2.39.0 + lifecycle: + preStop: + exec: + command: [ "/bin/sh","-c","ray stop" ] + resources: + limits: + cpu: "1" + requests: + cpu: "200m" +``` + +Note that in the RayJob spec, `submissionMode` is set to `InteractiveMode`. + +```text +$ kubectl ray job submit -f ray-job.interactivemode.yaml --working-dir . -- python sample_code.py +Submitted RayJob rayjob-interactivemode. +Waiting for RayCluster +Checking Cluster Status for cluster rayjob-interactivemode-raycluster-qbkr8... +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Cluster is not ready: Cannot determine cluster state +Waiting for portforwarding...Port Forwarding service rayjob-interactivemode-raycluster-qbkr8-head-svc +Forwarding from 127.0.0.1:8265 -> 8265 +Forwarding from [::1]:8265 -> 8265 +Handling connection for 8265 +Portforwarding started on http://localhost:8265 +Ray command: [ray job submit --address http://localhost:8265 --working-dir . -- python sample_code.py] +Running ray submit job command... +Handling connection for 8265 +Handling connection for 8265 +Handling connection for 8265 +2025-01-06 12:44:41,234 INFO dashboard_sdk.py:338 -- Uploading package gcs://_ray_pkg_3ddba7608d86c45a.zip. +2025-01-06 12:44:41,238 INFO packaging.py:601 -- Creating a file package for local module '.'. +Handling connection for 8265 +Handling connection for 8265 +Handling connection for 8265 +Handling connection for 8265 +Handling connection for 8265 +2025-01-06 12:44:41,077 INFO cli.py:39 -- Job submission server address: http://localhost:8265 +2025-01-06 12:44:42,312 SUCC cli.py:63 -- ------------------------------------------------------- +2025-01-06 12:44:42,313 SUCC cli.py:64 -- Job 'raysubmit_fuBdjGnecFggejhR' submitted successfully +2025-01-06 12:44:42,313 SUCC cli.py:65 -- ------------------------------------------------------- +2025-01-06 12:44:42,313 INFO cli.py:289 -- Next steps +2025-01-06 12:44:42,313 INFO cli.py:290 -- Query the logs of the job: +2025-01-06 12:44:42,313 INFO cli.py:292 -- ray job logs raysubmit_fuBdjGnecFggejhR +2025-01-06 12:44:42,313 INFO cli.py:294 -- Query the status of the job: +2025-01-06 12:44:42,313 INFO cli.py:296 -- ray job status raysubmit_fuBdjGnecFggejhR +2025-01-06 12:44:42,313 INFO cli.py:298 -- Request the job to be stopped: +2025-01-06 12:44:42,313 INFO cli.py:300 -- ray job stop raysubmit_fuBdjGnecFggejhR +2025-01-06 12:44:42,472 INFO cli.py:307 -- Tailing logs until the job exits (disable with --no-wait): +2025-01-06 12:44:41,931 INFO job_manager.py:530 -- Runtime env is setting up. +2025-01-06 12:44:43,542 INFO worker.py:1494 -- Using address 10.12.0.10:6379 set in the environment variable RAY_ADDRESS +2025-01-06 12:44:43,542 INFO worker.py:1634 -- Connecting to existing Ray cluster at address: 10.12.0.10:6379... +2025-01-06 12:44:43,551 INFO worker.py:1810 -- Connected to Ray cluster. View the dashboard at 10.12.0.10:8265 +[0, 1, 4, 9] +2025-01-06 12:44:47,830 SUCC cli.py:63 -- ------------------------------------------ +2025-01-06 12:44:47,830 SUCC cli.py:64 -- Job 'raysubmit_fuBdjGnecFggejhR' succeeded +2025-01-06 12:44:47,830 SUCC cli.py:65 -- ------------------------------------------ +``` + +### Delete a Ray cluster + +```text +$ kubectl ray delete raycluster-sample +Are you sure you want to delete raycluster raycluster-sample? (y/yes/n/no) y +Delete raycluster raycluster-sample ``` diff --git a/doc/source/conf.py b/doc/source/conf.py index cd1ce65093a95..c11a48b7a3c39 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -304,22 +304,19 @@ def render_svg_logo(path): # documentation. html_theme_options = { "use_edit_page_button": True, - "announcement": """Influence the future of Ray with our Ray Community Pulse survey. Complete it by Monday, January 27th, 2025 to get exclusive swag for eligible participants.""", + "announcement": False, "logo": { "svg": render_svg_logo("_static/img/ray_logo.svg"), }, "navbar_start": ["navbar-ray-logo"], "navbar_end": [ + "theme-switcher", "version-switcher", "navbar-icon-links", "navbar-anyscale", ], "navbar_center": ["navbar-links"], "navbar_align": "left", - "navbar_persistent": [ - "search-button-field", - "theme-switcher", - ], "secondary_sidebar_items": [ "page-toc", "edit-on-github", @@ -332,7 +329,9 @@ def render_svg_logo(path): "pygment_dark_style": "stata-dark", "switcher": { "json_url": "https://docs.ray.io/en/master/_static/versions.json", - "version_match": os.getenv("READTHEDOCS_VERSION", "master"), + "version_match": ( + lambda v: v if v in ["master", "latest"] else f"releases/{v}" + )(os.getenv("READTHEDOCS_VERSION", "master")), }, } @@ -345,9 +344,11 @@ def render_svg_logo(path): html_sidebars = { "**": [ - "main-sidebar-readthedocs" - if os.getenv("READTHEDOCS") == "True" - else "main-sidebar" + ( + "main-sidebar-readthedocs" + if os.getenv("READTHEDOCS") == "True" + else "main-sidebar" + ) ], "ray-overview/examples": [], } diff --git a/doc/source/data/api/api.rst b/doc/source/data/api/api.rst index b82d011b1d125..17be24e64f648 100644 --- a/doc/source/data/api/api.rst +++ b/doc/source/data/api/api.rst @@ -13,4 +13,5 @@ Ray Data API grouped_data.rst data_context.rst preprocessor.rst + llm.rst from_other_data_libs.rst diff --git a/doc/source/data/api/input_output.rst b/doc/source/data/api/input_output.rst index 338bd45c5936e..2bddf5ba8e471 100644 --- a/doc/source/data/api/input_output.rst +++ b/doc/source/data/api/input_output.rst @@ -64,6 +64,15 @@ Text read_text +Audio +----- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_audio + Avro ---- @@ -289,6 +298,15 @@ TensorFlow from_tf +Video +----- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + read_videos + WebDataset ---------- diff --git a/doc/source/data/api/llm.rst b/doc/source/data/api/llm.rst new file mode 100644 index 0000000000000..7056becba8045 --- /dev/null +++ b/doc/source/data/api/llm.rst @@ -0,0 +1,34 @@ +.. _llm-ref: + +Large Language Model (LLM) API +============================== + +.. currentmodule:: ray.data.llm + +LLM Processor Builder +--------------------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + ~build_llm_processor + +Processor +--------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + ~Processor + +Processor Configs +----------------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + ~ProcessorConfig + ~HttpRequestProcessorConfig diff --git a/doc/source/data/saving-data.rst b/doc/source/data/saving-data.rst index 24130f6796e04..60c81d7ec2e1a 100644 --- a/doc/source/data/saving-data.rst +++ b/doc/source/data/saving-data.rst @@ -143,12 +143,15 @@ Changing the number of output files ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When you call a write method, Ray Data writes your data to several files. To control the -number of output files, configure ``num_rows_per_file``. +number of output files, configure ``min_rows_per_write``. .. note:: - ``num_rows_per_file`` is a hint, not a strict limit. Ray Data might write more or - fewer rows to each file. + ``min_rows_per_write`` is a hint, not a strict limit. Ray Data might write more or + fewer rows to each file. Under the hood, if the number of rows per block is + larger than the specified value, Ray Data writes + the number of rows per block to each file. + .. testcode:: @@ -156,7 +159,7 @@ number of output files, configure ``num_rows_per_file``. import ray ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") - ds.write_csv("/tmp/few_files/", num_rows_per_file=75) + ds.write_csv("/tmp/few_files/", min_rows_per_write=75) print(os.listdir("/tmp/few_files/")) diff --git a/doc/source/ray-contribute/development.rst b/doc/source/ray-contribute/development.rst index e0499a9c660b6..0d5671ab45620 100644 --- a/doc/source/ray-contribute/development.rst +++ b/doc/source/ray-contribute/development.rst @@ -298,8 +298,9 @@ You can tweak the build with the following environment variables (when running ` ``cpp``) build will not provide some ``cpp`` interfaces - ``SKIP_BAZEL_BUILD``: If set and equal to ``1``, no Bazel build steps will be executed -- ``SKIP_THIRDPARTY_INSTALL``: If set will skip installation of third-party - python packages +- ``SKIP_THIRDPARTY_INSTALL_CONDA_FORGE``: If set, setup will skip installation of + third-party packages required for build. This is active on conda-forge where + pip is not used to create a build environment. - ``RAY_DEBUG_BUILD``: Can be set to ``debug``, ``asan``, or ``tsan``. Any other value will be ignored - ``BAZEL_ARGS``: If set, pass a space-separated set of arguments to Bazel. This can be useful diff --git a/doc/source/ray-overview/use-cases.rst b/doc/source/ray-overview/use-cases.rst index 899d8c93fd9c2..feb22203eacd0 100644 --- a/doc/source/ray-overview/use-cases.rst +++ b/doc/source/ray-overview/use-cases.rst @@ -144,7 +144,7 @@ Learn more about reinforcement learning with the following resources. - `[Course] Applied Reinforcement Learning with RLlib `_ - `[Blog] Intro to RLlib: Example Environments `_ -- :doc:`[Guide] Getting Started with RLlib ` +- :doc:`[Guide] Getting Started with RLlib ` - `[Talk] Deep reinforcement learning at Riot Games `_ - :doc:`[Gallery] RLlib Examples Gallery ` - `[Gallery] More RL Use Cases on the Blog `_ diff --git a/doc/source/rllib/algorithm-config.rst b/doc/source/rllib/algorithm-config.rst index d6b53763a54f6..800ff8de3d0e5 100644 --- a/doc/source/rllib/algorithm-config.rst +++ b/doc/source/rllib/algorithm-config.rst @@ -127,6 +127,7 @@ instance into the constructor of the :py:class:`~ray.tune.tuner.Tuner`: results = tuner.fit() +.. _rllib-algo-configuration-generic-settings: Generic config settings ----------------------- diff --git a/doc/source/rllib/checkpoints.rst b/doc/source/rllib/checkpoints.rst index ed98b263ad405..9f52d2f20f31d 100644 --- a/doc/source/rllib/checkpoints.rst +++ b/doc/source/rllib/checkpoints.rst @@ -35,7 +35,7 @@ For example, you can deploy a previously trained :py:class:`~ray.rllib.core.rl_m any of the other RLlib components, into production. .. figure:: images/checkpointing/from_checkpoint.svg - :width: 500 + :width: 750 :align: left **Creating a new instance directly from a checkpoint**: Use the ``classmethod`` diff --git a/doc/source/rllib/doc_code/getting_started.py b/doc/source/rllib/doc_code/getting_started.py deleted file mode 100644 index 951b3acee8dad..0000000000000 --- a/doc/source/rllib/doc_code/getting_started.py +++ /dev/null @@ -1,151 +0,0 @@ -# flake8: noqa - -# __rllib-first-config-begin__ -from pprint import pprint - -from ray.rllib.algorithms.ppo import PPOConfig - -config = ( - PPOConfig() - .api_stack( - enable_rl_module_and_learner=True, - enable_env_runner_and_connector_v2=True, - ) - .environment("CartPole-v1") - .env_runners(num_env_runners=1) -) - -algo = config.build() - -for i in range(10): - result = algo.train() - result.pop("config") - pprint(result) - - if i % 5 == 0: - checkpoint_dir = algo.save_to_path() - print(f"Checkpoint saved in directory {checkpoint_dir}") -# __rllib-first-config-end__ - -algo.stop() - -if False: - # __rllib-tune-config-begin__ - from ray import train, tune - - config = ( - PPOConfig() - .api_stack( - enable_rl_module_and_learner=True, - enable_env_runner_and_connector_v2=True, - ) - .environment("CartPole-v1") - .training( - lr=tune.grid_search([0.01, 0.001, 0.0001]), - ) - ) - - tuner = tune.Tuner( - "PPO", - param_space=config, - run_config=train.RunConfig( - stop={"env_runners/episode_return_mean": 150.0}, - ), - ) - - tuner.fit() - # __rllib-tune-config-end__ - - -# __rllib-tuner-begin__ -from ray import train, tune - -# Tuner.fit() allows setting a custom log directory (other than ~/ray-results). -tuner = tune.Tuner( - "PPO", - param_space=config, - run_config=train.RunConfig( - stop={"num_env_steps_sampled_lifetime": 20000}, - checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True), - ), -) - -results = tuner.fit() - -# Get the best result based on a particular metric. -best_result = results.get_best_result( - metric="env_runners/episode_return_mean", mode="max" -) - -# Get the best checkpoint corresponding to the best result. -best_checkpoint = best_result.checkpoint -# __rllib-tuner-end__ - - -# __rllib-compute-action-begin__ -import pathlib -import gymnasium as gym -import numpy as np -import torch -from ray.rllib.core.rl_module import RLModule - -env = gym.make("CartPole-v1") - -# Create only the neural network (RLModule) from our checkpoint. -rl_module = RLModule.from_checkpoint( - pathlib.Path(best_checkpoint.path) / "learner_group" / "learner" / "rl_module" -)["default_policy"] - -episode_return = 0 -terminated = truncated = False - -obs, info = env.reset() - -while not terminated and not truncated: - # Compute the next action from a batch (B=1) of observations. - torch_obs_batch = torch.from_numpy(np.array([obs])) - action_logits = rl_module.forward_inference({"obs": torch_obs_batch})[ - "action_dist_inputs" - ] - # The default RLModule used here produces action logits (from which - # we'll have to sample an action or use the max-likelihood one). - action = torch.argmax(action_logits[0]).numpy() - obs, reward, terminated, truncated, info = env.step(action) - episode_return += reward - -print(f"Reached episode return of {episode_return}.") -# __rllib-compute-action-end__ - - -del rl_module - - -# __rllib-get-state-begin__ -from ray.rllib.algorithms.ppo import PPOConfig - -algo = ( - PPOConfig() - .api_stack( - enable_rl_module_and_learner=True, - enable_env_runner_and_connector_v2=True, - ) - .environment("CartPole-v1") - .env_runners(num_env_runners=2) -).build() - -# Get weights of the algo's RLModule. -algo.get_module().get_state() - -# Same as above -algo.env_runner.module.get_state() - -# Get list of weights of each EnvRunner, including remote replicas. -algo.env_runner_group.foreach_worker(lambda env_runner: env_runner.module.get_state()) - -# Same as above, but with index. -algo.env_runner_group.foreach_worker_with_id( - lambda _id, env_runner: env_runner.module.get_state() -) -# __rllib-get-state-end__ - -algo.stop() diff --git a/doc/source/rllib/doc_code/rllib_in_60s.py b/doc/source/rllib/doc_code/rllib_in_60s.py deleted file mode 100644 index 6d214504f15d8..0000000000000 --- a/doc/source/rllib/doc_code/rllib_in_60s.py +++ /dev/null @@ -1,25 +0,0 @@ -# flake8: noqa - -# __rllib-in-60s-begin__ -from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.connectors.env_to_module import FlattenObservations - -# 1. Configure the algorithm, -config = ( - PPOConfig() - .environment("Taxi-v3") - .env_runners( - num_env_runners=2, - # Observations are discrete (ints) -> We need to flatten (one-hot) them. - env_to_module_connector=lambda env: FlattenObservations(), - ) - .evaluation(evaluation_num_env_runners=1) -) -# 2. build the algorithm .. -algo = config.build() -# 3. .. train it .. -for _ in range(5): - print(algo.train()) -# 4. .. and evaluate it. -algo.evaluate() -# __rllib-in-60s-end__ diff --git a/doc/source/rllib/doc_code/training.py b/doc/source/rllib/doc_code/training.py deleted file mode 100644 index 75bf8a48f18c1..0000000000000 --- a/doc/source/rllib/doc_code/training.py +++ /dev/null @@ -1,170 +0,0 @@ -# flake8: noqa - -# __preprocessing_observations_start__ -try: - import gymnasium as gym - - env = gym.make("ale_py:ALE/Pong-v5") - obs, infos = env.reset() -except Exception: - import gym - - env = gym.make("PongNoFrameskip-v4") - obs = env.reset() - -# RLlib uses preprocessors to implement transforms such as one-hot encoding -# and flattening of tuple and dict observations. -from ray.rllib.models.preprocessors import get_preprocessor - -prep = get_preprocessor(env.observation_space)(env.observation_space) -# - -# Observations should be preprocessed prior to feeding into a model -obs.shape -# (210, 160, 3) -prep.transform(obs).shape -# (84, 84, 3) -# __preprocessing_observations_end__ - -# __query_action_dist_start__ -# Get a reference to the policy -import numpy as np -import torch - -from ray.rllib.algorithms.dqn import DQNConfig - -algo = ( - DQNConfig() - .api_stack( - enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False - ) - .framework("torch") - .environment("CartPole-v1") - .env_runners(num_env_runners=0) - .training( - replay_buffer_config={ - "type": "MultiAgentPrioritizedReplayBuffer", - } - ) -).build() -# - -policy = algo.get_policy() -# - -# Run a forward pass to get model output logits. Note that complex observations -# must be preprocessed as in the above code block. -logits, _ = policy.model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))}) -# (, []) - -# Compute action distribution given logits -policy.dist_class -# -dist = policy.dist_class(logits, policy.model) -# - -# Query the distribution for samples, sample logps -dist.sample() -# -dist.logp(torch.tensor([1])) -# - -# Get the estimated values for the most recent forward pass -policy.model.value_function() -# - -print(policy.model) -""" -Model: "model" -_____________________________________________________________________ -Layer (type) Output Shape Param # Connected to -===================================================================== -observations (InputLayer) [(None, 4)] 0 -_____________________________________________________________________ -fc_1 (Dense) (None, 256) 1280 observations[0][0] -_____________________________________________________________________ -fc_value_1 (Dense) (None, 256) 1280 observations[0][0] -_____________________________________________________________________ -fc_2 (Dense) (None, 256) 65792 fc_1[0][0] -_____________________________________________________________________ -fc_value_2 (Dense) (None, 256) 65792 fc_value_1[0][0] -_____________________________________________________________________ -fc_out (Dense) (None, 2) 514 fc_2[0][0] -_____________________________________________________________________ -value_out (Dense) (None, 1) 257 fc_value_2[0][0] -===================================================================== -Total params: 134,915 -Trainable params: 134,915 -Non-trainable params: 0 -_____________________________________________________________________ -""" -# __query_action_dist_end__ - - -# __get_q_values_dqn_start__ -# Get a reference to the model through the policy -import numpy as np -import torch - -from ray.rllib.algorithms.dqn import DQNConfig - -algo = ( - DQNConfig() - .api_stack( - enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False - ) - .framework("torch") - .environment("CartPole-v1") - .training( - replay_buffer_config={ - "type": "MultiAgentPrioritizedReplayBuffer", - } - ) -).build() -model = algo.get_policy().model -# - -# List of all model variables -list(model.parameters()) - -# Run a forward pass to get base model output. Note that complex observations -# must be preprocessed. An example of preprocessing is -# examples/offline_rl/saving_experiences.py -model_out = model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))}) -# (` +and your :ref:`neural network model `. + +Installation +~~~~~~~~~~~~ + +First, install RLlib, `PyTorch `__, and `Farama Gymnasium `__ as shown below: + +.. code-block:: bash + + pip install "ray[rllib]" torch "gymnasium[atari,accept-rom-license,mujoco]" + + +.. _rllib-python-api: + +Python API +~~~~~~~~~~ + +RLlib's Python API provides all the flexibility required for applying the library to any +type of RL problem. + +You manage RLlib experiments through an instance of the :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` +class. An :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` typically holds a neural +network for computing actions, called ``policy``, the :ref:`RL environment ` +that you want to optimize against, a loss function, an optimizer, and some code describing the +algorithm's execution logic, like determining when to collect samples, when to update your model, etc.. + +In :ref:`multi-agent training `, +:py:class:`~ray.rllib.algorithms.algorithm.Algorithm` manages the querying and optimization of multiple policies at once. + +Through the algorithm's interface, you can train the policy, compute actions, or store your +algorithm's state through checkpointing. + + +Configure and build the algorithm ++++++++++++++++++++++++++++++++++ + +You first create an :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig` instance +and change some default settings through the config object's various methods. + +For example, you can set the :ref:`RL environment ` +you want to use by calling the config's :py:meth:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig.environment` +method: + +.. testcode:: + + from ray.rllib.algorithms.ppo import PPOConfig + + # Create a config instance for the PPO algorithm. + config = ( + PPOConfig() + .environment("Pendulum-v1") + ) + + +To scale your setup and define, how many :py:class:`~ray.rllib.env.env_runner.EnvRunner` actors you want to leverage, +you can call the :py:meth:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig.env_runners` method. +``EnvRunners`` are used to collect samples for training updates from your :ref:`environment `. + +.. testcode:: + + config.env_runners(num_env_runners=2) + +For training-related settings or any algorithm-specific settings, use the +:py:meth:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig.training` method: + +.. testcode:: + + config.training( + lr=0.0002, + train_batch_size_per_learner=2000, + num_epochs=10, + ) + +Finally, you build the actual :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` instance +through calling your config's :py:meth:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig.build_algo` +method. + +.. testcode:: + + # Build the Algorithm (PPO). + ppo = config.build_algo() + + +.. note:: + + See here to learn about all the :ref:`methods you can use to configure your Algorithm `. + + +Run the algorithm ++++++++++++++++++ + +After you built your :ref:`PPO ` from its configuration, you can ``train`` it for a number of +iterations through calling the :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.train` method, +which returns a result dictionary that you can pretty-print for debugging purposes: + +.. testcode:: + + from pprint import pprint + + for _ in range(4): + pprint(ppo.train()) + + +Checkpoint the algorithm +++++++++++++++++++++++++ + +To save the current state of your :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`, +create a checkpoint through calling its :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.save_to_path` method, +which returns the directory of the saved checkpoint. + +Instead of not passing any arguments to this call and letting the :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` decide where to save +the checkpoint, you can also provide a checkpoint directory yourself: + +.. testcode:: + + checkpoint_path = ppo.save_to_path() + + # OR: + # ppo.save_to_path([a checkpoint location of your choice]) + + +Evaluate the algorithm +++++++++++++++++++++++ + +RLlib supports setting up a separate :py:class:`~ray.rllib.env.env_runner_group.EnvRunnerGroup` +for the sole purpose of evaluating your model from time to time on the :ref:`RL environment `. + +Use your config's :py:meth:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig.evaluation` method +to set up the details. By default, RLlib doesn't perform evaluation during training and only reports the +results of collecting training samples with its "regular" :py:class:`~ray.rllib.env.env_runner_group.EnvRunnerGroup`. + + +.. testcode:: + :hide: + + ppo.stop() + + +.. testcode:: + + config.evaluation( + # Run one evaluation round every iteration. + evaluation_interval=1, + + # Create 2 eval EnvRunners in the extra EnvRunnerGroup. + evaluation_num_env_runners=2, + + # Run evaluation for exactly 10 episodes. Note that because you have + # 2 EnvRunners, each one runs through 5 episodes. + evaluation_duration_unit="episodes", + evaluation_duration=10, + ) + + # Rebuild the PPO, but with the extra evaluation EnvRunnerGroup + ppo_with_evaluation = config.build_algo() + + for _ in range(3): + pprint(ppo_with_evaluation.train()) + +.. testcode:: + :hide: + + ppo_with_evaluation.stop() + + +.. _rllib-with-ray-tune: + +RLlib with Ray Tune ++++++++++++++++++++ + +All online RLlib :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` classes are compatible with +the :ref:`Ray Tune API `. + +.. note:: + + The offline RL algorithms, like :ref:`BC `, :ref:`CQL `, and :ref:`MARWIL ` + require more work on :ref:`Tune ` and :ref:`Ray Data ` + to add Ray Tune support. + +This integration allows for utilizing your configured :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` in +:ref:`Ray Tune ` experiments. + +For example, the following code performs a hyper-parameter sweep of your :ref:`PPO `, creating three ``Trials``, +one for each of the configured learning rates: + +.. testcode:: + + from ray import train, tune + from ray.rllib.algorithms.ppo import PPOConfig + + config = ( + PPOConfig() + .environment("Pendulum-v1") + # Specify a simple tune hyperparameter sweep. + .training( + lr=tune.grid_search([0.001, 0.0005, 0.0001]), + ) + ) + + # Create a Tuner instance to manage the trials. + tuner = tune.Tuner( + config.algo_class, + param_space=config, + # Specify a stopping criterion. Note that the criterion has to match one of the + # pretty printed result metrics from the results returned previously by + # ``.train()``. Also note that -1100 is not a good episode return for + # Pendulum-v1, we are using it here to shorten the experiment time. + run_config=train.RunConfig( + stop={"env_runners/episode_return_mean": -1100.0}, + ), + ) + # Run the Tuner and capture the results. + results = tuner.fit() + +Note that each :py:class:`~ray.tune.trial.Trial` creates a separate +:py:class:`~ray.rllib.algorithms.algorithm.Algorithm` instance as a :ref:`Ray actor `, +assigns compute resources to each ``Trial``, and runs them in parallel, if possible, +on your Ray cluster: + +.. code-block:: text + + Trial status: 3 RUNNING + Current time: 2025-01-17 18:47:33. Total running time: 3min 0s + Logical resource usage: 9.0/12 CPUs, 0/0 GPUs + ╭───────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ + │ Trial name status lr iter total time (s) episode_return_mean .._sampled_lifetime │ + ├───────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ + │ PPO_Pendulum-v1_b5c41_00000 RUNNING 0.001 29 86.2426 -998.449 108000 │ + │ PPO_Pendulum-v1_b5c41_00001 RUNNING 0.0005 25 74.4335 -997.079 100000 │ + │ PPO_Pendulum-v1_b5c41_00002 RUNNING 0.0001 20 60.0421 -960.293 80000 │ + ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + +``Tuner.fit()`` returns a ``ResultGrid`` object that allows for a detailed analysis of the +training process and for retrieving the :ref:`checkpoints ` of the trained +algorithms and their models: + +.. testcode:: + # Get the best result of the final iteration, based on a particular metric. + best_result = results.get_best_result( + metric="env_runners/episode_return_mean", + mode="max", + scope="last", + ) + + # Get the best checkpoint corresponding to the best result + # from the preceding experiment. + best_checkpoint = best_result.checkpoint + + +Deploy a trained model for inference +++++++++++++++++++++++++++++++++++++ + +After training, you might want to deploy your models into a new environment, for example +to run inference in production. For this purpose, you can use the checkpoint directory created +in the preceding example. To read more about checkpoints, model deployments, and restoring algorithm state, +see this :ref:`page on checkpointing ` here. + +Here is how you would create a new model instance from the checkpoint and run inference through +a single episode of your RL environment. Note in particular the use of the +:py:meth:`~ray.rllib.utils.checkpoints.Checkpointable.from_checkpoint` method to create +the model and the +:py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference` +method to compute actions: + + +.. testcode:: + + from pathlib import Path + import gymnasium as gym + import numpy as np + import torch + from ray.rllib.core.rl_module import RLModule + + # Create only the neural network (RLModule) from our algorithm checkpoint. + # See here (https://docs.ray.io/en/master/rllib/checkpoints.html) + # to learn more about checkpointing and the specific "path" used. + rl_module = RLModule.from_checkpoint( + Path(best_checkpoint.path) + / "learner_group" + / "learner" + / "rl_module" + / "default_policy" + ) + + # Create the RL environment to test against (same as was used for training earlier). + env = gym.make("Pendulum-v1", render_mode="human") + + episode_return = 0.0 + done = False + + # Reset the env to get the initial observation. + obs, info = env.reset() + + while not done: + # Uncomment this line to render the env. + # env.render() + + # Compute the next action from a batch (B=1) of observations. + obs_batch = torch.from_numpy(obs).unsqueeze(0) # add batch B=1 dimension + model_outputs = rl_module.forward_inference({"obs": obs_batch}) + + # Extract the action distribution parameters from the output and dissolve batch dim. + action_dist_params = model_outputs["action_dist_inputs"][0].numpy() + + # We have continuous actions -> take the mean (max likelihood). + greedy_action = np.clip( + action_dist_params[0:1], # 0=mean, 1=log(stddev), [0:1]=use mean, but keep shape=(1,) + a_min=env.action_space.low[0], + a_max=env.action_space.high[0], + ) + # For discrete actions, you should take the argmax over the logits: + # greedy_action = np.argmax(action_dist_params) + + # Send the action to the environment for the next step. + obs, reward, terminated, truncated, info = env.step(greedy_action) + + # Perform env-loop bookkeeping. + episode_return += reward + done = terminated or truncated + + print(f"Reached episode return of {episode_return}.") + + +Alternatively, if you still have an :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` instance up and running +in your script, you can get the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` through the +:py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.get_module` method: + +.. code-block:: python + + rl_module = ppo.get_module("default_policy") # Equivalent to `rl_module = ppo.get_module()` + + +Customizing your RL environment +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In the preceding examples, your :ref:`RL environment ` was +a `Farama gymnasium `__ pre-registered one, +like ``Pendulum-v1`` or ``CartPole-v1``. However, if you would like to run your +experiments against a custom one, see this tab below for a less-than-50-lines example. + +See here for an :ref:`in-depth guide on how to setup RL environments in RLlib ` +and how to customize them. + +.. dropdown:: Quickstart: Custom RL environment + :animate: fade-in-slide-down + + .. testcode:: + + import gymnasium as gym + from ray.rllib.algorithms.ppo import PPOConfig + + # Define your custom env class by subclassing gymnasium.Env: + + class ParrotEnv(gym.Env): + """Environment in which the agent learns to repeat the seen observations. + + Observations are float numbers indicating the to-be-repeated values, + e.g. -1.0, 5.1, or 3.2. + The action space is the same as the observation space. + Rewards are `r=-abs([observation] - [action])`, for all steps. + """ + def __init__(self, config=None): + # Since actions should repeat observations, their spaces must be the same. + self.observation_space = config.get( + "obs_act_space", + gym.spaces.Box(-1.0, 1.0, (1,), np.float32), + ) + self.action_space = self.observation_space + self._cur_obs = None + self._episode_len = 0 + + def reset(self, *, seed=None, options=None): + """Resets the environment, starting a new episode.""" + # Reset the episode len. + self._episode_len = 0 + # Sample a random number from our observation space. + self._cur_obs = self.observation_space.sample() + # Return initial observation. + return self._cur_obs, {} + + def step(self, action): + """Takes a single step in the episode given `action`.""" + # Set `terminated` and `truncated` flags to True after 10 steps. + self._episode_len += 1 + terminated = truncated = self._episode_len >= 10 + # Compute the reward: `r = -abs([obs] - [action])` + reward = -sum(abs(self._cur_obs - action)) + # Set a new observation (random sample). + self._cur_obs = self.observation_space.sample() + return self._cur_obs, reward, terminated, truncated, {} + + # Point your config to your custom env class: + config = ( + PPOConfig() + .environment( + ParrotEnv, + # Add `env_config={"obs_act_space": [some Box space]}` to customize. + ) + ) + + # Build a PPO algorithm and train it. + ppo_w_custom_env = config.build_algo() + ppo_w_custom_env.train() + + .. testcode:: + :hide: + + ppo_w_custom_env.stop() + + +Customizing your models +~~~~~~~~~~~~~~~~~~~~~~~ + +In the preceding examples, because you didn't specify anything in your +:py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig`, RLlib provided a default +neural network model. If you would like to either reconfigure the type and size of RLlib's default models, +for example define the number of hidden layers and their activation functions, +or even write your own custom models from scratch using PyTorch, see here +for a :ref:`detailed guide on the RLModule class `. + +See this tab below for a 30-lines example. + +.. dropdown:: Quickstart: Custom RLModule + :animate: fade-in-slide-down + + .. testcode:: + + import torch + + from ray.rllib.core.columns import Columns + from ray.rllib.core.rl_module.torch import TorchRLModule + + # Define your custom env class by subclassing `TorchRLModule`: + class CustomTorchRLModule(TorchRLModule): + def setup(self): + # You have access here to the following already set attributes: + # self.observation_space + # self.action_space + # self.inference_only + # self.model_config # <- a dict with custom settings + input_dim = self.observation_space.shape[0] + hidden_dim = self.model_config["hidden_dim"] + output_dim = self.action_space.n + + # Define and assign your torch subcomponents. + self._policy_net = torch.nn.Sequential( + torch.nn.Linear(input_dim, hidden_dim), + torch.nn.ReLU(), + torch.nn.Linear(hidden_dim, output_dim), + ) + + def _forward(self, batch, **kwargs): + # Push the observations from the batch through our `self._policy_net`. + action_logits = self._policy_net(batch[Columns.OBS]) + # Return parameters for the default action distribution, which is + # `TorchCategorical` (due to our action space being `gym.spaces.Discrete`). + return {Columns.ACTION_DIST_INPUTS: action_logits} diff --git a/doc/source/rllib/images/metrics_logger_overview.svg b/doc/source/rllib/images/metrics_logger_overview.svg new file mode 100644 index 0000000000000..93924f039ec03 --- /dev/null +++ b/doc/source/rllib/images/metrics_logger_overview.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/rllib/images/rllib-api.svg b/doc/source/rllib/images/rllib-api.svg deleted file mode 100644 index 6eb03dac2e494..0000000000000 --- a/doc/source/rllib/images/rllib-api.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/doc/source/rllib/images/rllib-components.svg b/doc/source/rllib/images/rllib-components.svg deleted file mode 100644 index 66625a1374bf1..0000000000000 --- a/doc/source/rllib/images/rllib-components.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/doc/source/rllib/index.rst b/doc/source/rllib/index.rst index eb4932f098267..1c1b5f7c7d5b1 100644 --- a/doc/source/rllib/index.rst +++ b/doc/source/rllib/index.rst @@ -16,7 +16,7 @@ RLlib: Industry-Grade, Scalable Reinforcement Learning .. todo (sven): redo toctree: suggestion: - getting-started (replaces rllib-training) + getting-started key-concepts rllib-env (single-agent) ... <- multi-agent @@ -33,7 +33,7 @@ RLlib: Industry-Grade, Scalable Reinforcement Learning metrics-logger rllib-advanced-api algorithm (general description of how algos work) - rllib-rlmodule + rl-modules rllib-offline single-agent-episode multi-agent-episode @@ -47,7 +47,7 @@ RLlib: Industry-Grade, Scalable Reinforcement Learning .. toctree:: :hidden: - rllib-training + getting-started key-concepts rllib-env algorithm-config @@ -99,34 +99,71 @@ Install RLlib and `PyTorch `__, as shown below: .. note:: To be able to run the Atari or MuJoCo examples, you also need to do: - `pip install "gymnasium[atari,accept-rom-license,mujoco]"`. -This is all. You can now start coding against RLlib. Here is an example for running the PPO Algorithm on the + .. code-block:: bash + + pip install "gymnasium[atari,accept-rom-license,mujoco]" + +This is all, you can now start coding against RLlib. Here is an example for running the :ref:`PPO Algorithm ` on the `Taxi domain `__. -You first create a `config` for the algorithm, which defines the RL environment and -any other needed settings and parameters. +You first create a `config` for the algorithm, which defines the :ref:`RL environment ` and any other needed settings and parameters. + +.. testcode:: + + from ray.rllib.algorithms.ppo import PPOConfig + from ray.rllib.connectors.env_to_module import FlattenObservations + + # Configure the algorithm. + config = ( + PPOConfig() + .environment("Taxi-v3") + .env_runners( + num_env_runners=2, + # Observations are discrete (ints) -> We need to flatten (one-hot) them. + env_to_module_connector=lambda env: FlattenObservations(), + ) + .evaluation(evaluation_num_env_runners=1) + ) + + +Next, ``build`` the algorithm and ``train`` it for a total of five iterations. +One training iteration includes parallel, distributed sample collection by the +:py:class:`~ray.rllib.env.env_runner.EnvRunner` actors, followed by loss calculation +on the collected data, and a model update step. + +.. testcode:: + + from pprint import pprint + + # Build the algorithm. + algo = config.build_algo() + + # Train it for 5 iterations ... + for _ in range(5): + pprint(algo.train()) + +At the end of your script, you evaluate the trained Algorithm and release all its resources: + +.. testcode:: -Next, `build` the algorithm and `train` it for a total of five iterations. -One training iteration includes parallel, distributed sample collection by the :py:class:`~ray.rllib.env.env_runner.EnvRunner` actors, -followed by loss calculation on the collected data, and a model update step. + # ... and evaluate it. + pprint(algo.evaluate()) -At the end of your script, RLlib evaluates the trained Algorithm: + # Release the algo's resources (remote actors, like EnvRunners and Learners). + algo.stop() -.. literalinclude:: doc_code/rllib_in_60s.py - :language: python - :start-after: __rllib-in-60s-begin__ - :end-before: __rllib-in-60s-end__ You can use any `Farama-Foundation Gymnasium `__ registered environment -with the `env` argument. +with the ``env`` argument. -In `config.env_runners()` you can specify - amongst many other things - the number of parallel +In ``config.env_runners()`` you can specify - amongst many other things - the number of parallel :py:class:`~ray.rllib.env.env_runner.EnvRunner` actors to collect samples from the environment. -You can also tweak the NN architecture used by tweaking RLlib's `DefaultModelConfig`, as well as, set up a separate -config for the evaluation :py:class:`~ray.rllib.env.env_runner.EnvRunner` actors through the `config.evaluation()` method. +You can also tweak the NN architecture used by tweaking RLlib's :py:class:`~ray.rllib.core.rl_module.default_model_cnofig.DefaultModelConfig`, +as well as, set up a separate config for the evaluation +:py:class:`~ray.rllib.env.env_runner.EnvRunner` actors through the ``config.evaluation()`` method. -`See here `_, if you want to learn more about the RLlib training APIs. +:ref:`See here `, if you want to learn more about the RLlib training APIs. Also, `see here `__ for a simple example on how to write an action inference loop after training. @@ -349,7 +386,7 @@ production training-workflows. For example, you may code your own `environments `__ in python using the `Farama Foundation's gymnasium `__ or DeepMind's OpenSpiel, provide custom `PyTorch models `_, -write your own `optimizer setups and loss definitions `__, +write your own `optimizer setups and loss definitions `__, or define custom `exploratory behavior `_. .. figure:: images/rllib-new-api-stack-simple.svg diff --git a/doc/source/rllib/key-concepts.rst b/doc/source/rllib/key-concepts.rst index b5496dd38279c..4759e07aef996 100644 --- a/doc/source/rllib/key-concepts.rst +++ b/doc/source/rllib/key-concepts.rst @@ -164,7 +164,7 @@ RLModules The following is a quick overview of **RLlib RLModules**. See :ref:`here for a detailed description of the RLModule class `. -`RLModules `__ are deep-learning framework-specific neural network containers. +`RLModules `__ are deep-learning framework-specific neural network wrappers. RLlib's :ref:`EnvRunners ` use them for computing actions when stepping through the :ref:`RL environment ` and RLlib's :ref:`Learners ` use :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` instances for computing losses and gradients before updating them. diff --git a/doc/source/rllib/metrics-logger.rst b/doc/source/rllib/metrics-logger.rst new file mode 100644 index 0000000000000..76b985cc6e5f4 --- /dev/null +++ b/doc/source/rllib/metrics-logger.rst @@ -0,0 +1,462 @@ +.. include:: /_includes/rllib/we_are_hiring.rst + +.. _rllib-metric-logger-docs: + +MetricsLogger API +================== + +.. include:: /_includes/rllib/new_api_stack.rst + +The RLlib team designed the :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` API +to unify and make accessible the logging and processing of stats and metrics during +reinforcement learning (RL) experiments. RLlib's :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` +class and all its sub-components each have one :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` +instance managing metrics and statistics for this component. When a subcomponent reports back to its +parent component, it "reduces" the logged results and sends them upstream. + +The RLlib team recommends this API for all your custom code, like in +:py:class:`~ray.rllib.env.env_runner.EnvRunner`-based :ref:`callbacks `, +in `custom loss functions `__, or in custom `training_step() `__ +implementations. + +.. figure:: images/metrics_logger_overview.svg + :width: 750 + :align: left + + **RLlib's MetricsLogger system**: Every subcomponent of an RLlib :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` has-a + :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` instance + and uses it to locally log values. When a component completes a distinct task, + for example, an :py:class:`~ray.rllib.env.env_runner.EnvRunner` finishing a sampling request, the local metrics of the subcomponent + (``EnvRunner``) are "reduced", and sent upstream to the containing parent component (``Algorithm``). + The parent component merges the received results into its own :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` and, + at the end of its own task cycle, "reduces" as well for final reporting to the user or to Ray Tune. + + +.. note:: + So far, RLlib components owning a :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` + instance are :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`, :py:class:`~ray.rllib.env.env_runner.EnvRunner`, + :py:class:`~ray.rllib.core.learner.learner.Learner`, all :py:class:`~ray.rllib.connectors.connector_v2.ConnectorV2` classes, + and all ``~ray.rllib.utils.replay_buffers.EpisodeReplayBuffer`` classes. + The Ray team is considering expanding access to this API on other components as well. + + +Features of MetricsLogger +------------------------- + +The :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` API offers the following functionalities: + +- Log scalar values over time, such as losses, individual rewards, or episode returns. +- Configure different reduction types, in particular ``mean``, ``min``, ``max``, or ``sum``. Also, users can chose to not + reduce at all through the ``reduce=None`` setting, leaving the logged values untouched. + A separate ``clear_on_reduce=True`` setting allows for automatically clearing all logged values on each ``reduce`` event. +- Specify sliding windows, over which reductions take place, for example ``window=100`` to average over the + last 100 logged values, or specify exponential moving average (EMA) coefficients, through which the weight of older values + in the computed mean should decay over time. +- Merge ``n`` result dicts from ``n`` parallel subcomponents into the local :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger`. + Each of these ``n`` dicts is the result of a ``reduce`` operation on each subcomponent's own :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` + instance. +- Log execution times for distinct code blocks through convenient ``with ...`` blocks. +- Add up lifetime counts and automatically compute the corresponding throughput metrics per second along the way. + + +Built-in usages of MetricsLogger +-------------------------------- + +RLlib uses the :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` API extensively in the +existing code-base. The following is an overview of a typical information flow resulting from this: + +#. The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` sends parallel sample requests to its ``n`` :py:class:`~ray.rllib.env.env_runner.EnvRunner` actors. +#. Each :py:class:`~ray.rllib.env.env_runner.EnvRunner` collects training data by stepping through its :ref:`RL environment ` and logs standard stats to its :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger`, such as episode return or episode length. +#. Each :py:class:`~ray.rllib.env.env_runner.EnvRunner` calls :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.reduce` on its own :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` instance and returns the resulting stats dict. +#. The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` merges the ``n`` received stats dicts into its own :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` instance under the top-level key "env_runners", thereby keeping all log-settings chosen by the :py:class:`~ray.rllib.env.env_runner.EnvRunner` actors. +#. The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` sends parallel update requests to its ``m`` :py:class:`~ray.rllib.core.learner.learner.Learner` actors. +#. Each :py:class:`~ray.rllib.core.learner.learner.Learner` performs a model update through computing losses and gradients and logs standard stats to its :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger`, such as total loss or mean gradients. +#. Each :py:class:`~ray.rllib.core.learner.learner.Learner` calls :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.reduce` on its own :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` instance and returns the resulting stats dict. +#. The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` merges the ``m`` received stats dicts into its own :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` instance under the top-level key "learners", thereby keeping all log-settings chosen by the :py:class:`~ray.rllib.core.learner.learner.Learner` actors. +#. The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` may add standard stats to its own :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` instance, for example the average time of a parallel sample request. +#. The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` calls :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.reduce` on its own :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` instance, compiling and returning a complete and final stats dict to the user or Ray Tune. + + +The MetricsLogger APIs in detail +-------------------------------- + +Before you can use :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` in your custom code, familiarize +yourself with how to actually use its APIs. + +Logging scalar values +~~~~~~~~~~~~~~~~~~~~~ + +To log a scalar value under some string key in your :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger`, +use the :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.log_value` method: + +.. testcode:: + + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + + logger = MetricsLogger() + + # Log a scalar float value under the `loss` key. By default, all logged + # values under that key are averaged, once `reduce()` is called. + logger.log_value("loss", 0.01, reduce="mean", window=2) + +By default, :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` reduces values through averaging them (``reduce="mean"``). +Other available reduce types are ``reduce="min"``, ``reduce="max"``, and ``reduce="sum"``. + +Specifying a ``window`` causes the reduction to take place over the last ``window`` logged values. +For example, you can continue logging new values under the ``loss`` key: + +.. testcode:: + + logger.log_value("loss", 0.02) # don't have to repeat `reduce` or `window` args, + # because the key already exists. + logger.log_value("loss", 0.03) + logger.log_value("loss", 0.04) + logger.log_value("loss", 0.05) + +Because you specified a window of 2, :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` only uses the last 2 values to compute the reduced result. +You can ``peek()`` at the currently reduced result throug the :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.peek` method: + +.. testcode:: + + # Peek at the current, reduced value. + # Note that in the underlying structure, the internal values list still + # contains all logged values: 0.01, 0.02, 0.03, 0.04, and 0.05. + print(logger.peek("loss")) # Expect: 0.045, which is the average over the last 2 values + +The :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.peek` method allows you to +check the current underlying reduced result for some key, without actually having to call +:py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.reduce`. + +.. warning:: + + **Don't call the reduce() method yourself** on any + :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` from your custom code. + The only time RLlib invokes this API is at the end of a task cycle. + RLlib controls all of these "hand over" points entirely, so unless you write your own subcomponent that reports to a parent component, such as + :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`, refrain from calling the + :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.reduce` method. + + To get the current reduced results, use the :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.peek` method instead, + which doesn't alter any underlying values. + + +Instead of providing a flat key, you can also log a value under some nested key through passing in a tuple: + +.. testcode:: + + # Log a value under a deeper nested key. + logger.log_value(("some", "nested", "key"), -1.0) + print(logger.peek(("some", "nested", "key"))) # expect: -1.0 + + +To use reduce methods, other than "mean", specify the ``reduce`` argument in +:py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.log_value`: + +.. testcode:: + + # Log a maximum value. + logger.log_value(key="max_value", value=0.0, reduce="max") + +Because you didn't specify a ``window`` and are using ``reduce="max"``, RLlib uses the infinite window, +meaning :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` reports the lifetime maximum value, +whenever reduction takes place or you peek at the current value: + +.. testcode:: + + for i in range(1000, 0, -1): + logger.log_value(key="max_value", value=float(i)) + + logger.peek("max_value") # Expect: 1000.0, which is the lifetime max (infinite window) + + +You can also chose to not reduce at all, but to simply collect individual values, for example a set of images you receive +from your environment over time and for which it doesn't make sense to reduce them in any way. + +Use the ``reduce=None`` argument for achieving this. However, it's strongly advised that you should also +set the ``clear_on_reduce=True`` flag, because this setting may cause memory leaks otherwise. +This flag assures that :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` clears out the underlying list of values after every +``reduce()`` handover operation, for example from :py:class:`~ray.rllib.env.env_runner.EnvRunner` +to :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`: + +.. testcode:: + + logger.log_value("some_items", value="a", reduce=None, clear_on_reduce=True) + logger.log_value("some_items", value="b") + logger.log_value("some_items", value="c") + logger.log_value("some_items", value="d") + + logger.peek("some_items") # expect a list: ["a", "b", "c", "d"] + + logger.reduce() + logger.peek("some_items") # expect an empty list: [] + + +Logging a set of nested scalar values +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you're logging a nested structure of values, for example +``{"time_s": 0.1, "lives": 5, "rounds_played": {"player1": 10, "player2": 4}}`` and all values have the exact same log settings +in terms of the ``reduce``, ``clear_on_reduce``, ``window``, etc arguments, you can also call the shortcut +:py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.log_dict` method to do so: + + +.. testcode:: + + + + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + + logger = MetricsLogger() + + # Log a bunch of scalar values within a nested dict. + stats = {"player1": 100.0, "player2": 105.0} + logger.log_dict(stats, key="mean_scores", reduce="mean", window=10) + + # Later, do the same again. + stats = {"player1": 150.0, "player2": 110.0} + logger.log_dict(stats, key="mean_scores") + + print(logger.peek(("mean_scores", "player1"))) # <- expect 125.0 + +Logging non-scalar data +~~~~~~~~~~~~~~~~~~~~~~~ + +:py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` isn't limited to scalar values. +You can also use it to log images, videos, or any other complex data. + +Normally, you would chose the previously described ``reduce=None`` argument. For example, to +log three consecutive image frames from a ``CartPole`` environment, do the following: + +.. testcode:: + + import gymnasium as gym + + env = gym.make("CartPole-v1") + + # Log three consecutive render frames from the env. + # Make sure to set ``clear_on_reduce=True`` to avoid memory leaks. + env.reset() + logger.log_value("some_images", value=env.render(), reduce=None, clear_on_reduce=True) + env.step(0) + logger.log_value("some_images", value=env.render()) + env.step(1) + logger.log_value("some_images", value=env.render()) + +Timers +~~~~~~ + +:py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` is context capable and offers the following +simple API to log timer results. +Notice that you can now time all your code blocks of interest inside your custom code through a single ``with-`` line: + +.. testcode:: + + import time + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + + logger = MetricsLogger() + + # First delta measurement: + with logger.log_time("my_block_to_be_timed", reduce="mean", ema_coeff=0.1): + time.sleep(1.0) + + # EMA should be ~1sec. + assert 1.1 > logger.peek("my_block_to_be_timed") > 0.9 + + # Second delta measurement: + with logger.log_time("my_block_to_be_timed"): + time.sleep(2.0) + + # EMA should be ~1.1sec. + assert 1.15 > logger.peek("my_block_to_be_timed") > 1.05 + +.. note:: + The default logging behavior is through exponential mean averaging (EMA), with a default coefficient of 0.01. + This default is usually a good choice for averaging timer results over the course of the experiment. + + .. TODO: add this paragraph once we properly support lifetime simple average: + If instead you want to reduce through averaging all logged values over the lifetime, + use `with logger.log_time([some key], reduce="mean", window=float("inf"))`, instead. + + +Counters +~~~~~~~~ + +In case you want to count things, for example the number of environment steps taken in a sample phase, and add up those +counts either over the lifetime or over some particular phase, use the ``reduce="sum"`` argument in the call to +:py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.log_value`. + +Combine this with ``clear_on_reduce=True``, if you want the count to only accumulate until the next "reduce" event. +Set ``clear_on_reduce=False``, which is the default, if you want the count to accumulate over the lifetime. + +.. testcode:: + + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + + logger = MetricsLogger() + + logger.log_value("my_counter", 50, reduce="sum", window=None) + logger.log_value("my_counter", 25) + logger.peek("my_counter") # expect: 75 + + # Even if your logger gets "reduced" from time to time, the counter keeps increasing + # because we set clear_on_reduce=False (default behavior): + logger.reduce() + logger.peek("my_counter") # still expect: 75 + + # To clear the sum after each "reduce" event, set `clear_on_reduce=True`: + logger.log_value("my_temp_counter", 50, reduce="sum", window=None, clear_on_reduce=True) + logger.log_value("my_temp_counter", 25) + logger.peek("my_counter") # expect: 75 + logger.reduce() + logger.peek("my_counter") # expect: 0 (upon reduction, all values are cleared) + + +Automatic throughput measurements ++++++++++++++++++++++++++++++++++ + +A metrics logged with the settings ``reduce="sum"`` and ``clear_on_reduce=False`` is considered +a ``lifetime`` counter, accumulating counts over the entire course of the experiment without ever resetting +the value back to 0. If you also add the ``with_throughput=True`` flag, the underlying metric automatically computes the throughput per second +on each ``reduce()`` operation. + +The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` automatically compiles an extra key for each such metric, adding the suffix ``_throughput`` +to the original key and assigning it the value for the throughput per second. + +You can use the :py:meth:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger.peek` method with the call argument ``throughput=True`` +to access the throughput value. For example: + +.. testcode:: + + import time + from ray.rllib.utils.metrics.metrics_logger import MetricsLogger + + logger = MetricsLogger() + + for _ in range(3): + logger.log_value("lifetime_count", 5, reduce="sum", with_throughput=True) + + # RLlib triggers a new throughput computation at each `reduce()` call + logger.reduce() + time.sleep(1.0) + + # Expect the first call to return NaN because we don't have a proper start time for the time delta. + # From the second call on, expect a value of roughly 5/sec. + print(logger.peek("lifetime_count", throughput=True)) + + +Example 1: How to use MetricsLogger in EnvRunner callbacks +---------------------------------------------------------- + +To demonstrate how to use the :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` on an :py:class:`~ray.rllib.env.env_runner.EnvRunner`, take a look at this end-to-end example here, which +makes use of the :py:class:`~ray.rllib.callbacks.callbacks.RLlibCallback` API to inject custom code into the RL environment loop. + +The example computes the average "first-joint angle" of the +`Acrobot-v1 RL environment `__ +and logs the results through the :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` API. + +Note that this example is :ref:`identical to the one described here `, but the focus has shifted to explain +only the :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` aspects of the code. + +.. testcode:: + + import math + import numpy as np + from ray.rllib.algorithms.ppo import PPOConfig + from ray.rllib.callbacks.callbacks import RLlibCallback + + # Define a custom RLlibCallback. + class LogAcrobotAngle(RLlibCallback): + def on_episode_step(self, *, episode, env, **kwargs): + # Compute the angle at every episode step and store it temporarily in episode: + state = env.envs[0].unwrapped.state + deg_theta1 = math.degrees(math.atan2(state[1], state[0])) + episode.add_temporary_timestep_data("theta1", deg_theta1) + + def on_episode_end(self, *, episode, metrics_logger, **kwargs): + theta1s = episode.get_temporary_timestep_data("theta1") + avg_theta1 = np.mean(theta1s) + + # Log the resulting average angle - per episode - to the MetricsLogger. + # Report with a sliding window of 50. + metrics_logger.log_value("theta1_mean", avg_theta1, reduce="mean", window=50) + + config = ( + PPOConfig() + .environment("Acrobot-v1") + .callbacks( + callbacks_class=LogAcrobotAngle, + ) + ) + ppo = config.build() + + # Train n times. Expect `theta1_mean` to be found in the results under: + # `env_runners/theta1_mean` + for i in range(10): + results = ppo.train() + print( + f"iter={i} " + f"theta1_mean={results['env_runners']['theta1_mean']} " + f"R={results['env_runners']['episode_return_mean']}" + ) + + +Also take a look at this more complex example on +`how to generate and log a PacMan heatmap (image) to WandB `__ here. + + +Example 2: How to use MetricsLogger in a custom loss function +------------------------------------------------------------- + +You can log metrics inside your custom loss functions. Use the Learner's ``self.metrics`` attribute for this. + +.. note:: + + When logging loss values, the RLlib team recommends to use ``window=1`` to always report the exact + current loss value, rather than a smoothened result over time. This way, you notice strange spikes or unstable + behavior in your loss math right away and can pinpoint problems to a particular iteration. + + +.. code-block:: + + @override(TorchLearner) + def compute_loss_for_module(self, *, module_id, config, batch, fwd_out): + ... + + loss_xyz = ... + + # Log a specific loss term. + self.metrics.log_value("special_loss_term", loss_xyz, window=1) + + total_loss = loss_abc + loss_xyz + + return total_loss + + +Take a look at this running +`end-to-end example for logging custom values inside a loss function `__ here. + + +Example 3: How to use MetricsLogger in a custom Algorithm +--------------------------------------------------------- + +You can log metrics inside your custom Algorithm :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.training_step` method. +Use the Algorithm's own ``self.metrics`` attribute for this. + +.. code-block:: + + @override(Algorithm) + def training_step(self) -> None: + ... + + # Log some value. + self.metrics.log_value("some_mean_result", 1.5, window=5) + + ... + + with self.metrics.log_time(("timers", "some_code")): + ... # time some code + + +See this running +`end-to-end example for logging inside training_step() `__. diff --git a/doc/source/rllib/multi-agent-envs.rst b/doc/source/rllib/multi-agent-envs.rst index 9944a1a619451..52312b98fefce 100644 --- a/doc/source/rllib/multi-agent-envs.rst +++ b/doc/source/rllib/multi-agent-envs.rst @@ -22,6 +22,7 @@ allowing for any policy to control more than one agent. .. figure:: images/envs/multi_agent_setup.svg :width: 600 + :align: left **Multi-agent setup:** ``N`` agents live in the environment and take actions computed by ``M`` policy networks. The mapping from agent to policy is flexible and determined by a user-provided mapping function. Here, `agent_1` @@ -177,6 +178,7 @@ their observations in the returned observations dict. .. figure:: images/envs/multi_agent_episode_simultaneous.svg :width: 600 + :align: left **Env with simultaneously acting agents:** Both agents receive their observations at each time step, including right after `reset()`. Note that an agent must compute and send an action @@ -202,6 +204,7 @@ returned observation dict. .. figure:: images/envs/multi_agent_episode_turn_based.svg :width: 600 + :align: left **Env with agents taking turns:** The two agents act by taking alternating turns. `agent_1` receives the first observation after the `reset()` and thus has to compute and send an action first. Upon receiving @@ -214,6 +217,7 @@ environments where all agents always act simultaneously, to any arbitrarily comp .. figure:: images/envs/multi_agent_episode_complex_order.svg :width: 600 + :align: left **Env with a complex order of turns:** Three agents act in a seemingly chaotic order. `agent_1` and `agent_3` receive their initial observation after the `reset()` and thus has to compute and send actions first. Upon receiving diff --git a/doc/source/rllib/new-api-stack-migration-guide.rst b/doc/source/rllib/new-api-stack-migration-guide.rst index 9ba4e3e63f632..9eb426dcca935 100644 --- a/doc/source/rllib/new-api-stack-migration-guide.rst +++ b/doc/source/rllib/new-api-stack-migration-guide.rst @@ -240,7 +240,7 @@ It allows you to specify: #. the number of `Learner` workers through `.learners(num_learners=...)`. #. the resources per learner; use `.learners(num_gpus_per_learner=1)` for GPU training and `.learners(num_gpus_per_learner=0)` for CPU training. -#. the custom Learner class you want to use. See this `example `__ for more details. +#. the custom Learner class you want to use. See this `example `__ for more details. #. a config dict you would like to set for your custom learner: `.learners(learner_config_dict={...})`. Note that every `Learner` has access to the entire `AlgorithmConfig` object through `self.config`, but setting the @@ -530,7 +530,7 @@ customizations inside the old stack's Policy class, you need to move the logic i See :ref:`Learner ` for details on how to write a custom Learner . The following example scripts show how to write: -- `a simple custom loss function `__ +- `a simple custom loss function `__ - `a custom Learner with 2 optimizers and different learning rates for each `__. Note that the new API stack doesn't support the Policy class. In the old stack, this class holds a diff --git a/doc/source/rllib/package_ref/distributions.rst b/doc/source/rllib/package_ref/distributions.rst index f01fa27f92c28..401e54cc912e0 100644 --- a/doc/source/rllib/package_ref/distributions.rst +++ b/doc/source/rllib/package_ref/distributions.rst @@ -17,3 +17,8 @@ Base Distribution class :toctree: doc/ ~Distribution + ~Distribution.from_logits + ~Distribution.sample + ~Distribution.rsample + ~Distribution.logp + ~Distribution.kl diff --git a/doc/source/rllib/package_ref/rl_modules.rst b/doc/source/rllib/package_ref/rl_modules.rst index 4777a759c9a3f..f2c53cc162a10 100644 --- a/doc/source/rllib/package_ref/rl_modules.rst +++ b/doc/source/rllib/package_ref/rl_modules.rst @@ -21,6 +21,12 @@ Single RLModuleSpec RLModuleSpec RLModuleSpec.build + RLModuleSpec.module_class + RLModuleSpec.observation_space + RLModuleSpec.action_space + RLModuleSpec.inference_only + RLModuleSpec.learner_only + RLModuleSpec.model_config MultiRLModuleSpec +++++++++++++++++ @@ -34,6 +40,25 @@ MultiRLModuleSpec MultiRLModuleSpec MultiRLModuleSpec.build +.. autoattribute:: ray.rllib.core.rl_module.multi_rl_module.MultiRLModuleSpec.multi_rl_module_class + :no-index: + +.. autoattribute:: ray.rllib.core.rl_module.multi_rl_module.MultiRLModuleSpec.observation_space + :no-index: + +.. autoattribute:: ray.rllib.core.rl_module.multi_rl_module.MultiRLModuleSpec.action_space + :no-index: + +.. autoattribute:: ray.rllib.core.rl_module.multi_rl_module.MultiRLModuleSpec.inference_only + :no-index: + +.. autoattribute:: ray.rllib.core.rl_module.multi_rl_module.MultiRLModuleSpec.model_config + :no-index: + +.. autoattribute:: ray.rllib.core.rl_module.multi_rl_module.MultiRLModuleSpec.rl_module_specs + :no-index: + + DefaultModelConfig ++++++++++++++++++ @@ -152,3 +177,47 @@ Saving and restoring ~MultiRLModule.from_checkpoint ~MultiRLModule.get_state ~MultiRLModule.set_state + + +Additional RLModule APIs +------------------------ + +.. currentmodule:: ray.rllib.core.rl_module.apis + +InferenceOnlyAPI +++++++++++++++++ + +.. autoclass:: ray.rllib.core.rl_module.apis.inference_only_api.InferenceOnlyAPI + + .. automethod:: get_non_inference_attributes + +QNetAPI ++++++++ + +.. autoclass:: ray.rllib.core.rl_module.apis.q_net_api.QNetAPI + + .. automethod:: compute_q_values + .. automethod:: compute_advantage_distribution + +SelfSupervisedLossAPI ++++++++++++++++++++++ + +.. autoclass:: ray.rllib.core.rl_module.apis.self_supervised_loss_api.SelfSupervisedLossAPI + + .. automethod:: compute_self_supervised_loss + +TargetNetworkAPI +++++++++++++++++ + +.. autoclass:: ray.rllib.core.rl_module.apis.target_network_api.TargetNetworkAPI + + .. automethod:: make_target_networks + .. automethod:: get_target_network_pairs + .. automethod:: forward_target + +ValueFunctionAPI +++++++++++++++++ + +.. autoclass:: ray.rllib.core.rl_module.apis.value_function_api.ValueFunctionAPI + + .. automethod:: compute_values diff --git a/doc/source/rllib/rllib-rlmodule.rst b/doc/source/rllib/rl-modules.rst similarity index 77% rename from doc/source/rllib/rllib-rlmodule.rst rename to doc/source/rllib/rl-modules.rst index 7af433ecfd8b1..49a7fc123fbd6 100644 --- a/doc/source/rllib/rllib-rlmodule.rst +++ b/doc/source/rllib/rl-modules.rst @@ -1,12 +1,12 @@ .. include:: /_includes/rllib/we_are_hiring.rst +.. include:: /_includes/rllib/new_api_stack.rst + .. _rlmodule-guide: RL Modules ========== -.. include:: /_includes/rllib/new_api_stack.rst - The :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` class in RLlib's new API stack allows you to write custom models, including highly complex multi-network setups often found in multi-agent or model-based algorithms. @@ -81,9 +81,9 @@ image observations (``[width] x [height] x [channels]``). config = ( PPOConfig() - # FrozenLake has a discrete observation space, ... + # FrozenLake has a discrete observation space (ints). .environment("FrozenLake-v1") - # ... which `FlattenObservations` converts to one-hot. + # `FlattenObservations` converts int observations to one-hot. .env_runners(env_to_module_connector=lambda env: FlattenObservations()) ) @@ -121,6 +121,14 @@ with PPO and the default RLModule, configure your experiment as follows: ) ) +.. testcode:: + :hide: + + test = config.build() + test.train() + test.stop() + + The following is the compete list of all supported ``fcnet_..`` options: .. literalinclude:: ../../../rllib/core/rl_module/default_model_config.py @@ -136,19 +144,36 @@ For image-based environments like `Atari ` ``conv_..`` fields in :py:class:`~ray.rllib.core.rl_module.default_model_config.DefaultModelConfig` to configure the convolutional neural network (CNN) stack. -For example: +You may have to check whether your CNN configuration works with the incoming observation image +dimensions. For example, for an `Atari `__ environment, you can +use RLlib's Atari wrapper utility, which performs resizing (default 64x64) and gray scaling (default True), +frame stacking (default None), frame skipping (default 4), normalization (from uint8 to float32), and +applies up to 30 "noop" actions after a reset, which aren't part of the episode: .. testcode:: + import gymnasium as gym # `pip install gymnasium[atari,accept-rom-license]` + from ray.rllib.algorithms.ppo import PPOConfig + from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig + from ray.tune import register_env + + register_env( + "image_env", + lambda _: wrap_atari_for_new_api_stack( + gym.make("ale_py:ALE/Pong-v5"), + dim=64, # resize original observation to 64x64x3 + framestack=4, + ) + ) config = ( PPOConfig() - .environment("ale_py:ALE/Pong-v5") # `pip install gymnasium[atari]` + .environment("image_env") .rl_module( model_config=DefaultModelConfig( - # Use a DreamerV3-style CNN stack. + # Use a DreamerV3-style CNN stack for 64x64 images. conv_filters=[ [16, 4, 2], # 1st CNN layer: num_filters, kernel, stride(, padding)? [32, 4, 2], # 2nd CNN layer @@ -156,13 +181,19 @@ For example: [128, 4, 2], ], conv_activation="silu", - # After the last CNN, the default model flattens, then adds an optional MLP. head_fcnet_hiddens=[256], ) ) ) +.. testcode:: + :hide: + + test = config.build() + test.train() + test.stop() + The following is the compete list of all supported ``conv_..`` options: .. literalinclude:: ../../../rllib/core/rl_module/default_model_config.py @@ -177,6 +208,17 @@ Other default model settings For LSTM-based configurations and specific settings for continuous action output layers, see :py:class:`~ray.rllib.core.rl_module.default_model_config.DefaultModelConfig`. +.. note:: + + To auto-wrap your default encoder with an extra LSTM layer and allow your model to learn in + non-Markovian, partially observable environments, you can try the convenience + ``DefaultModelConfig.use_lstm`` setting in combination with the + ``DefaultModelConfig.lstm_cell_size`` and ``DefaultModelConfig.max_seq_len`` settings. + See here for a tuned + `example that uses a default RLModule with an LSTM layer `__. + +.. TODO: mention attention example once done + Constructing RLModule instances ------------------------------- @@ -415,6 +457,8 @@ or any multi-model use cases, subclass the :py:class:`~ray.rllib.core.rl_module. See :ref:`Algorithm-specific RLModule APIs ` for how to determine which APIs your algorithm requires you to implement. +.. _rllib-implementing-custom-rl-modules-setup: + The setup() method ~~~~~~~~~~~~~~~~~~ @@ -437,7 +481,6 @@ You also have access to the following attributes anywhere in the class, includin from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule class MyTorchPolicy(TorchRLModule): - def setup(self): # You have access here to the following already set attributes: # self.observation_space @@ -462,6 +505,7 @@ You also have access to the following attributes anywhere in the class, includin torch.nn.Linear(hidden_dim, output_dim), ) +.. _rllib-implementing-custom-rl-modules-forward: Forward methods ~~~~~~~~~~~~~~~~~~~ @@ -505,7 +549,6 @@ If you don't return the ``actions`` key from your forward method: actions from the distribution logits or probabilities. If you return the "actions" key, RLlib skips that sampling step. - .. tab-set:: .. tab-item:: Returning "actions" key @@ -601,13 +644,30 @@ To find out, what APIs your Algorithms require, do the following: .. note:: - You don't need the preceding VPG example module to implement any APIs because - you haven't considered training it with any particular algorithm. - You can find examples of algorithm-ready :py:class:`~ray.rllib.algorithms.ppo.PPO` custom RLModules - in the `tiny_atari_cnn_rlm example `__ + + You didn't implement any APIs in the preceding example module, because + you hadn't considered training it with any particular algorithm yet. + You can find examples of custom :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` classes + implementing the :py:class:`~ray.rllib.core.rl_module.apis.self_supervised_loss_api.SelfSupervisedLossAPI` and thus + ready to train with :py:class:`~ray.rllib.algorithms.ppo.PPO` in the + `tiny_atari_cnn_rlm example `__ and in the `lstm_containing_rlm example `__. +You can mix supervised losses into any RLlib algorithm through the :py:class:`~ray.rllib.core.rl_module.apis.self_supervised_loss_api.SelfSupervisedLossAPI`. +Your Learner actors automatically call the implemented +:py:meth:`~ray.rllib.core.rl_module.apis.self_supervised_loss_api.SelfSupervisedLossAPI.compute_self_supervised_loss` method to compute the model's own loss +passing it the outputs of the :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_train` call. + + +See here for an `example script utilizing a self-supervised loss RLModule `__. +Losses can be defined over either policy evaluation inputs, or data read from `offline storage `__. +Note that you may want to set the :py:attr:`~ray.rllib.core.rl_module.rl_module.RLModuleSpec.learner_only` attribute to ``True`` in your custom +:py:class:`~ray.rllib.rl_module.rl_module.RLModuleSpec` if you don't need the self-supervised model for collecting samples in your +:py:class:`~ray.rllib.env.env_runner.EnvRunner` actors. You may also need an extra Learner connector piece in this case make sure your +:py:class:`~ray.rllib.rl_module.rl_module.RLModule` receives data to learn. + + End-to-end example ~~~~~~~~~~~~~~~~~~~~~~~ @@ -643,6 +703,86 @@ override the following methods in the :py:class:`~ray.rllib.core.rl_module.rl_mo See `torch_distributions.py `__ for common distribution implementations. +Auto-regressive action distributions +++++++++++++++++++++++++++++++++++++ + +In an action space with multiple components, for example ``Tuple(a1, a2)``, you may want to condition the sampling of ``a2`` on the sampled value +of ``a1``, such that ``a2_sampled ~ P(a2 | a1_sampled, obs)``. Note that in the default, non-autoregressive case, RLlib would use a default +model in combination with an independent :py:class:`~ray.rllib.models.torch.torch_distributions.TorchMultiDistribution` and thus +sample ``a1`` and ``a2`` independently. This makes it impossible to learn in environments, in which one action component +should be sampled dependent on another action, already sampled, action component. +See an `example for a "correlated actions" environment `__ here. + +To write a custom :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` that samples the various +action components as previously described, you need to carefully implement its forward logic. + +Find an `example of such a autoregressive action model `__ here. + +You implement the main action sampling logic in the ``_forward_...()`` methods: + +.. literalinclude:: ../../../rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py + :language: python + :dedent: 4 + :start-after: __sphinx_begin__ + :end-before: __sphinx_end__ + + +.. TODO: Move this parametric paragraph back in here, once we have the example translated to the new API stack + Variable-length / Parametric Action Spaces + ++++++++++++++++++++++++++++++++++++++++++ + Custom models can be used to work with environments where (1) the set of valid actions `varies per step `__, and/or (2) the number of valid actions is `very large `__. The general idea is that the meaning of actions can be completely conditioned on the observation, i.e., the ``a`` in ``Q(s, a)`` becomes just a token in ``[0, MAX_AVAIL_ACTIONS)`` that only has meaning in the context of ``s``. This works with algorithms in the `DQN and policy-gradient families `__ and can be implemented as follows: + 1. The environment should return a mask and/or list of valid action embeddings as part of the observation for each step. To enable batching, the number of actions can be allowed to vary from 1 to some max number: + .. code-block:: python + class MyParamActionEnv(gym.Env): + def __init__(self, max_avail_actions): + self.action_space = Discrete(max_avail_actions) + self.observation_space = Dict({ + "action_mask": Box(0, 1, shape=(max_avail_actions, )), + "avail_actions": Box(-1, 1, shape=(max_avail_actions, action_embedding_sz)), + "real_obs": ..., + }) + 2. A custom model can be defined that can interpret the ``action_mask`` and ``avail_actions`` portions of the observation. Here the model computes the action logits via the dot product of some network output and each action embedding. Invalid actions can be masked out of the softmax by scaling the probability to zero: + .. code-block:: python + class ParametricActionsModel(TFModelV2): + def __init__(self, + obs_space, + action_space, + num_outputs, + model_config, + name, + true_obs_shape=(4,), + action_embed_size=2): + super(ParametricActionsModel, self).__init__( + obs_space, action_space, num_outputs, model_config, name) + self.action_embed_model = FullyConnectedNetwork(...) + def forward(self, input_dict, state, seq_lens): + # Extract the available actions tensor from the observation. + avail_actions = input_dict["obs"]["avail_actions"] + action_mask = input_dict["obs"]["action_mask"] + # Compute the predicted action embedding + action_embed, _ = self.action_embed_model({ + "obs": input_dict["obs"]["cart"] + }) + # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the + # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. + intent_vector = tf.expand_dims(action_embed, 1) + # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS]. + action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2) + # Mask out invalid actions (use tf.float32.min for stability) + inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) + return action_logits + inf_mask, state + Depending on your use case it may make sense to use |just the masking|_, |just action embeddings|_, or |both|_. For a runnable example of "just action embeddings" in code, + check out `examples/parametric_actions_cartpole.py `__. + .. |just the masking| replace:: just the **masking** + .. _just the masking: https://github.com/ray-project/ray/blob/master/rllib/examples/_old_api_stack/models/action_mask_model.py + .. |just action embeddings| replace:: just action **embeddings** + .. _just action embeddings: https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_actions_cartpole.py + .. |both| replace:: **both** + .. _both: https://github.com/ray-project/ray/blob/master/rllib/examples/_old_api_stack/models/parametric_actions_model.py + Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``model.vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `algorithm overview `__. + + + Implementing custom MultiRLModules ---------------------------------- diff --git a/doc/source/rllib/rllib-advanced-api.rst b/doc/source/rllib/rllib-advanced-api.rst index 703c4b85bcbf5..b80c4f8969768 100644 --- a/doc/source/rllib/rllib-advanced-api.rst +++ b/doc/source/rllib/rllib-advanced-api.rst @@ -14,7 +14,7 @@ Tune will call ``train()`` on your algorithm once per training iteration and rep the new training results. Sometimes, it's desirable to have full control over training, but still run inside Tune. Tune supports using :ref:`custom trainable functions ` to -implement `custom training workflows (example) `__. +implement `custom training workflows (example) `__. Curriculum learning ~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/rllib/rllib-algorithms.rst b/doc/source/rllib/rllib-algorithms.rst index 3b88c4c265007..90f8a17363141 100644 --- a/doc/source/rllib/rllib-algorithms.rst +++ b/doc/source/rllib/rllib-algorithms.rst @@ -59,6 +59,7 @@ Proximal Policy Optimization (PPO) .. figure:: images/algos/ppo-architecture.svg :width: 750 + :align: left **PPO architecture:** In a training iteration, PPO performs three major steps: 1. Sampling a set of episodes or episode fragments @@ -74,7 +75,7 @@ Proximal Policy Optimization (PPO) `Pendulum-v1 `__. -**PPO-specific configs** (see also `common configs `__): +**PPO-specific configs** (see also :ref:`generic algorithm settings `): .. autoclass:: ray.rllib.algorithms.ppo.ppo.PPOConfig :members: training @@ -92,6 +93,7 @@ Deep Q Networks (DQN, Rainbow, Parametric DQN) .. figure:: images/algos/dqn-architecture.svg :width: 650 + :align: left **DQN architecture:** DQN uses a replay buffer to temporarily store episode samples that RLlib collects from the environment. Throughout different training iterations, these episodes and episode fragments are re-sampled from the buffer and re-used @@ -121,7 +123,7 @@ See also how to use `parametric-actions in DQN `__): +**DQN-specific configs** (see also :ref:`generic algorithm settings `): .. autoclass:: ray.rllib.algorithms.dqn.dqn.DQNConfig :members: training @@ -137,6 +139,7 @@ Soft Actor Critic (SAC) .. figure:: images/algos/sac-architecture.svg :width: 750 + :align: left **SAC architecture:** SAC uses a replay buffer to temporarily store episode samples that RLlib collects from the environment. Throughout different training iterations, these episodes and episode fragments are re-sampled from the buffer and re-used @@ -148,9 +151,9 @@ Soft Actor Critic (SAC) **Tuned examples:** `Pendulum-v1 `__, -`HalfCheetah-v3 `__, +`HalfCheetah-v3 `__, -**SAC-specific configs** (see also `common configs `__): +**SAC-specific configs** (see also :ref:`generic algorithm settings `): .. autoclass:: ray.rllib.algorithms.sac.sac.SACConfig :members: training @@ -173,6 +176,7 @@ Asynchronous Proximal Policy Optimization (APPO) .. figure:: images/algos/appo-architecture.svg :width: 750 + :align: left **APPO architecture:** APPO is an asynchronous variant of :ref:`Proximal Policy Optimization (PPO) ` based on the IMPALA architecture, but using a surrogate policy loss with clipping, allowing for multiple SGD passes per collected train batch. @@ -190,7 +194,7 @@ Asynchronous Proximal Policy Optimization (APPO) `Pong-v5 `__ `HalfCheetah-v4 `__ -**APPO-specific configs** (see also `common configs `__): +**APPO-specific configs** (see also :ref:`generic algorithm settings `): .. autoclass:: ray.rllib.algorithms.appo.appo.APPOConfig :members: training @@ -205,6 +209,7 @@ Importance Weighted Actor-Learner Architecture (IMPALA) .. figure:: images/algos/impala-architecture.svg :width: 750 + :align: left **IMPALA architecture:** In a training iteration, IMPALA requests samples from all EnvRunners asynchronously and the collected episodes are returned to the main algorithm process as Ray references rather than actual objects available on the local process. @@ -229,7 +234,7 @@ Tuned examples: The maximum training throughput reached is ~30k transitions per second (~120k environment frames per second). -**IMPALA-specific configs** (see also `common configs `__): +**IMPALA-specific configs** (see also :ref:`generic algorithm settings `): .. autoclass:: ray.rllib.algorithms.impala.impala.IMPALAConfig :members: training @@ -248,6 +253,7 @@ DreamerV3 .. figure:: images/algos/dreamerv3-architecture.svg :width: 850 + :align: left **DreamerV3 architecture:** DreamerV3 trains a recurrent WORLD_MODEL in supervised fashion using real environment interactions sampled from a replay buffer. The world model's objective @@ -308,6 +314,7 @@ Behavior Cloning (BC) .. figure:: images/algos/bc-architecture.svg :width: 750 + :align: left **BC architecture:** RLlib's behavioral cloning (BC) uses Ray Data to tap into its parallel data processing capabilities. In one training iteration, BC reads episodes in parallel from @@ -322,7 +329,7 @@ Behavior Cloning (BC) `CartPole-v1 `__ `Pendulum-v1 `__ -**BC-specific configs** (see also `common configs `__): +**BC-specific configs** (see also :ref:`generic algorithm settings `): .. autoclass:: ray.rllib.algorithms.bc.bc.BCConfig :members: training @@ -337,6 +344,7 @@ Conservative Q-Learning (CQL) .. figure:: images/algos/cql-architecture.svg :width: 750 + :align: left **CQL architecture:** CQL (Conservative Q-Learning) is an offline RL algorithm that mitigates the overestimation of Q-values outside the dataset distribution through a conservative critic estimate. It adds a simple Q regularizer loss to the standard @@ -347,7 +355,7 @@ Conservative Q-Learning (CQL) **Tuned examples:** `Pendulum-v1 `__ -**CQL-specific configs** and `common configs `__): +**CQL-specific configs** and :ref:`generic algorithm settings `): .. autoclass:: ray.rllib.algorithms.cql.cql.CQLConfig :members: training @@ -362,6 +370,7 @@ Monotonic Advantage Re-Weighted Imitation Learning (MARWIL) .. figure:: images/algos/marwil-architecture.svg :width: 750 + :align: left **MARWIL architecture:** MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data. When the ``beta`` hyperparameter is set to zero, the MARWIL objective reduces to plain @@ -374,7 +383,7 @@ Monotonic Advantage Re-Weighted Imitation Learning (MARWIL) **Tuned examples:** `CartPole-v1 `__ -**MARWIL-specific configs** (see also `common configs `__): +**MARWIL-specific configs** (see also :ref:`generic algorithm settings `): .. autoclass:: ray.rllib.algorithms.marwil.marwil.MARWILConfig :members: training @@ -389,10 +398,11 @@ Algorithm Extensions- and Plugins Curiosity-driven Exploration by Self-supervised Prediction ---------------------------------------------------------- `[paper] `__ -`[implementation] `__ +`[implementation] `__ .. figure:: images/algos/curiosity-architecture.svg :width: 850 + :align: left **Intrinsic Curiosity Model (ICM) architecture:** The main idea behind ICM is to train a world-model (in parallel to the "main" policy) to predict the environment's dynamics. The loss of diff --git a/doc/source/rllib/rllib-callback.rst b/doc/source/rllib/rllib-callback.rst index 060b42dd1a486..c7f017e60e816 100644 --- a/doc/source/rllib/rllib-callback.rst +++ b/doc/source/rllib/rllib-callback.rst @@ -296,13 +296,11 @@ computing the average "first-joint angle" of the `theta1` is the angle of the first joint, where an angle of 0.0 indicates that the first link is pointing directly downwards. -This example utilizes RLlib's :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` -API to log the custom computations of the injected code of the Algorithm's main results system. +This example utilizes RLlib's :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` API to log the custom computations +of the injected code. See :ref:`rllib-metric-logger-docs` for more details about the MetricsLogger API. -.. todo: uncomment this once metrics-logger.rst page is online. - Read :ref:`more about the MetricsLogger API here `__ or also ... - -See this more complex example that `generates and logs a PacMan heatmap (image) to WandB `__. +Also, see this more complex example that +`generates and logs a PacMan heatmap (image) to WandB `__. .. testcode:: diff --git a/doc/source/rllib/rllib-env.rst b/doc/source/rllib/rllib-env.rst index d42576bcc734e..349af48a67612 100644 --- a/doc/source/rllib/rllib-env.rst +++ b/doc/source/rllib/rllib-env.rst @@ -63,6 +63,7 @@ action choices eventually maximize the cumulative reward over the agent's lifeti .. figure:: images/envs/single_agent_setup.svg :width: 600 + :align: left **Single-agent setup:** One agent lives in the environment and takes actions computed by a single policy. The mapping from agent to policy is fixed ("default_agent" maps to "default_policy"). @@ -272,6 +273,7 @@ Performance and Scaling .. figure:: images/envs/env_runners.svg :width: 600 + :align: left **EnvRunner with gym.Env setup:** Environments in RLlib are located within the :py:class:`~ray.rllib.envs.env_runner.EnvRunner` actors, whose number (`n`) you can scale through the `config.env_runners(num_env_runners=..)` setting. Each :py:class:`~ray.rllib.envs.env_runner.EnvRunner` actor diff --git a/doc/source/rllib/rllib-examples.rst b/doc/source/rllib/rllib-examples.rst index d86a461430cdc..5787a525f3709 100644 --- a/doc/source/rllib/rllib-examples.rst +++ b/doc/source/rllib/rllib-examples.rst @@ -238,7 +238,7 @@ Inference of models or policies Learners ++++++++ -- `Custom loss function, simple `__: +- `Custom loss function, simple `__: Implements a custom loss function for training, demonstrating how users can define tailored loss objectives for specific environments or behaviors. diff --git a/doc/source/rllib/rllib-learner.rst b/doc/source/rllib/rllib-learner.rst index 5d3b5ca59dd96..7bf147fc11e97 100644 --- a/doc/source/rllib/rllib-learner.rst +++ b/doc/source/rllib/rllib-learner.rst @@ -242,13 +242,9 @@ Updates results = learner_group.update_from_batch( batch=DUMMY_BATCH, async_update=True, timesteps=TIMESTEPS ) - # `results` is a list of n items (where n is the number of async results collected). + # `results` is a list of n result dicts from various Learner actors. assert isinstance(results, list), results - # Each item in that list is another list of m items (where m is the number of Learner - # workers). - assert isinstance(results[0], list), results - # Each item in the inner list is a result dict from the Learner worker. - assert isinstance(results[0][0], dict), results + assert isinstance(results[0], dict), results When updating a :py:class:`~ray.rllib.core.learner.learner_group.LearnerGroup` you can perform blocking or async updates on batches of data. Async updates are necessary for implementing async algorithms such as APPO/IMPALA. diff --git a/doc/source/rllib/rllib-models.rst b/doc/source/rllib/rllib-models.rst deleted file mode 100644 index 8b475081af3bc..0000000000000 --- a/doc/source/rllib/rllib-models.rst +++ /dev/null @@ -1,639 +0,0 @@ -.. include:: /_includes/rllib/we_are_hiring.rst - -.. include:: /_includes/rllib/new_api_stack.rst - -.. _rllib-models-walkthrough: - -Models, Preprocessors, and Action Distributions -=============================================== - -The following diagram provides a conceptual overview of data flow between different components in RLlib. -We start with an ``Environment``, which - given an action - produces an observation. -The observation is preprocessed by a ``Preprocessor`` and ``Filter`` (e.g. for running mean normalization) -before being sent to a neural network ``Model``. The model output is in turn -interpreted by an ``ActionDistribution`` to determine the next action. - -.. image:: images/rllib-components.svg - -The components highlighted in green can be replaced with custom user-defined -implementations, as described in the next sections. The purple components are -RLlib internal, which means they can only be modified by changing the algorithm -source code. - -Default Behaviors ------------------ - -Built-in Preprocessors -~~~~~~~~~~~~~~~~~~~~~~ - -RLlib tries to pick one of its built-in preprocessors based on the environment's -observation space. Thereby, the following simple rules apply: - -- Discrete observations are one-hot encoded, e.g. ``Discrete(3) and value=1 -> [0, 1, 0]``. - -- MultiDiscrete observations are encoded by one-hot encoding each discrete element - and then concatenating the respective one-hot encoded vectors. - e.g. ``MultiDiscrete([3, 4]) and value=[1, 3] -> [0 1 0 0 0 0 1]`` because - the first ``1`` is encoded as ``[0 1 0]`` and the second ``3`` is encoded as - ``[0 0 0 1]``; these two vectors are then concatenated to ``[0 1 0 0 0 0 1]``. - -- Tuple and Dict observations are flattened, thereby, Discrete and MultiDiscrete - sub-spaces are handled as described above. - Also, the original dict/tuple observations are still available inside a) the Model via the input - dict's "obs" key (the flattened observations are in "obs_flat"), as well as b) the Policy - via the following line of code (e.g. put this into your loss function to access the original - observations: ``dict_or_tuple_obs = restore_original_dimensions(input_dict["obs"], self.obs_space, "tf|torch")`` - -For Atari observation spaces, RLlib defaults to using the `DeepMind preprocessors `__ -(``preprocessor_pref=deepmind``). However, if the Algorithm's config key ``preprocessor_pref`` is set to "rllib", -the following mappings apply for Atari-type observation spaces: - -- Images of shape ``(210, 160, 3)`` are downscaled to ``dim x dim``, where - ``dim`` is a model config key (see default Model config below). Also, you can set - ``grayscale=True`` for reducing the color channel to 1, or ``zero_mean=True`` for - producing -1.0 to 1.0 values (instead of 0.0 to 1.0 values by default). - -- Atari RAM observations (1D space of shape ``(128, )``) are zero-averaged - (values between -1.0 and 1.0). - -In all other cases, no preprocessor will be used and the raw observations from the environment -will be sent directly into your model. - - -Default Model Config Settings -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In the following paragraphs, we will first describe RLlib's default behavior for automatically constructing -models (if you don't setup a custom one), then dive into how you can customize your models by changing these -settings or writing your own model classes. - -By default, RLlib will use the following config settings for your models. -These include options for the ``FullyConnectedNetworks`` (``fcnet_hiddens`` and ``fcnet_activation``), -``VisionNetworks`` (``conv_filters`` and ``conv_activation``), auto-RNN wrapping, auto-Attention (`GTrXL `__) wrapping, -and some special options for Atari environments: - -.. literalinclude:: ../../../rllib/models/catalog.py - :language: python - :start-after: __sphinx_doc_begin__ - :end-before: __sphinx_doc_end__ - -The dict above (or an overriding sub-set) is handed to the Algorithm via the ``model`` key within -the main config dict like so: - -.. code-block:: python - - algo_config = { - # All model-related settings go into this sub-dict. - "model": { - # By default, the MODEL_DEFAULTS dict above will be used. - - # Change individual keys in that dict by overriding them, e.g. - "fcnet_hiddens": [512, 512, 512], - "fcnet_activation": "relu", - }, - - # ... other Algorithm config keys, e.g. "lr" ... - "lr": 0.00001, - } - - -Built-in Models -~~~~~~~~~~~~~~~ - -After preprocessing (if applicable) the raw environment outputs, the processed observations are fed through the policy's model. -In case, no custom model is specified (see further below on how to customize models), RLlib will pick a default model -based on simple heuristics: - -- A vision network (`TF `__ or `Torch `__) - for observations that have a shape of length larger than 2, for example, ``(84 x 84 x 3)``. -- A fully connected network (`TF `__ or `Torch `__) - for everything else. - -These default model types can further be configured via the ``model`` config key inside your Algorithm config (as discussed above). -Available settings are `listed above <#default-model-config-settings>`__ and also documented in the `model catalog file `__. - -Note that for the vision network case, you'll probably have to configure ``conv_filters``, if your environment observations -have custom sizes. For example, ``"model": {"dim": 42, "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], [512, [11, 11], 1]]}`` for 42x42 observations. -Thereby, always make sure that the last Conv2D output has an output shape of ``[B, 1, 1, X]`` (``[B, X, 1, 1]`` for PyTorch), where B=batch and -X=last Conv2D layer's number of filters, so that RLlib can flatten it. An informative error will be thrown if this isn't the case. - - -.. _auto_lstm_and_attention: - -Built-in auto-LSTM, and auto-Attention Wrappers -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In addition, if you set ``"use_lstm": True`` or ``"use_attention": True`` in your model config, -your model's output will be further processed by an LSTM cell -(`TF `__ or `Torch `__), -or an attention (`GTrXL `__) network -(`TF `__ or -`Torch `__), respectively. -More generally, RLlib supports the use of recurrent/attention models for all -its policy-gradient algorithms (A3C, PPO, PG, IMPALA), and the necessary sequence processing support -is built into its policy evaluation utilities. - -See above for which additional config keys to use to configure in more detail these two auto-wrappers -(e.g. you can specify the size of the LSTM layer by ``lstm_cell_size`` or the attention dim by ``attention_dim``). - -For fully customized RNN/LSTM/Attention-Net setups see the `Recurrent Models <#rnns>`_ and -`Attention Networks/Transformers <#attention>`_ sections below. - -.. note:: - It isn't possible to use both auto-wrappers (lstm and attention) at the same time. Doing so will create an error. - - -Customizing Preprocessors and Models ------------------------------------- - -Custom Preprocessors and Environment Filters -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. warning:: - - Custom preprocessors have been fully deprecated, since they sometimes conflict with the built-in preprocessors for handling complex observation spaces. - Please use `wrapper classes `__ around your environment instead of preprocessors. - Note that the built-in **default** Preprocessors described above will still be used and won't be deprecated. - -Instead of using the deprecated custom Preprocessors, you should use ``gym.Wrappers`` to preprocess your environment's output (observations and rewards), -but also your Model's computed actions before sending them back to the environment. - -For example, for manipulating your env's observations or rewards, do: - -.. code-block:: python - - import gymnasium as gym - from ray.rllib.utils.numpy import one_hot - - class OneHotEnv(gym.core.ObservationWrapper): - # Override `observation` to custom process the original observation - # coming from the env. - def observation(self, observation): - # E.g. one-hotting a float obs [0.0, 5.0[. - return one_hot(observation, depth=5) - - - class ClipRewardEnv(gym.core.RewardWrapper): - def __init__(self, env, min_, max_): - super().__init__(env) - self.min = min_ - self.max = max_ - - # Override `reward` to custom process the original reward coming - # from the env. - def reward(self, reward): - # E.g. simple clipping between min and max. - return np.clip(reward, self.min, self.max) - - -Custom Models: Implementing your own Forward Logic -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If you would like to provide your own model logic (instead of using RLlib's built-in defaults), you -can sub-class either ``TFModelV2`` (for TensorFlow) or ``TorchModelV2`` (for PyTorch) and then -register and specify your sub-class in the config as follows: - -.. _tensorflow-models: - -Custom TensorFlow Models -```````````````````````` - -Custom TensorFlow models should subclass `TFModelV2 `__ and implement the ``__init__()`` and ``forward()`` methods. -``forward()`` takes a dict of tensor inputs (mapping str to Tensor types), whose keys and values depend on -the view requirements of the model. -Normally, this input dict contains only the current observation ``obs`` and an ``is_training`` boolean flag, as well as an optional list of RNN states. -``forward()`` should return the model output (of size ``self.num_outputs``) and - if applicable - a new list of internal -states (in case of RNNs or attention nets). You can also override extra methods of the model such as ``value_function`` to implement -a custom value branch. - -Additional supervised/self-supervised losses can be added via the ``TFModelV2.custom_loss`` method: - -.. autoclass:: ray.rllib.models.tf.tf_modelv2.TFModelV2 - :members: - :noindex: - -Once implemented, your TF model can then be registered and used in place of a built-in default one: - -.. code-block:: python - - import ray - import ray.rllib.algorithms.ppo as ppo - from ray.rllib.models import ModelCatalog - from ray.rllib.models.tf.tf_modelv2 import TFModelV2 - - class MyModelClass(TFModelV2): - def __init__(self, obs_space, action_space, num_outputs, model_config, name): ... - def forward(self, input_dict, state, seq_lens): ... - def value_function(self): ... - - ModelCatalog.register_custom_model("my_tf_model", MyModelClass) - - ray.init() - algo = ppo.PPO(env="CartPole-v1", config={ - "model": { - "custom_model": "my_tf_model", - # Extra kwargs to be passed to your model's c'tor. - "custom_model_config": {}, - }, - }) - -See the `keras model example `__ for a full example of a TF custom model. - -More examples and explanations on how to implement custom Tuple/Dict processing models -(also check out `this test case here `__), -custom RNNs, custom model APIs (on top of default models) follow further below. - -.. _torch-models: - -Custom PyTorch Models -````````````````````` - -Similarly, you can create and register custom PyTorch models by subclassing -`TorchModelV2 `__ and implement the ``__init__()`` and ``forward()`` methods. -``forward()`` takes a dict of tensor inputs (mapping str to PyTorch tensor types), whose keys and values depend on -the view requirements of the model. -Usually, the dict contains only the current observation ``obs`` and an ``is_training`` boolean flag, as well as an optional list of RNN states. -``forward()`` should return the model output (of size ``self.num_outputs``) and - if applicable - a new list of internal -states (in case of RNNs or attention nets). You can also override extra methods of the model such as ``value_function`` to implement -a custom value branch. - -Additional supervised/self-supervised losses can be added via the ``TorchModelV2.custom_loss`` method: - -See these examples of `fully connected `__, `convolutional `__, and `recurrent `__ torch models. - -.. autoclass:: ray.rllib.models.torch.torch_modelv2.TorchModelV2 - :members: - :noindex: - -Once implemented, your PyTorch model can then be registered and used in place of a built-in model: - -.. code-block:: python - - import torch.nn as nn - - import ray - from ray.rllib.algorithms import ppo - from ray.rllib.models import ModelCatalog - from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 - - class CustomTorchModel(TorchModelV2): - def __init__(self, obs_space, action_space, num_outputs, model_config, name): ... - def forward(self, input_dict, state, seq_lens): ... - def value_function(self): ... - - ModelCatalog.register_custom_model("my_torch_model", CustomTorchModel) - - ray.init() - algo = ppo.PPO(env="CartPole-v1", config={ - "framework": "torch", - "model": { - "custom_model": "my_torch_model", - # Extra kwargs to be passed to your model's c'tor. - "custom_model_config": {}, - }, - }) - -See the `torch model examples `__ for various examples on how to build a custom -PyTorch model (including recurrent ones). - -More examples and explanations on how to implement custom Tuple/Dict processing models (also check out `this test case here `__), -custom RNNs, custom model APIs (on top of default models) follow further below. - - -Implementing custom Recurrent Networks -`````````````````````````````````````` - -Instead of using the ``use_lstm: True`` option, it may be preferable to use a custom recurrent model. -This provides more control over postprocessing the LSTM's output and can also allow the use of multiple LSTM cells to process different portions of the input. -For an RNN model it is recommended to subclass ``RecurrentNetwork`` (either the `TF `__ -or `PyTorch `__ versions) and then implement ``__init__()``, -``get_initial_state()``, and ``forward_rnn()``. - -.. autoclass:: ray.rllib.models.tf.recurrent_net.RecurrentNetwork - - .. automethod:: __init__ - .. automethod:: get_initial_state - .. automethod:: forward_rnn - -Note that the ``inputs`` arg entering ``forward_rnn`` is already a time-ranked single tensor (not an ``input_dict``!) with shape ``(B x T x ...)``. -If you further want to customize and need more direct access to the complete (non time-ranked) ``input_dict``, you can also override -your Model's ``forward`` method directly (as you would do with a non-RNN ModelV2). In that case, though, you are responsible for changing your inputs -and add the time rank to the incoming data (usually you just have to reshape). - -You can check out the `rnn_model.py `__ models as examples to implement -your own (either TF or Torch). - - -.. _attention: - -Implementing custom Attention Networks -`````````````````````````````````````` - -Similar to the RNN case described above, you could also implement your own attention-based networks, instead of using the -``use_attention: True`` flag in your model config. - -See RLlib's `GTrXL (Attention Net) `__ implementations -(for `TF `__ and `PyTorch `__) -to get a better idea on how to write your own models of this type. These are the models we use -as wrappers when ``use_attention=True``. - -This `test case `__ confirms their learning capabilities in PPO and IMPALA. - -Batch Normalization -``````````````````` - -You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model. -See this `code example `__). -RLlib will automatically run the update ops for the batch norm layers during optimization -(see `tf_policy.py `__ and -`multi_gpu_learner_thread.py `__ for the exact handling of these updates). - -In case RLlib does not properly detect the update ops for your custom model, you can override the ``update_ops()`` method to return the list of ops to run for updates. - - -Custom Model APIs (on Top of Default- or Custom Models) -``````````````````````````````````````````````````````` - -So far we talked about a) the default models that are built into RLlib and are being provided -automatically if you don't specify anything in your Algorithm's config and b) custom Models through -which you can define any arbitrary forward passes. - -Another typical situation in which you would have to customize a model would be to -add a new API that your algorithm needs in order to learn, for example a Q-value -calculating head on top of your policy model. In order to expand a Model's API, simply -define and implement a new method (e.g. ``get_q_values()``) in your TF- or TorchModelV2 sub-class. - -You can now wrap this new API either around RLlib's default models or around -your custom (``forward()``-overriding) model classes. - - -More examples for Building Custom Models -```````````````````````````````````````` - -**A multi-input capable model for Tuple observation spaces (for PPO)** - -RLlib's default preprocessor for Tuple and Dict spaces is to flatten incoming observations -into one flat **1D** array, and then pick a fully connected network (by default) to -process this flattened vector. This is usually ok, if you have only 1D Box or -Discrete/MultiDiscrete sub-spaces in your observations. - -However, what if you had a complex observation space with one or more image components in -it (besides 1D Boxes and discrete spaces). You would probably want to preprocess each of the -image components using some convolutional network, and then concatenate their outputs -with the remaining non-image (flat) inputs (the 1D Box and discrete/one-hot components). - -Take a look at this model example that does exactly that: - -.. literalinclude:: ../../../rllib/models/tf/complex_input_net.py - :language: python - :start-after: __sphinx_doc_begin__ - :end-before: __sphinx_doc_end__ - - - -Custom Action Distributions ---------------------------- - -Similar to custom models and preprocessors, you can also specify a custom action distribution class as follows. The action dist class is passed a reference to the ``model``, which you can use to access ``model.model_config`` or other attributes of the model. This is commonly used to implement `autoregressive action outputs <#autoregressive-action-distributions>`__. - -.. code-block:: python - - import ray - import ray.rllib.algorithms.ppo as ppo - from ray.rllib.models import ModelCatalog - from ray.rllib.models.preprocessors import Preprocessor - - class MyActionDist(ActionDistribution): - @staticmethod - def required_model_output_shape(action_space, model_config): - return 7 # controls model output feature vector size - - def __init__(self, inputs, model): - super(MyActionDist, self).__init__(inputs, model) - assert model.num_outputs == 7 - - def sample(self): ... - def logp(self, actions): ... - def entropy(self): ... - - ModelCatalog.register_custom_action_dist("my_dist", MyActionDist) - - ray.init() - algo = ppo.PPO(env="CartPole-v1", config={ - "model": { - "custom_action_dist": "my_dist", - }, - }) - -Supervised Model Losses ------------------------ - -You can mix supervised losses into any RLlib algorithm through custom models. For example, you can add an imitation learning loss on expert experiences, or a self-supervised autoencoder loss within the model. These losses can be defined over either policy evaluation inputs, or data read from `offline storage `__. - -**TensorFlow**: To add a supervised loss to a custom TF model, you need to override the ``custom_loss()`` method. This method takes in the existing policy loss for the algorithm, which you can add your own supervised loss to before returning. For debugging, you can also return a dictionary of scalar tensors in the ``metrics()`` method. - -**PyTorch**: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling ``reader.next()`` in the loss forward pass. - -Self-Supervised Model Losses ----------------------------- - -You can also use the ``custom_loss()`` API to add in self-supervised losses such as VAE reconstruction loss and L2-regularization. - -Variable-length / Complex Observation Spaces --------------------------------------------- - -RLlib supports complex and variable-length observation spaces, including ``gym.spaces.Tuple``, ``gym.spaces.Dict``, and ``rllib.utils.spaces.Repeated``. The handling of these spaces is transparent to the user. RLlib internally will insert preprocessors to insert padding for repeated elements, flatten complex observations into a fixed-size vector during transit, and unpack the vector into the structured tensor before sending it to the model. The flattened observation is available to the model as ``input_dict["obs_flat"]``, and the unpacked observation as ``input_dict["obs"]``. - -To enable batching of struct observations, RLlib unpacks them in a `StructTensor-like format `__. In summary, repeated fields are "pushed down" and become the outer dimensions of tensor batches, as illustrated in this figure from the StructTensor RFC. - -.. image:: images/struct-tensor.png - -For further information about complex observation spaces, see: - * A custom environment and model that uses `repeated struct fields `__. - * The pydoc of the `Repeated space `__. - * The pydoc of the batched `repeated values tensor `__. - * The `unit tests `__ for Tuple and Dict spaces. - -Variable-length / Parametric Action Spaces ------------------------------------------- - -Custom models can be used to work with environments where (1) the set of valid actions `varies per step `__, and/or (2) the number of valid actions is `very large `__. The general idea is that the meaning of actions can be completely conditioned on the observation, i.e., the ``a`` in ``Q(s, a)`` becomes just a token in ``[0, MAX_AVAIL_ACTIONS)`` that only has meaning in the context of ``s``. This works with algorithms in the `DQN and policy-gradient families `__ and can be implemented as follows: - -1. The environment should return a mask and/or list of valid action embeddings as part of the observation for each step. To enable batching, the number of actions can be allowed to vary from 1 to some max number: - -.. code-block:: python - - class MyParamActionEnv(gym.Env): - def __init__(self, max_avail_actions): - self.action_space = Discrete(max_avail_actions) - self.observation_space = Dict({ - "action_mask": Box(0, 1, shape=(max_avail_actions, )), - "avail_actions": Box(-1, 1, shape=(max_avail_actions, action_embedding_sz)), - "real_obs": ..., - }) - -2. A custom model can be defined that can interpret the ``action_mask`` and ``avail_actions`` portions of the observation. Here the model computes the action logits via the dot product of some network output and each action embedding. Invalid actions can be masked out of the softmax by scaling the probability to zero: - -.. code-block:: python - - class ParametricActionsModel(TFModelV2): - def __init__(self, - obs_space, - action_space, - num_outputs, - model_config, - name, - true_obs_shape=(4,), - action_embed_size=2): - super(ParametricActionsModel, self).__init__( - obs_space, action_space, num_outputs, model_config, name) - self.action_embed_model = FullyConnectedNetwork(...) - - def forward(self, input_dict, state, seq_lens): - # Extract the available actions tensor from the observation. - avail_actions = input_dict["obs"]["avail_actions"] - action_mask = input_dict["obs"]["action_mask"] - - # Compute the predicted action embedding - action_embed, _ = self.action_embed_model({ - "obs": input_dict["obs"]["cart"] - }) - - # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the - # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. - intent_vector = tf.expand_dims(action_embed, 1) - - # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS]. - action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2) - - # Mask out invalid actions (use tf.float32.min for stability) - inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) - return action_logits + inf_mask, state - - -Depending on your use case it may make sense to use |just the masking|_, |just action embeddings|_, or |both|_. For a runnable example of "just action embeddings" in code, -check out `examples/parametric_actions_cartpole.py `__. - -.. |just the masking| replace:: just the **masking** -.. _just the masking: https://github.com/ray-project/ray/blob/master/rllib/examples/_old_api_stack/models/action_mask_model.py -.. |just action embeddings| replace:: just action **embeddings** -.. _just action embeddings: https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_actions_cartpole.py -.. |both| replace:: **both** -.. _both: https://github.com/ray-project/ray/blob/master/rllib/examples/_old_api_stack/models/parametric_actions_model.py - -Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``model.vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `algorithm overview `__. - - -Autoregressive Action Distributions ------------------------------------ - -In an action space with multiple components (e.g., ``Tuple(a1, a2)``), you might want ``a2`` to be conditioned on the sampled value of ``a1``, i.e., ``a2_sampled ~ P(a2 | a1_sampled, obs)``. Normally, ``a1`` and ``a2`` would be sampled independently, reducing the expressivity of the policy. - -To do this, you need both a custom model that implements the autoregressive pattern, and a custom action distribution class that leverages that model. The `autoregressive_action_dist.py `__ example shows how this can be implemented for a simple binary action space. For a more complex space, a more efficient architecture such as a `MADE `__ is recommended. Note that sampling a `N-part` action requires `N` forward passes through the model, however computing the log probability of an action can be done in one pass: - -.. code-block:: python - - class BinaryAutoregressiveOutput(ActionDistribution): - """Action distribution P(a1, a2) = P(a1) * P(a2 | a1)""" - - @staticmethod - def required_model_output_shape(self, model_config): - return 16 # controls model output feature vector size - - def sample(self): - # first, sample a1 - a1_dist = self._a1_distribution() - a1 = a1_dist.sample() - - # sample a2 conditioned on a1 - a2_dist = self._a2_distribution(a1) - a2 = a2_dist.sample() - - # return the action tuple - return TupleActions([a1, a2]) - - def logp(self, actions): - a1, a2 = actions[:, 0], actions[:, 1] - a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1) - a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec]) - return (Categorical(a1_logits, None).logp(a1) + Categorical( - a2_logits, None).logp(a2)) - - def _a1_distribution(self): - BATCH = tf.shape(self.inputs)[0] - a1_logits, _ = self.model.action_model( - [self.inputs, tf.zeros((BATCH, 1))]) - a1_dist = Categorical(a1_logits, None) - return a1_dist - - def _a2_distribution(self, a1): - a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1) - _, a2_logits = self.model.action_model([self.inputs, a1_vec]) - a2_dist = Categorical(a2_logits, None) - return a2_dist - - class AutoregressiveActionsModel(TFModelV2): - """Implements the `.action_model` branch required above.""" - - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): - super(AutoregressiveActionsModel, self).__init__( - obs_space, action_space, num_outputs, model_config, name) - if action_space != Tuple([Discrete(2), Discrete(2)]): - raise ValueError( - "This model only supports the [2, 2] action space") - - # Inputs - obs_input = tf.keras.layers.Input( - shape=obs_space.shape, name="obs_input") - a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input") - ctx_input = tf.keras.layers.Input( - shape=(num_outputs, ), name="ctx_input") - - # Output of the model (normally 'logits', but for an autoregressive - # dist this is more like a context/feature layer encoding the obs) - context = tf.keras.layers.Dense( - num_outputs, - name="hidden", - activation=tf.nn.tanh, - kernel_initializer=normc_initializer(1.0))(obs_input) - - # P(a1 | obs) - a1_logits = tf.keras.layers.Dense( - 2, - name="a1_logits", - activation=None, - kernel_initializer=normc_initializer(0.01))(ctx_input) - - # P(a2 | a1) - # --note: typically you'd want to implement P(a2 | a1, obs) as follows: - # a2_context = tf.keras.layers.Concatenate(axis=1)( - # [ctx_input, a1_input]) - a2_context = a1_input - a2_hidden = tf.keras.layers.Dense( - 16, - name="a2_hidden", - activation=tf.nn.tanh, - kernel_initializer=normc_initializer(1.0))(a2_context) - a2_logits = tf.keras.layers.Dense( - 2, - name="a2_logits", - activation=None, - kernel_initializer=normc_initializer(0.01))(a2_hidden) - - # Base layers - self.base_model = tf.keras.Model(obs_input, context) - self.register_variables(self.base_model.variables) - self.base_model.summary() - - # Autoregressive action sampler - self.action_model = tf.keras.Model([ctx_input, a1_input], - [a1_logits, a2_logits]) - self.action_model.summary() - self.register_variables(self.action_model.variables) - - - -.. note:: - - Not all algorithms support autoregressive action distributions; see the `algorithm overview table `__ for more information. diff --git a/doc/source/rllib/rllib-torch2x.rst b/doc/source/rllib/rllib-torch2x.rst index 06c7476e77bd8..11d9afc2bad49 100644 --- a/doc/source/rllib/rllib-torch2x.rst +++ b/doc/source/rllib/rllib-torch2x.rst @@ -6,7 +6,7 @@ Using RLlib with torch 2.x compile ================================== -torch 2.x comes with the ``torch.compile()`` `API `_, which can be used to JIT-compile wrapped code. We integrate ``torch.compile()`` with RLlib in the context of `RLModules `_ and Learners. +torch 2.x comes with the ``torch.compile()`` `API `_, which can be used to JIT-compile wrapped code. We integrate ``torch.compile()`` with RLlib in the context of `RLModules `_ and Learners. We have integrated this feature with RLModules. You can set the backend and mode via ``framework()`` API on an :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig` object. Alternatively, you can compile the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` directly during stand-alone usage, such as inference. diff --git a/doc/source/rllib/rllib-training.rst b/doc/source/rllib/rllib-training.rst deleted file mode 100644 index b52f7a50816af..0000000000000 --- a/doc/source/rllib/rllib-training.rst +++ /dev/null @@ -1,248 +0,0 @@ -.. include:: /_includes/rllib/we_are_hiring.rst - -.. include:: /_includes/rllib/new_api_stack.rst - -.. _rllib-getting-started: - -Getting Started with RLlib -========================== - -All RLlib experiments are run using an ``Algorithm`` class which holds a policy for environment interaction. -Through the algorithm's interface, you can train the policy, compute actions, or store your algorithm's state (checkpointing). -In multi-agent training, the algorithm manages the querying and optimization of multiple policies at once. - -.. image:: images/rllib-api.svg - -In this guide, we will explain in detail RLlib's Python API for running learning experiments. - - -.. _rllib-training-api: - -Using the Python API --------------------- - -The Python API provides all the flexibility required for applying RLlib to any type of problem. - -Let's start with an example of the API's basic usage. -We first create a `PPOConfig` instance and set some properties through the config class' various methods. -For example, we can set the RL environment we want to use by calling the config's `environment` method. -To scale our algorithm and define, how many environment workers (EnvRunners) we want to leverage, we can call -the `env_runners` method. -After we `build` the `PPO` Algorithm from its configuration, we can `train` it for a number of -iterations (here `10`) and `save` the resulting policy periodically (here every `5` iterations). - -.. literalinclude:: ./doc_code/getting_started.py - :language: python - :start-after: rllib-first-config-begin - :end-before: rllib-first-config-end - - -All RLlib algorithms are compatible with the :ref:`Tune API `. -This enables them to be easily used in experiments with :ref:`Ray Tune `. -For example, the following code performs a simple hyper-parameter sweep of PPO. - - -.. literalinclude:: ./doc_code/getting_started.py - :dedent: 4 - :language: python - :start-after: rllib-tune-config-begin - :end-before: rllib-tune-config-end - -Tune will schedule the trials to run in parallel on your Ray cluster: - -:: - - == Status == - Using FIFO scheduling algorithm. - Resources requested: 4/4 CPUs, 0/0 GPUs - Result logdir: ~/ray_results/my_experiment - PENDING trials: - - PPO_CartPole-v1_2_lr=0.0001: PENDING - RUNNING trials: - - PPO_CartPole-v1_0_lr=0.01: RUNNING [pid=21940], 16 s, 4013 ts, 22 rew - - PPO_CartPole-v1_1_lr=0.001: RUNNING [pid=21942], 27 s, 8111 ts, 54.7 rew - -``Tuner.fit()`` returns an ``ResultGrid`` object that allows further analysis -of the training results and retrieving the checkpoint(s) of the trained agent. - -.. literalinclude:: ./doc_code/getting_started.py - :dedent: 0 - :language: python - :start-after: rllib-tuner-begin - :end-before: rllib-tuner-end - -.. note:: - - You can find your checkpoint's version by - looking into the ``rllib_checkpoint.json`` file inside your checkpoint directory. - -Loading and restoring a trained algorithm from a checkpoint is simple. -Let's assume you have a local checkpoint directory called ``checkpoint_path``. -To load newer RLlib checkpoints (version >= 1.0), use the following code: - - -.. code-block:: python - - from ray.rllib.algorithms.algorithm import Algorithm - algo = Algorithm.from_checkpoint(checkpoint_path) - - -For older RLlib checkpoint versions (version < 1.0), you can -restore an algorithm through: - -.. code-block:: python - - from ray.rllib.algorithms.ppo import PPO - algo = PPO(config=config, env=env_class) - algo.restore(checkpoint_path) - - -Computing Actions -~~~~~~~~~~~~~~~~~ - -The simplest way to programmatically compute actions from a trained agent is to -use ``Algorithm.compute_single_action()``. -This method preprocesses and filters the observation before passing it to the agent -policy. -Here is a simple example of testing a trained agent for one episode: - -.. literalinclude:: ./doc_code/getting_started.py - :language: python - :start-after: rllib-compute-action-begin - :end-before: rllib-compute-action-end - -For more advanced usage on computing actions and other functionality, -you can consult the :ref:`RLlib Algorithm API documentation `. - -Accessing Policy State -~~~~~~~~~~~~~~~~~~~~~~ - -It is common to need to access a algorithm's internal state, for instance to set -or get model weights. - -In RLlib algorithm state is replicated across multiple *rollout workers* (Ray actors) -in the cluster. -However, you can easily get and update this state between calls to ``train()`` -via ``Algorithm.env_runner_group.foreach_worker()`` -or ``Algorithm.env_runner_group.foreach_worker_with_index()``. -These functions take a lambda function that is applied with the worker as an argument. -These functions return values for each worker as a list. - -You can also access just the "master" copy of the algorithm state through -``Algorithm.get_policy()`` or ``Algorithm.env_runner``, -but note that updates here may not be immediately reflected in -your rollout workers (if you have configured ``num_env_runners > 0``). -Here's a quick example of how to access state of a model: - -.. literalinclude:: ./doc_code/getting_started.py - :language: python - :start-after: rllib-get-state-begin - :end-before: rllib-get-state-end - -Accessing Model State -~~~~~~~~~~~~~~~~~~~~~ - -Similar to accessing policy state, you may want to get a reference to the -underlying neural network model being trained. For example, you may want to -pre-train it separately, or otherwise update its weights outside of RLlib. -This can be done by accessing the ``model`` of the policy. - -.. note:: - - To run these examples, you need to install a few extra dependencies, namely - `pip install "gym[atari]" "gym[accept-rom-license]" atari_py`. - -Below you find three explicit examples showing how to access the model state of -an algorithm. - -.. dropdown:: **Example: Preprocessing observations for feeding into a model** - - - Then for the code: - - .. literalinclude:: doc_code/training.py - :language: python - :start-after: __preprocessing_observations_start__ - :end-before: __preprocessing_observations_end__ - -.. dropdown:: **Example: Querying a policy's action distribution** - - .. literalinclude:: doc_code/training.py - :language: python - :start-after: __query_action_dist_start__ - :end-before: __query_action_dist_end__ - -.. dropdown:: **Example: Getting Q values from a DQN model** - - .. literalinclude:: doc_code/training.py - :language: python - :start-after: __get_q_values_dqn_start__ - :end-before: __get_q_values_dqn_end__ - - This is especially useful when used with - `custom model classes `__. - - -Debugging RLlib Experiments ---------------------------- - -Eager Mode -~~~~~~~~~~ - -Policies built with ``build_tf_policy`` (most of the reference algorithms are) -can be run in eager mode by setting the -``"framework": "tf2"`` / ``"eager_tracing": true`` config options. -This will tell RLlib to execute the model forward pass, action distribution, -loss, and stats functions in eager mode. - -Eager mode makes debugging much easier, since you can now use line-by-line -debugging with breakpoints or Python ``print()`` to inspect -intermediate tensor values. -However, eager can be slower than graph mode unless tracing is enabled. - - -Episode Traces -~~~~~~~~~~~~~~ - -You can use the `data output API `__ to save episode traces -for debugging. For example, the following command will run PPO while saving episode -traces to ``/tmp/debug``. - -.. code-block:: bash - - cd rllib/tuned_examples/ppo - python cartpole_ppo.py --output /tmp/debug - - # episode traces will be saved in /tmp/debug, for example - output-2019-02-23_12-02-03_worker-2_0.json - output-2019-02-23_12-02-04_worker-1_0.json - -Log Verbosity -~~~~~~~~~~~~~ - -You can control the log level via the ``"log_level"`` flag. Valid values are "DEBUG", -"INFO", "WARN" (default), and "ERROR". This can be used to increase or decrease the -verbosity of internal logging. -For example: - -.. code-block:: bash - - cd rllib/tuned_examples/ppo - - python atari_ppo.py --env ALE/Pong-v5 --log-level INFO - python atari_ppo.py --env ALE/Pong-v5 --log-level DEBUG - -The default log level is ``WARN``. We strongly recommend using at least ``INFO`` -level logging for development. - -Stack Traces -~~~~~~~~~~~~ - -You can use the ``ray stack`` command to dump the stack traces of all the -Python workers on a single node. This can be useful for debugging unexpected -hangs or performance issues. - -Next Steps ----------- - -- To check how your application is doing, you can use the :ref:`Ray dashboard `. diff --git a/doc/source/rllib/scaling-guide.rst b/doc/source/rllib/scaling-guide.rst index e30770bd08113..3f80c824f5faa 100644 --- a/doc/source/rllib/scaling-guide.rst +++ b/doc/source/rllib/scaling-guide.rst @@ -158,7 +158,7 @@ Make sure to set the number of GPUs per :py:class:`~ray.rllib.core.learner.learn The number of GPUs may be fractional quantities, for example 0.5, to allocate only a fraction of a GPU per :py:class:`~ray.rllib.env.env_runner.EnvRunner`. For example, you can pack five :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` instances onto one GPU by setting ``num_learners=1, num_gpus_per_learner=0.2``. -See this `fractional GPU example `__ +See this `fractional GPU example `__ for details. .. note:: diff --git a/doc/source/rllib/user-guides.rst b/doc/source/rllib/user-guides.rst index 340be7ed93c6f..4b12fdf0fb3a4 100644 --- a/doc/source/rllib/user-guides.rst +++ b/doc/source/rllib/user-guides.rst @@ -12,12 +12,12 @@ User Guides rllib-advanced-api rllib-callback - rllib-models checkpoints + metrics-logger single-agent-episode rllib-replay-buffers rllib-offline - rllib-rlmodule + rl-modules rllib-learner rllib-torch2x rllib-fault-tolerance @@ -53,9 +53,9 @@ RLlib Feature Guides :img-top: /rllib/images/rllib-logo.svg :class-img-top: pt-2 w-75 d-block mx-auto fixed-height-img - .. button-ref:: rllib-models + .. button-ref:: metrics-logger - Working with models, preprocessors and action distributions + Logging metrics and statistics from custom code .. grid-item-card:: :img-top: /rllib/images/rllib-logo.svg diff --git a/doc/source/tune/api/syncing.rst b/doc/source/tune/api/syncing.rst index 8628fd6a25eaf..d3a1da18cdf0c 100644 --- a/doc/source/tune/api/syncing.rst +++ b/doc/source/tune/api/syncing.rst @@ -1,5 +1,5 @@ -Syncing in Tune (train.SyncConfig) -================================== +Syncing in Tune +=============== .. seealso:: @@ -13,6 +13,6 @@ Tune Syncing Configuration .. autosummary:: :nosignatures: + :toctree: doc/ - ray.train.SyncConfig - :noindex: + ~ray.tune.SyncConfig diff --git a/docker/retag-lambda/README.md b/docker/retag-lambda/README.md index 5a14279906a63..77b8a7ba41555 100644 --- a/docker/retag-lambda/README.md +++ b/docker/retag-lambda/README.md @@ -14,4 +14,3 @@ zip retag-lambda.zip * ``` 3. Head to the AWS Management console & select the `DockerTagLatest` function. Select `Upload from`, then `.zip file` and then select the zip file created in Step 2. - diff --git a/docker/retag-lambda/cuda_versions.txt b/docker/retag-lambda/cuda_versions.txt index 64e9c9e806d61..5e299664ff081 100644 --- a/docker/retag-lambda/cuda_versions.txt +++ b/docker/retag-lambda/cuda_versions.txt @@ -1,3 +1,4 @@ +cu125 cu124 cu123 cu121 diff --git a/java/generate_jni_header_files.sh b/java/generate_jni_header_files.sh index 64e4fa61d412d..97d86de83f24c 100755 --- a/java/generate_jni_header_files.sh +++ b/java/generate_jni_header_files.sh @@ -44,4 +44,3 @@ generate_one runtime/src/main/java/io/ray/runtime/metric/NativeMetric.java io.ra # Remove empty files rm -f io_ray_runtime_RayNativeRuntime_AsyncContext.h rm -f io_ray_runtime_task_NativeTaskExecutor_NativeActorContext.h - diff --git a/java/test/src/main/resources/test_cross_language_invocation.py b/java/test/src/main/resources/test_cross_language_invocation.py index fc6de702ab094..4394ff773570d 100644 --- a/java/test/src/main/resources/test_cross_language_invocation.py +++ b/java/test/src/main/resources/test_cross_language_invocation.py @@ -90,7 +90,7 @@ def py_func_pass_python_actor_handle(): @ray.remote def py_func_python_raise_exception(): - 1 / 0 + _ = 1 / 0 @ray.remote diff --git a/pyproject.toml b/pyproject.toml index 4e0f118d1c7a8..9769a99bc11af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,15 +30,12 @@ ignore = [ # TODO(MortalHappiness): Remove the following rules from the ignore list # The above are rules ignored originally in flake8 # The following are rules ignored in ruff - "F841", - "B018", "B023", "B024", "B026", "B027", "B035", "B904", - "C416", "C419", # Below are auto-fixable rules "I001", diff --git a/python/ray/_private/gcs_aio_client.py b/python/ray/_private/gcs_aio_client.py index 5ef7b64b1017c..79ac7766a2f59 100644 --- a/python/ray/_private/gcs_aio_client.py +++ b/python/ray/_private/gcs_aio_client.py @@ -45,3 +45,4 @@ def __init__( self.get_all_actor_info = self.inner.async_get_all_actor_info self.get_all_node_info = self.inner.async_get_all_node_info self.kill_actor = self.inner.async_kill_actor + self.get_cluster_status = self.inner.async_get_cluster_status diff --git a/python/ray/_private/ray_logging/constants.py b/python/ray/_private/ray_logging/constants.py index 54552bdfe1d75..6accad1200064 100644 --- a/python/ray/_private/ray_logging/constants.py +++ b/python/ray/_private/ray_logging/constants.py @@ -1,7 +1,8 @@ from enum import Enum # A set containing the standard attributes of a LogRecord. This is used to -# help us determine which attributes constitute Ray or user-provided context. +# help us determine which attributes constitute Ray or user-provided context. It is +# also be used to determine whether a attribute is a standard python logging attribute. # http://docs.python.org/library/logging.html#logrecord-attributes LOGRECORD_STANDARD_ATTRS = { "args", diff --git a/python/ray/_private/ray_logging/formatters.py b/python/ray/_private/ray_logging/formatters.py index 6324659dc4f02..ec16a629b87ee 100644 --- a/python/ray/_private/ray_logging/formatters.py +++ b/python/ray/_private/ray_logging/formatters.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod import logging import json from ray._private.log import INTERNAL_TIMESTAMP_LOG_KEY @@ -7,7 +8,7 @@ LOGGER_FLATTEN_KEYS, ) from ray._private.ray_constants import LOGGER_FORMAT -from typing import Any, Dict +from typing import Any, Dict, List def _append_flatten_attributes(formatted_attrs: Dict[str, Any], key: str, value: Any): @@ -30,60 +31,81 @@ def _append_flatten_attributes(formatted_attrs: Dict[str, Any], key: str, value: formatted_attrs[key] = value -def generate_record_format_attrs( - formatter: logging.Formatter, - record: logging.LogRecord, - exclude_standard_attrs, -) -> dict: - record_format_attrs = {} - - # If `exclude_standard_attrs` is False, include the standard attributes. - # Otherwise, include only Ray and user-provided context. - if not exclude_standard_attrs: - record_format_attrs.update( - { - LogKey.ASCTIME.value: formatter.formatTime(record), - LogKey.LEVELNAME.value: record.levelname, - LogKey.MESSAGE.value: record.getMessage(), - LogKey.FILENAME.value: record.filename, - LogKey.LINENO.value: record.lineno, - } - ) - if record.exc_info: - if not record.exc_text: - record.exc_text = formatter.formatException(record.exc_info) - record_format_attrs[LogKey.EXC_TEXT.value] = record.exc_text - - for key, value in record.__dict__.items(): - # Both Ray and user-provided context are stored in `record_format`. - if key not in LOGRECORD_STANDARD_ATTRS: - _append_flatten_attributes(record_format_attrs, key, value) - - # Format the internal timestamp to the standardized `timestamp_ns` key. - if INTERNAL_TIMESTAMP_LOG_KEY in record_format_attrs: - record_format_attrs[LogKey.TIMESTAMP_NS.value] = record_format_attrs.pop( - INTERNAL_TIMESTAMP_LOG_KEY - ) +class AbstractFormatter(logging.Formatter, ABC): + def __init__(self, fmt=None, datefmt=None, style="%", validate=True) -> None: + super().__init__(fmt, datefmt, style, validate) + self._additional_log_standard_attrs = [] + + def set_additional_log_standard_attrs( + self, additional_log_standard_attrs: List[str] + ) -> None: + self._additional_log_standard_attrs = additional_log_standard_attrs + + def generate_record_format_attrs( + self, + record: logging.LogRecord, + exclude_default_standard_attrs, + ) -> dict: + record_format_attrs = {} + + # If `exclude_default_standard_attrs` is False, include the standard attributes. + # Otherwise, include only Ray and user-provided context. + if not exclude_default_standard_attrs: + record_format_attrs.update( + { + LogKey.ASCTIME.value: self.formatTime(record), + LogKey.LEVELNAME.value: record.levelname, + LogKey.MESSAGE.value: record.getMessage(), + LogKey.FILENAME.value: record.filename, + LogKey.LINENO.value: record.lineno, + } + ) + if record.exc_info: + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + record_format_attrs[LogKey.EXC_TEXT.value] = record.exc_text + + # Add the user specified additional standard attributes. + for key in self._additional_log_standard_attrs: + _append_flatten_attributes( + record_format_attrs, key, getattr(record, key, None) + ) + + for key, value in record.__dict__.items(): + # Both Ray and user-provided context are stored in `record_format`. + if key not in LOGRECORD_STANDARD_ATTRS: + _append_flatten_attributes(record_format_attrs, key, value) - return record_format_attrs + # Format the internal timestamp to the standardized `timestamp_ns` key. + if INTERNAL_TIMESTAMP_LOG_KEY in record_format_attrs: + record_format_attrs[LogKey.TIMESTAMP_NS.value] = record_format_attrs.pop( + INTERNAL_TIMESTAMP_LOG_KEY + ) + return record_format_attrs -class JSONFormatter(logging.Formatter): - def format(self, record): - record_format_attrs = generate_record_format_attrs( - self, record, exclude_standard_attrs=False + @abstractmethod + def format(self, record: logging.LogRecord) -> str: + pass + + +class JSONFormatter(AbstractFormatter): + def format(self, record: logging.LogRecord) -> str: + record_format_attrs = self.generate_record_format_attrs( + record, exclude_default_standard_attrs=False ) return json.dumps(record_format_attrs) -class TextFormatter(logging.Formatter): - def __init__(self) -> None: +class TextFormatter(AbstractFormatter): + def __init__(self, fmt=None, datefmt=None, style="%", validate=True) -> None: + super().__init__(fmt, datefmt, style, validate) self._inner_formatter = logging.Formatter(LOGGER_FORMAT) def format(self, record: logging.LogRecord) -> str: s = self._inner_formatter.format(record) - record_format_attrs = generate_record_format_attrs( - self, record, exclude_standard_attrs=True + record_format_attrs = self.generate_record_format_attrs( + record, exclude_default_standard_attrs=True ) additional_attrs = " ".join( diff --git a/python/ray/_private/ray_logging/logging_config.py b/python/ray/_private/ray_logging/logging_config.py index 3318c19afabbe..babdc94e18ac4 100644 --- a/python/ray/_private/ray_logging/logging_config.py +++ b/python/ray/_private/ray_logging/logging_config.py @@ -2,11 +2,12 @@ from typing import Set from ray._private.ray_logging import default_impl +from ray._private.ray_logging.constants import LOGRECORD_STANDARD_ATTRS from ray._private.ray_logging.formatters import TextFormatter from ray._private.ray_logging.filters import CoreContextFilter from ray.util.annotations import PublicAPI -from dataclasses import dataclass +from dataclasses import dataclass, field import logging @@ -17,7 +18,7 @@ def get_supported_encodings(self) -> Set[str]: raise NotImplementedError @abstractmethod - def configure_logging(self, encoding: str, log_level: str): + def configure(self, logging_config: "LoggingConfig"): raise NotImplementedError @@ -30,20 +31,24 @@ def __init__(self): def get_supported_encodings(self) -> Set[str]: return self._encoding_to_formatter.keys() - def configure_logging(self, encoding: str, log_level: str): - formatter = self._encoding_to_formatter[encoding] + def configure(self, logging_config: "LoggingConfig"): + formatter = self._encoding_to_formatter[logging_config.encoding] + formatter.set_additional_log_standard_attrs( + logging_config.additional_log_standard_attrs + ) + core_context_filter = CoreContextFilter() handler = logging.StreamHandler() - handler.setLevel(log_level) + handler.setLevel(logging_config.log_level) handler.setFormatter(formatter) handler.addFilter(core_context_filter) root_logger = logging.getLogger() - root_logger.setLevel(log_level) + root_logger.setLevel(logging_config.log_level) root_logger.addHandler(handler) ray_logger = logging.getLogger("ray") - ray_logger.setLevel(log_level) + ray_logger.setLevel(logging_config.log_level) # Remove all existing handlers added by `ray/__init__.py`. for h in ray_logger.handlers[:]: ray_logger.removeHandler(h) @@ -54,11 +59,19 @@ def configure_logging(self, encoding: str, log_level: str): _logging_configurator: LoggingConfigurator = default_impl.get_logging_configurator() +# Class defines the logging configurations for a Ray job. +# To add a new logging configuration: (1) add a new field to this class; (2) Update the +# logic in the __post_init__ method in this class to add the validation logic; +# (3) Update the configure method in the DefaultLoggingConfigurator +# class to use the new field. @PublicAPI(stability="alpha") @dataclass class LoggingConfig: encoding: str = "TEXT" log_level: str = "INFO" + # The list of valid attributes are defined as LOGRECORD_STANDARD_ATTRS in + # constants.py. + additional_log_standard_attrs: list = field(default_factory=list) def __post_init__(self): if self.encoding not in _logging_configurator.get_supported_encodings(): @@ -68,9 +81,17 @@ def __post_init__(self): f"{list(_logging_configurator.get_supported_encodings())}" ) + for attr in self.additional_log_standard_attrs: + if attr not in LOGRECORD_STANDARD_ATTRS: + raise ValueError( + f"Unknown python logging standard attribute: {attr}. " + "The valid attributes are: " + f"{LOGRECORD_STANDARD_ATTRS}" + ) + def _configure_logging(self): """Set up the logging configuration for the current process.""" - _logging_configurator.configure_logging(self.encoding, self.log_level) + _logging_configurator.configure(self) def _apply(self): """Set up the logging configuration.""" @@ -89,7 +110,7 @@ def _apply(self): import logging ray.init( - logging_config=ray.LoggingConfig(encoding="TEXT", log_level="INFO") + logging_config=ray.LoggingConfig(encoding="TEXT", log_level="INFO", additional_log_standard_attrs=['name']) ) @ray.remote @@ -102,11 +123,16 @@ def f(): .. testoutput:: :options: +MOCK - 2024-06-03 07:53:50,815 INFO test.py:11 -- This is a Ray task job_id=01000000 worker_id=0dbbbd0f17d5343bbeee8228fa5ff675fe442445a1bc06ec899120a8 node_id=577706f1040ea8ebd76f7cf5a32338d79fe442e01455b9e7110cddfc task_id=c8ef45ccd0112571ffffffffffffffffffffffff01000000 + 2024-06-03 07:53:50,815 INFO test.py:11 -- This is a Ray task name=__main__ job_id=01000000 worker_id=0dbbbd0f17d5343bbeee8228fa5ff675fe442445a1bc06ec899120a8 node_id=577706f1040ea8ebd76f7cf5a32338d79fe442e01455b9e7110cddfc task_id=c8ef45ccd0112571ffffffffffffffffffffffff01000000 Args: encoding: Encoding type for the logs. The valid values are {list(_logging_configurator.get_supported_encodings())} log_level: Log level for the logs. Defaults to 'INFO'. You can set it to 'DEBUG' to receive more detailed debug logs. + additional_log_standard_attrs: List of additional standard python logger attributes to + include in the log. Defaults to an empty list. The list of already + included standard attributes are: "asctime", "levelname", "message", + "filename", "lineno", "exc_text". The list of valid attributes are specified + here: http://docs.python.org/library/logging.html#logrecord-attributes """ # noqa: E501 diff --git a/python/ray/_private/runtime_env/agent/runtime_env_agent.py b/python/ray/_private/runtime_env/agent/runtime_env_agent.py index d76a0e06cbac8..c08dfbc203784 100644 --- a/python/ray/_private/runtime_env/agent/runtime_env_agent.py +++ b/python/ray/_private/runtime_env/agent/runtime_env_agent.py @@ -117,6 +117,7 @@ def _increase_reference_for_runtime_env(self, serialized_env: str): self._runtime_env_reference[serialized_env] += 1 def _decrease_reference_for_runtime_env(self, serialized_env: str): + """Decrease reference count for the given [serialized_env]. Throw exception if we cannot decrement reference.""" default_logger.debug(f"Decrease reference for runtime env {serialized_env}.") unused = False if self._runtime_env_reference[serialized_env] > 0: @@ -126,10 +127,12 @@ def _decrease_reference_for_runtime_env(self, serialized_env: str): del self._runtime_env_reference[serialized_env] else: default_logger.warning(f"Runtime env {serialized_env} does not exist.") + raise ValueError( + f"{serialized_env} cannot decrement reference since the reference count is 0" + ) if unused: default_logger.info(f"Unused runtime env {serialized_env}.") self._unused_runtime_env_callback(serialized_env) - return unused def increase_reference( self, runtime_env: RuntimeEnv, serialized_env: str, source_process: str @@ -143,8 +146,9 @@ def increase_reference( def decrease_reference( self, runtime_env: RuntimeEnv, serialized_env: str, source_process: str ) -> None: + """Decrease reference count for runtime env and uri. Throw exception if decrement reference count fails.""" if source_process in self._reference_exclude_sources: - return list() + return self._decrease_reference_for_runtime_env(serialized_env) uris = self._uris_parser(runtime_env) self._decrease_reference_for_uris(uris) @@ -543,9 +547,15 @@ async def DeleteRuntimeEnvIfPossible(self, request): ), ) - self._reference_table.decrease_reference( - runtime_env, request.serialized_runtime_env, request.source_process - ) + try: + self._reference_table.decrease_reference( + runtime_env, request.serialized_runtime_env, request.source_process + ) + except Exception as e: + return runtime_env_agent_pb2.DeleteRuntimeEnvIfPossibleReply( + status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED, + error_message=f"Fails to decrement reference for runtime env for {str(e)}", + ) return runtime_env_agent_pb2.DeleteRuntimeEnvIfPossibleReply( status=agent_manager_pb2.AGENT_RPC_STATUS_OK diff --git a/python/ray/_private/thirdparty/dacite/LICENSE b/python/ray/_private/thirdparty/dacite/LICENSE index 4be5be76271ce..5aff43f8f59d0 100644 --- a/python/ray/_private/thirdparty/dacite/LICENSE +++ b/python/ray/_private/thirdparty/dacite/LICENSE @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/python/ray/_private/thirdparty/pyamdsmi/pyamdsmi.py b/python/ray/_private/thirdparty/pyamdsmi/pyamdsmi.py index b88853eb4b6f1..7edc0955c7308 100644 --- a/python/ray/_private/thirdparty/pyamdsmi/pyamdsmi.py +++ b/python/ray/_private/thirdparty/pyamdsmi/pyamdsmi.py @@ -549,4 +549,3 @@ def smi_get_device_xgmi_hive_id(dev): hive_id = c_uint64() ret = rocm_lib.rsmi_dev_xgmi_hive_id_get(dev, byref(hive_id)) return hive_id.value if rsmi_ret_ok(ret) else -1 - diff --git a/python/ray/_private/thirdparty/pynvml/__init__.py b/python/ray/_private/thirdparty/pynvml/__init__.py index 2cc389aba708d..1b674aebf667e 100644 --- a/python/ray/_private/thirdparty/pynvml/__init__.py +++ b/python/ray/_private/thirdparty/pynvml/__init__.py @@ -1,4 +1,4 @@ from ray._private.thirdparty.pynvml.pynvml import * # nvdia-ml-py version # Note: we pick this version to use the V2 API which is supported by older drivers -__version__ = "11.495.46" \ No newline at end of file +__version__ = "11.495.46" diff --git a/python/ray/_private/thirdparty/pynvml/pynvml.py b/python/ray/_private/thirdparty/pynvml/pynvml.py index 893db0787e7a7..e0092f8d3c2a0 100644 --- a/python/ray/_private/thirdparty/pynvml/pynvml.py +++ b/python/ray/_private/thirdparty/pynvml/pynvml.py @@ -3668,4 +3668,3 @@ def nvmlDeviceGetIrqNum(device): ret = fn(device, byref(c_irqNum)) _nvmlCheckReturn(ret) return c_irqNum.value - diff --git a/python/ray/air/config.py b/python/ray/air/config.py index 85e6fac7d7168..af34c3aeafa56 100644 --- a/python/ray/air/config.py +++ b/python/ray/air/config.py @@ -695,6 +695,8 @@ def __post_init__(self): if not self.checkpoint_config: self.checkpoint_config = CheckpointConfig() + # Save the original verbose value to check for deprecations + self._verbose = self.verbose if self.verbose is None: # Default `verbose` value. For new output engine, # this is AirVerbosity.DEFAULT. diff --git a/python/ray/autoscaler/_private/_azure/azure-config-template.json b/python/ray/autoscaler/_private/_azure/azure-config-template.json index bedb063d4447a..558273f58a637 100644 --- a/python/ray/autoscaler/_private/_azure/azure-config-template.json +++ b/python/ray/autoscaler/_private/_azure/azure-config-template.json @@ -127,4 +127,4 @@ "value": "[resourceId(parameters('msiResourceGroup'), 'Microsoft.ManagedIdentity/userAssignedIdentities', parameters('msiName'))]" } } -} \ No newline at end of file +} diff --git a/python/ray/autoscaler/_private/vsphere/ARCHITECTURE.md b/python/ray/autoscaler/_private/vsphere/ARCHITECTURE.md index 6e81cc7680e86..3606c65807f24 100644 --- a/python/ray/autoscaler/_private/vsphere/ARCHITECTURE.md +++ b/python/ray/autoscaler/_private/vsphere/ARCHITECTURE.md @@ -79,4 +79,4 @@ The autoscaler can find the currently running nodes with `non_terminated_nodes` The autoscaler can use `external_ip` or `internal_ip` function to fetch a node's IP. ## Cluster tear down ([node_provider.py](./node_provider.py)) -`terminate_nodes` function gets called on ray down command's execution. It deletes all the nodes except the frozen VM. \ No newline at end of file +`terminate_nodes` function gets called on ray down command's execution. It deletes all the nodes except the frozen VM. diff --git a/python/ray/autoscaler/_private/vsphere/data/userdata.yaml b/python/ray/autoscaler/_private/vsphere/data/userdata.yaml index 2da58e64c1fcc..ef4931a3d2b0f 100644 --- a/python/ray/autoscaler/_private/vsphere/data/userdata.yaml +++ b/python/ray/autoscaler/_private/vsphere/data/userdata.yaml @@ -7,4 +7,4 @@ users: name: ray passwd: AdminRay primary_group: sudo - ssh_authorized_keys: \ No newline at end of file + ssh_authorized_keys: diff --git a/python/ray/autoscaler/aws/cloudwatch/example-cloudwatch-dashboard-config.json b/python/ray/autoscaler/aws/cloudwatch/example-cloudwatch-dashboard-config.json index b30047a807f0e..b7976df916a51 100644 --- a/python/ray/autoscaler/aws/cloudwatch/example-cloudwatch-dashboard-config.json +++ b/python/ray/autoscaler/aws/cloudwatch/example-cloudwatch-dashboard-config.json @@ -235,4 +235,3 @@ } } ] - diff --git a/python/ray/autoscaler/aws/example-subnets.yaml b/python/ray/autoscaler/aws/example-subnets.yaml index 4058327420dc0..4c0c920290b8d 100644 --- a/python/ray/autoscaler/aws/example-subnets.yaml +++ b/python/ray/autoscaler/aws/example-subnets.yaml @@ -30,4 +30,3 @@ available_node_types: node_config: SubnetIds: - subnet-fffffff # Replace with your actual Worker Node Subnet ID. - diff --git a/python/ray/autoscaler/gcp/example-tpu-pod-topology.yaml b/python/ray/autoscaler/gcp/example-tpu-pod-topology.yaml index 521cae976d806..f4a6b07b336a6 100644 --- a/python/ray/autoscaler/gcp/example-tpu-pod-topology.yaml +++ b/python/ray/autoscaler/gcp/example-tpu-pod-topology.yaml @@ -55,4 +55,4 @@ head_setup_commands: - pip install google-api-python-client # Specify the node type of the head node (as configured above). -head_node_type: ray_head_default \ No newline at end of file +head_node_type: ray_head_default diff --git a/python/ray/autoscaler/gcp/example-tpu-pod.yaml b/python/ray/autoscaler/gcp/example-tpu-pod.yaml index f23d4486c0319..3c5690a6bb892 100644 --- a/python/ray/autoscaler/gcp/example-tpu-pod.yaml +++ b/python/ray/autoscaler/gcp/example-tpu-pod.yaml @@ -51,4 +51,4 @@ head_setup_commands: - pip install google-api-python-client # Specify the node type of the head node (as configured above). -head_node_type: ray_head_default \ No newline at end of file +head_node_type: ray_head_default diff --git a/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py b/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py index c1b8ddc2a31b9..dc667194be7f8 100644 --- a/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py +++ b/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py @@ -110,9 +110,7 @@ class ScaleRequest: def get_non_terminated(self) -> Dict[CloudInstanceId, CloudInstance]: self._sync_with_api_server() - return copy.deepcopy( - {id: instance for id, instance in self._cached_instances.items()} - ) + return copy.deepcopy(dict(self._cached_instances)) def terminate(self, ids: List[CloudInstanceId], request_id: str) -> None: if request_id in self._requests: diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 122a1921af2c9..15ec19a2a9c4a 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -852,7 +852,7 @@ def __init__( # We conservatively set num_shm_buffers to _max_inflight_executions. # It means that the DAG can be underutilized, but it guarantees there's # no false positive timeouts. - num_shm_buffers=1, + num_shm_buffers=self._max_inflight_executions, ) if not isinstance(self._buffer_size_bytes, int) or self._buffer_size_bytes <= 0: raise ValueError( @@ -1930,7 +1930,7 @@ def _detect_deadlock(self) -> bool: Returns: True if a deadlock is detected; otherwise, False. """ - logger.warning("Deadlock detection has not been implemented yet.") + logger.debug("Deadlock detection has not been implemented yet.") return False def _monitor_failures(self): diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index 43235a97e0991..676a2219d6f49 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -225,6 +225,7 @@ def test_get_ref_after_destructed_ref(self, ray_start_regular): ref = compiled_dag.execute(2) ref2 = compiled_dag.execute(2) ref3 = compiled_dag.execute(2) + del ref del ref2 # Test that ray.get() works correctly if preceding ref was destructed assert ray.get(ref3) == 6 @@ -241,6 +242,7 @@ def test_release_buffer_on_execute(self, ray_start_regular): del ref3 ray.get(ref) ref4 = compiled_dag.execute(3) + del ref4 # Test that max_inflight error is not raised as ref2 and ref3 # should be destructed and not counted in the inflight executions ref5 = compiled_dag.execute(3) @@ -2407,7 +2409,6 @@ def test_driver_and_intraprocess_read(ray_start_cluster): assert ray.get(dag.execute(1)) == [1, 2] -@pytest.mark.skip("Currently buffer size is set to 1 because of regression.") @pytest.mark.parametrize("temporary_change_timeout", [1], indirect=True) def test_buffered_inputs(shutdown_only, temporary_change_timeout): ray.init() @@ -2431,7 +2432,7 @@ def fwd(self, x): actor1 = Actor1.remote() # Since the timeout is 1 second, if buffering is not working, - # it will timeout (0.12s for each dag * MAX_INFLIGHT_EXECUTIONS). + # it will timeout (0.2s for each dag * MAX_INFLIGHT_EXECUTIONS). with InputNode() as input_node: dag = actor1.fwd.bind(input_node) @@ -2445,7 +2446,7 @@ def fwd(self, x): for i, ref in enumerate(output_refs): assert ray.get(ref) == i - # Test there are more items than max bufcfered inputs. + # Test there are more items than max buffered inputs. output_refs = [] for i in range(MAX_INFLIGHT_EXECUTIONS): output_refs.append(dag.execute(i)) @@ -2537,14 +2538,14 @@ async def main(): match=(expected_error_message), ): _ = await async_compiled_dag.execute_async(1) - (ref1, ref2) + _ = (ref1, ref2) loop = get_or_create_event_loop() loop.run_until_complete(main()) # to show variables are being used and avoid destruction since # CompiledDagRef __del__ will release buffers and # increment _max_finished_execution_index - (ref1, ref2) + _ = (ref1, ref2) def test_result_buffer_exceeds_capacity(ray_start_regular): @@ -2585,12 +2586,12 @@ async def main(): match=(expected_error_message), ): _ = await async_compiled_dag.execute_async(4) - (ref1, ref3) + _ = (ref1, ref3) loop = get_or_create_event_loop() loop.run_until_complete(main()) # same reason as comment for test_inflight_requests_exceed_capacity - (ref1, ref3) + _ = (ref1, ref3) def test_event_profiling(ray_start_regular, monkeypatch): diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index 98bf318cd94e4..a76f83362c6f3 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -1168,12 +1168,12 @@ def test_torch_tensor_nccl_all_reduce_wrong_shape(ray_start_regular): with pytest.raises(RayChannelError): ray.get(ref) - # The DAG will be torn down after any task throws an application-level - # exception, such as when the task returns torch.Tensors of the wrong - # shape or dtype. Check that we can no longer submit to the DAG. + # Since we have buffered channels, the execution should not error, but the + # get should error, as the dag should no longer work after the application- + # level exception. ref = compiled_dag.execute([((20,), dtype, 1) for _ in workers]) with pytest.raises(RayChannelError): - ref = compiled_dag.execute([((20,), dtype, 1) for _ in workers]) + ray.get(ref) @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) diff --git a/python/ray/dashboard/datacenter.py b/python/ray/dashboard/datacenter.py index 2a2c660ecd440..dd2d5edef8590 100644 --- a/python/ray/dashboard/datacenter.py +++ b/python/ray/dashboard/datacenter.py @@ -1,5 +1,5 @@ import logging -from typing import Any, List, Optional +from typing import List, Optional import ray.dashboard.consts as dashboard_consts from ray._private.utils import ( @@ -26,9 +26,6 @@ class DataSource: # {actor id hex(str): actor table data(dict of ActorTableData # in gcs.proto)} actors = MutableNotificationDict() - # {job id hex(str): job table data(dict of JobTableData in gcs.proto)} - # {node id hex(str): dashboard agent [http port(int), grpc port(int)]} - agents = Dict() # {node id hex(str): gcs node info(dict of GcsNodeInfo in gcs.proto)} nodes = Dict() # {node id hex(str): worker list} @@ -188,41 +185,6 @@ async def get_all_node_summary(cls): for node_id in DataSource.nodes.keys() ] - @classmethod - async def get_agent_infos( - cls, target_node_ids: Optional[List[str]] = None - ) -> Dict[str, Dict[str, Any]]: - """Fetches running Agent (like HTTP/gRPC ports, IP, etc) running on every node - - :param target_node_ids: Target node ids to fetch agent info for. If omitted will - fetch the info for all agents - """ - - # Return all available agent infos in case no target node-ids were provided - target_node_ids = target_node_ids or DataSource.agents.keys() - - missing_node_ids = [ - node_id for node_id in target_node_ids if node_id not in DataSource.agents - ] - if missing_node_ids: - logger.warning( - f"Agent info was not found for {missing_node_ids}" - f" (having agent infos for {list(DataSource.agents.keys())})" - ) - return {} - - def _create_agent_info(node_id: str): - (node_ip, http_port, grpc_port) = DataSource.agents[node_id] - - return dict( - ipAddress=node_ip, - httpPort=int(http_port or -1), - grpcPort=int(grpc_port or -1), - httpAddress=f"{node_ip}:{http_port}", - ) - - return {node_id: _create_agent_info(node_id) for node_id in target_node_ids} - @classmethod async def get_actor_infos(cls, actor_ids: Optional[List[str]] = None): target_actor_table_entries: dict[str, Optional[dict]] diff --git a/python/ray/dashboard/modules/job/common.py b/python/ray/dashboard/modules/job/common.py index 8b308ded25d27..c93cfcaac34a0 100644 --- a/python/ray/dashboard/modules/job/common.py +++ b/python/ray/dashboard/modules/job/common.py @@ -373,12 +373,7 @@ async def get_job_info(job_id: str): job_info = await self.get_info(job_id, timeout) return job_id, job_info - return { - job_id: job_info - for job_id, job_info in await asyncio.gather( - *[get_job_info(job_id) for job_id in job_ids] - ) - } + return dict(await asyncio.gather(*[get_job_info(job_id) for job_id in job_ids])) def uri_to_http_components(package_uri: str) -> Tuple[str, str]: diff --git a/python/ray/dashboard/modules/job/job_head.py b/python/ray/dashboard/modules/job/job_head.py index 185c8fc94983a..78c3801d9347e 100644 --- a/python/ray/dashboard/modules/job/job_head.py +++ b/python/ray/dashboard/modules/job/job_head.py @@ -3,25 +3,31 @@ import json import logging import traceback -from random import sample -from typing import AsyncIterator, List, Optional +from random import choice +from typing import AsyncIterator, Dict, List, Optional, Tuple import aiohttp.web from aiohttp.client import ClientResponse from aiohttp.web import Request, Response import ray +from ray import NodeID import ray.dashboard.consts as dashboard_consts +from ray.dashboard.consts import ( + GCS_RPC_TIMEOUT_SECONDS, + DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX, + TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS, + WAIT_AVAILABLE_AGENT_TIMEOUT, +) import ray.dashboard.optional_utils as optional_utils import ray.dashboard.utils as dashboard_utils -from ray._private.ray_constants import env_bool +from ray._private.ray_constants import env_bool, KV_NAMESPACE_DASHBOARD from ray._private.runtime_env.packaging import ( package_exists, pin_runtime_env_uri, upload_package_to_gcs, ) from ray._private.utils import get_or_create_event_loop -from ray.dashboard.datacenter import DataOrganizer from ray.dashboard.modules.job.common import ( JobDeleteResponse, JobInfoStorageClient, @@ -166,9 +172,10 @@ def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig): # `JobHead` has ever used, and will not be deleted # from it unless `JobAgentSubmissionClient` is no # longer available (the corresponding agent process is dead) - self._agents = dict() + # {node_id: JobAgentSubmissionClient} + self._agents: Dict[NodeID, JobAgentSubmissionClient] = dict() - async def get_target_agent(self) -> Optional[JobAgentSubmissionClient]: + async def get_target_agent(self) -> JobAgentSubmissionClient: if RAY_JOB_AGENT_USE_HEAD_NODE_ONLY: return await self._get_head_node_agent() @@ -188,79 +195,118 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]: 2. if not, randomly select one agent from all available agents, it is possible that the selected one already exists in `self._agents`. + + If there's no agent available at all, or there's exception, it will retry every + `TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS` seconds indefinitely. + """ + while True: + try: + return await self._pick_random_agent_once() + except Exception: + logger.exception( + f"Failed to pick a random agent, retrying in {TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS} seconds..." + ) + await asyncio.sleep(TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) + + async def _pick_random_agent_once(self) -> JobAgentSubmissionClient: + """ + Query the internal kv for all agent infos, and pick agents randomly. May raise + exception if there's no agent available at all or there's network error. """ # NOTE: Following call will block until there's at least 1 agent info # being populated from GCS - agent_infos = await self._fetch_agent_infos() + agent_node_ids = await self._fetch_all_agent_node_ids() # delete dead agents. - for dead_node in set(self._agents) - set(agent_infos): + for dead_node in set(self._agents) - set(agent_node_ids): client = self._agents.pop(dead_node) await client.close() if len(self._agents) >= dashboard_consts.CANDIDATE_AGENT_NUMBER: - node_id = sample(list(set(self._agents)), 1)[0] + node_id = choice(list(self._agents)) return self._agents[node_id] else: # Randomly select one from among all agents, it is possible that # the selected one already exists in `self._agents` - node_id = sample(sorted(agent_infos), 1)[0] - agent_info = agent_infos[node_id] + node_id = choice(list(agent_node_ids)) if node_id not in self._agents: - node_ip = agent_info["ipAddress"] - http_port = agent_info["httpPort"] - agent_http_address = f"http://{node_ip}:{http_port}" + # Fetch agent info from InternalKV, and create a new + # JobAgentSubmissionClient. May raise if the node_id is removed in + # InternalKV after the _fetch_all_agent_node_ids, though unlikely. + ip, http_port, grpc_port = await self._fetch_agent_info(node_id) + agent_http_address = f"http://{ip}:{http_port}" self._agents[node_id] = JobAgentSubmissionClient(agent_http_address) return self._agents[node_id] - async def _get_head_node_agent(self) -> Optional[JobAgentSubmissionClient]: - """Retrieves HTTP client for `JobAgent` running on the Head node""" + async def _get_head_node_agent_once(self) -> JobAgentSubmissionClient: + head_node_id_hex = await get_head_node_id(self.gcs_aio_client) - head_node_id = await get_head_node_id(self.gcs_aio_client) + if not head_node_id_hex: + raise Exception("Head node id has not yet been persisted in GCS") - if not head_node_id: - logger.warning("Head node id has not yet been persisted in GCS") - return None + head_node_id = NodeID.from_hex(head_node_id_hex) if head_node_id not in self._agents: - agent_infos = await self._fetch_agent_infos(target_node_ids=[head_node_id]) - if head_node_id not in agent_infos: - logger.error("Head node agent's information was not found") - return None - - agent_info = agent_infos[head_node_id] - - node_ip = agent_info["ipAddress"] - http_port = agent_info["httpPort"] - agent_http_address = f"http://{node_ip}:{http_port}" - + ip, http_port, grpc_port = await self._fetch_agent_info(head_node_id) + agent_http_address = f"http://{ip}:{http_port}" self._agents[head_node_id] = JobAgentSubmissionClient(agent_http_address) return self._agents[head_node_id] - @staticmethod - async def _fetch_agent_infos(target_node_ids: Optional[List[str]] = None): - """Fetches agent infos for nodes identified by provided node-ids (for all - nodes if not provided) - - NOTE: This call will block until there's at least 1 valid agent info populated + async def _get_head_node_agent(self) -> JobAgentSubmissionClient: + """Retrieves HTTP client for `JobAgent` running on the Head node. If the head + node does not have an agent, it will retry every + `TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS` seconds indefinitely. """ - while True: - raw_agent_infos = await DataOrganizer.get_agent_infos(target_node_ids) - # Filter out invalid agent infos with unset HTTP port - agent_infos = { - key: value - for key, value in raw_agent_infos.items() - if value.get("httpPort", -1) > 0 - } + try: + return await self._get_head_node_agent_once() + except Exception: + logger.exception( + f"Failed to get head node agent, retrying in {TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS} seconds..." + ) + await asyncio.sleep(TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) + + async def _fetch_all_agent_node_ids(self) -> List[NodeID]: + """ + Fetches all NodeIDs with agent infos in the cluster. - if len(agent_infos) > 0: - return agent_infos + May raise exception if there's no agent available at all or there's network error. + Returns: List[NodeID] + """ + keys = await self.gcs_aio_client.internal_kv_keys( + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}".encode(), + namespace=KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + if not keys: + # No agent keys found, retry + raise Exception("No agents found in InternalKV.") + return [ + NodeID.from_hex(key[len(DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX) :].decode()) + for key in keys + ] - await asyncio.sleep(dashboard_consts.TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) + async def _fetch_agent_info(self, target_node_id: NodeID) -> Tuple[str, int, int]: + """ + Fetches agent info by the Node ID. May raise exception if there's network error or the + agent info is not found. + + Returns: (ip, http_port, grpc_port) + """ + key = f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{target_node_id.hex()}" + value = await self.gcs_aio_client.internal_kv_get( + key, + namespace=KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + if not value: + raise KeyError( + f"Agent info not found in internal kv for node {target_node_id}" + ) + return json.loads(value.decode()) @routes.get("/api/version") async def get_version(self, req: Request) -> Response: @@ -337,7 +383,7 @@ async def submit_job(self, req: Request) -> Response: try: job_agent_client = await asyncio.wait_for( self.get_target_agent(), - timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, + timeout=WAIT_AVAILABLE_AGENT_TIMEOUT, ) resp = await job_agent_client.submit_job_internal(submit_request) except asyncio.TimeoutError: @@ -384,7 +430,7 @@ async def stop_job(self, req: Request) -> Response: try: job_agent_client = await asyncio.wait_for( self.get_target_agent(), - timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, + timeout=WAIT_AVAILABLE_AGENT_TIMEOUT, ) resp = await job_agent_client.stop_job_internal(job.submission_id) except Exception: @@ -419,7 +465,7 @@ async def delete_job(self, req: Request) -> Response: try: job_agent_client = await asyncio.wait_for( self.get_target_agent(), - timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, + timeout=WAIT_AVAILABLE_AGENT_TIMEOUT, ) resp = await job_agent_client.delete_job_internal(job.submission_id) except Exception: @@ -456,9 +502,10 @@ async def get_job_info(self, req: Request) -> Response: # that). @routes.get("/api/jobs/") async def list_jobs(self, req: Request) -> Response: - driver_jobs, submission_job_drivers = await get_driver_jobs(self.gcs_aio_client) + (driver_jobs, submission_job_drivers), submission_jobs = await asyncio.gather( + get_driver_jobs(self.gcs_aio_client), self._job_info_client.get_all_jobs() + ) - submission_jobs = await self._job_info_client.get_all_jobs() submission_jobs = [ JobDetails( **dataclasses.asdict(job), diff --git a/python/ray/dashboard/modules/job/sdk.py b/python/ray/dashboard/modules/job/sdk.py index b3b25e936fa0c..b9ff753593fc4 100644 --- a/python/ray/dashboard/modules/job/sdk.py +++ b/python/ray/dashboard/modules/job/sdk.py @@ -331,17 +331,19 @@ def get_job_info( >>> from ray.job_submission import JobSubmissionClient >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP >>> submission_id = client.submit_job(entrypoint="sleep 1") # doctest: +SKIP - >>> job_submission_client.get_job_info(submission_id) # doctest: +SKIP - JobInfo(status='SUCCEEDED', message='Job finished successfully.', - error_type=None, start_time=1647388711, end_time=1647388712, - metadata={}, runtime_env={}) + >>> client.get_job_info(submission_id) # doctest: +SKIP + JobDetails(status='SUCCEEDED', + job_id='03000000', type='submission', + submission_id='raysubmit_4LamXRuQpYdSMg7J', + message='Job finished successfully.', error_type=None, + start_time=1647388711, end_time=1647388712, metadata={}, runtime_env={}) Args: job_id: The job ID or submission ID of the job whose information is being requested. Returns: - The JobInfo for the job. + The JobDetails for the job. Raises: RuntimeError: If the job does not exist or if the request to the @@ -379,7 +381,7 @@ def list_jobs(self) -> List[JobDetails]: start_time=1647454832, end_time=None, metadata={}, runtime_env={})] Returns: - A dictionary mapping job_ids to their information. + A list of JobDetails containing the job status and other information. Raises: RuntimeError: If the request to the job server fails. diff --git a/python/ray/dashboard/modules/job/tests/test_http_job_server.py b/python/ray/dashboard/modules/job/tests/test_http_job_server.py index 1441d89bae1b1..b4ce1ab616907 100644 --- a/python/ray/dashboard/modules/job/tests/test_http_job_server.py +++ b/python/ray/dashboard/modules/job/tests/test_http_job_server.py @@ -8,13 +8,14 @@ import tempfile import time from pathlib import Path -from typing import Optional +from typing import Optional, List, Union, Dict from unittest.mock import patch import pytest import yaml import ray +from ray import NodeID from ray._private.test_utils import ( chdir, format_web_url, @@ -22,6 +23,10 @@ wait_for_condition, wait_until_server_available, ) +from ray.dashboard.consts import ( + DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX, + DASHBOARD_AGENT_ADDR_IP_PREFIX, +) from ray.dashboard.modules.dashboard_sdk import ClusterInfo, parse_cluster_info from ray.dashboard.modules.job.job_head import JobHead from ray.dashboard.modules.job.pydantic_models import JobDetails @@ -736,30 +741,85 @@ async def test_job_head_pick_random_job_agent(monkeypatch): importlib.reload(ray.dashboard.consts) - from ray.dashboard.datacenter import DataSource + # Fake GCS client + class _FakeGcsClient: + def __init__(self): + self._kv: Dict[bytes, bytes] = {} + + @staticmethod + def ensure_bytes(key: Union[bytes, str]) -> bytes: + return key.encode() if isinstance(key, str) else key + + async def internal_kv_put( + self, key: Union[bytes, str], value: bytes, **kwargs + ): + key = self.ensure_bytes(key) + self._kv[key] = value + + async def internal_kv_get(self, key: Union[bytes, str], **kwargs): + key = self.ensure_bytes(key) + return self._kv.get(key, None) + + async def internal_kv_multi_get( + self, keys: List[Union[bytes, str]], **kwargs + ): + return {key: self.internal_kv_get(key) for key in keys} + + async def internal_kv_del(self, key: Union[bytes, str], **kwargs): + key = self.ensure_bytes(key) + self._kv.pop(key) + + async def internal_kv_keys(self, prefix: Union[bytes, str], **kwargs): + prefix = self.ensure_bytes(prefix) + return [key for key in self._kv.keys() if key.startswith(prefix)] class MockJobHead(JobHead): def __init__(self): self._agents = dict() + self._gcs_aio_client = _FakeGcsClient() + + @property + def gcs_aio_client(self): + # Overrides JobHead.gcs_aio_client + return self._gcs_aio_client - DataSource.agents = {} - DataSource.nodes = {} job_head = MockJobHead() - def add_agent(agent): + async def add_agent(agent): node_id = agent[0] node_ip = agent[1]["ipAddress"] http_port = agent[1]["httpPort"] grpc_port = agent[1]["grpcPort"] - DataSource.nodes[node_id] = {"nodeManagerAddress": node_ip} - DataSource.agents[node_id] = (node_ip, http_port, grpc_port) - def del_agent(agent): + await job_head._gcs_aio_client.internal_kv_put( + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id.hex()}".encode(), + json.dumps([node_ip, http_port, grpc_port]).encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + await job_head._gcs_aio_client.internal_kv_put( + f"{DASHBOARD_AGENT_ADDR_IP_PREFIX}{node_ip}".encode(), + json.dumps([node_id.hex(), http_port, grpc_port]).encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + + async def del_agent(agent): node_id = agent[0] - DataSource.nodes.pop(node_id) - DataSource.agents.pop(node_id) - - head_node_id = "node1" + node_ip = agent[1]["ipAddress"] + await job_head._gcs_aio_client.internal_kv_del( + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id.hex()}".encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + await job_head._gcs_aio_client.internal_kv_del( + f"{DASHBOARD_AGENT_ADDR_IP_PREFIX}{node_ip}".encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + + head_node_id = NodeID.from_random() + await job_head._gcs_aio_client.internal_kv_put( + ray_constants.KV_HEAD_NODE_ID_KEY, + head_node_id.hex().encode(), + namespace=ray_constants.KV_NAMESPACE_JOB, + ) agent_1 = ( head_node_id, @@ -771,7 +831,7 @@ def del_agent(agent): ), ) agent_2 = ( - "node2", + NodeID.from_random(), dict( ipAddress="2.2.2.2", httpPort=2, @@ -780,7 +840,7 @@ def del_agent(agent): ), ) agent_3 = ( - "node3", + NodeID.from_random(), dict( ipAddress="3.3.3.3", httpPort=3, @@ -796,12 +856,12 @@ def del_agent(agent): ) # Check only 1 agent present, only agent being returned - add_agent(agent_1) + await add_agent(agent_1) job_agent_client = await job_head.get_target_agent() assert job_agent_client._agent_address == "http://1.1.1.1:1" # Remove only agent, no agents present, should time out - del_agent(agent_1) + await del_agent(agent_1) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(job_head.get_target_agent(), timeout=3) @@ -812,19 +872,9 @@ def del_agent(agent): ) # Add 3 agents - add_agent(agent_1) - add_agent(agent_2) - add_agent(agent_3) - - # Mock GCS client - class _MockedGCSClient: - async def internal_kv_get(self, key: bytes, **kwargs): - if key == ray_constants.KV_HEAD_NODE_ID_KEY: - return head_node_id.encode() - - return None - - job_head._gcs_aio_client = _MockedGCSClient() + await add_agent(agent_1) + await add_agent(agent_2) + await add_agent(agent_3) # Make sure returned agent is a head-node # NOTE: We run 3 tims to make sure we're not hitting branch probabilistically @@ -853,7 +903,7 @@ async def internal_kv_get(self, key: bytes, **kwargs): for agent in [agent_1, agent_2, agent_3]: if f"http://{agent[1]['httpAddress']}" in addresses_2: break - del_agent(agent) + await del_agent(agent) # Theoretically, the probability of failure is 1/2^100 addresses_3 = set() @@ -871,7 +921,7 @@ async def internal_kv_get(self, key: bytes, **kwargs): for agent in [agent_1, agent_2, agent_3]: if f"http://{agent[1]['httpAddress']}" in addresses_4: break - del_agent(agent) + await del_agent(agent) address = None for _ in range(3): job_agent_client = await job_head.get_target_agent() diff --git a/python/ray/dashboard/modules/job/tests/test_sdk.py b/python/ray/dashboard/modules/job/tests/test_sdk.py index e440cc2efb917..0e1500bc9ce58 100644 --- a/python/ray/dashboard/modules/job/tests/test_sdk.py +++ b/python/ray/dashboard/modules/job/tests/test_sdk.py @@ -7,17 +7,23 @@ from unittest.mock import Mock, patch import pytest -import requests import ray import ray.experimental.internal_kv as kv -from ray._private.ray_constants import DEFAULT_DASHBOARD_AGENT_LISTEN_PORT +from ray._private.ray_constants import ( + DEFAULT_DASHBOARD_AGENT_LISTEN_PORT, + KV_NAMESPACE_DASHBOARD, +) from ray._private.test_utils import ( format_web_url, wait_for_condition, wait_until_server_available, ) -from ray.dashboard.consts import RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR +from ray.dashboard.consts import ( + RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR, + DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX, + GCS_RPC_TIMEOUT_SECONDS, +) from ray.dashboard.modules.dashboard_sdk import ( DEFAULT_DASHBOARD_ADDRESS, ClusterInfo, @@ -28,7 +34,7 @@ from ray.dashboard.tests.conftest import * # noqa from ray.tests.conftest import _ray_start from ray.util.state import list_nodes - +from ray._raylet import GcsClient import psutil @@ -165,12 +171,13 @@ def mock_candidate_number(): os.environ.pop("CANDIDATE_AGENT_NUMBER", None) -def get_register_agents_number(webui_url): - response = requests.get(webui_url + "/internal/node_module") - response.raise_for_status() - result = response.json() - data = result["data"] - return data["registeredAgents"] +def get_register_agents_number(gcs_client): + keys = gcs_client.internal_kv_keys( + prefix=DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX, + namespace=KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + return len(keys) @pytest.mark.parametrize( @@ -195,6 +202,7 @@ def test_job_head_choose_job_agent_E2E(ray_start_cluster_head_with_env_vars): webui_url = cluster.webui_url webui_url = format_web_url(webui_url) client = JobSubmissionClient(webui_url) + gcs_client = GcsClient(address=cluster.gcs_address) def submit_job_and_wait_finish(): submission_id = client.submit_job(entrypoint="echo hello") @@ -206,7 +214,7 @@ def submit_job_and_wait_finish(): head_http_port = DEFAULT_DASHBOARD_AGENT_LISTEN_PORT worker_1_http_port = 52366 cluster.add_node(dashboard_agent_listen_port=worker_1_http_port) - wait_for_condition(lambda: get_register_agents_number(webui_url) == 2, timeout=20) + wait_for_condition(lambda: get_register_agents_number(gcs_client) == 2, timeout=20) assert len(cluster.worker_nodes) == 1 node_try_to_kill = list(cluster.worker_nodes)[0] @@ -250,7 +258,7 @@ def _kill_all_driver(): worker_2_http_port = 52367 cluster.add_node(dashboard_agent_listen_port=worker_2_http_port) - wait_for_condition(lambda: get_register_agents_number(webui_url) == 3, timeout=20) + wait_for_condition(lambda: get_register_agents_number(gcs_client) == 3, timeout=20) # The third `JobAgent` will not be called here. submit_job_and_wait_finish() @@ -281,7 +289,7 @@ def get_all_new_supervisor_actor_info(old_supervisor_actor_ids): node_try_to_kill.kill_raylet() # make sure the head updates the info of the dead node. - wait_for_condition(lambda: get_register_agents_number(webui_url) == 2, timeout=20) + wait_for_condition(lambda: get_register_agents_number(gcs_client) == 2, timeout=20) # Make sure the third JobAgent will be called here. wait_for_condition( @@ -324,6 +332,7 @@ def test_jobs_run_on_head_by_default_E2E(ray_start_cluster_head_with_env_vars): webui_url = cluster.webui_url webui_url = format_web_url(webui_url) client = JobSubmissionClient(webui_url) + gcs_client = GcsClient(address=cluster.gcs_address) def _check_nodes(num_nodes): try: @@ -334,7 +343,7 @@ def _check_nodes(num_nodes): return False wait_for_condition(lambda: _check_nodes(num_nodes=3), timeout=15) - wait_for_condition(lambda: get_register_agents_number(webui_url) == 3, timeout=20) + wait_for_condition(lambda: get_register_agents_number(gcs_client) == 3, timeout=20) # Submit 20 simple jobs. for i in range(20): diff --git a/python/ray/dashboard/modules/job/utils.py b/python/ray/dashboard/modules/job/utils.py index 6c36bb807be3a..a426868a79903 100644 --- a/python/ray/dashboard/modules/job/utils.py +++ b/python/ray/dashboard/modules/job/utils.py @@ -36,13 +36,14 @@ async def get_head_node_id(gcs_aio_client: GcsAioClient) -> Optional[str]: """Fetches Head node id persisted in GCS""" - head_node_id_bytes = await gcs_aio_client.internal_kv_get( + head_node_id_hex_bytes = await gcs_aio_client.internal_kv_get( ray_constants.KV_HEAD_NODE_ID_KEY, namespace=ray_constants.KV_NAMESPACE_JOB, timeout=30, ) - - return head_node_id_bytes.decode() if head_node_id_bytes is not None else None + if head_node_id_hex_bytes is None: + return None + return head_node_id_hex_bytes.decode() def strip_keys_with_value_none(d: Dict[str, Any]) -> Dict[str, Any]: diff --git a/python/ray/dashboard/modules/log/log_manager.py b/python/ray/dashboard/modules/log/log_manager.py index bb21446f15f62..a05b09c8f9d4b 100644 --- a/python/ray/dashboard/modules/log/log_manager.py +++ b/python/ray/dashboard/modules/log/log_manager.py @@ -12,7 +12,6 @@ GetLogOptions, protobuf_to_task_state_dict, ) -from ray.util.state.exception import DataSourceUnavailable from ray.util.state.state_manager import StateDataSourceClient if BaseModel is None: @@ -74,9 +73,8 @@ async def list_logs( Dictionary of {component_name -> list of log files} Raises: - DataSourceUnavailable: If a source is unresponsive. + ValueError: If a source is unresponsive. """ - self._verify_node_registered(node_id) reply = await self.client.list_logs(node_id, glob_filter, timeout=timeout) return self._categorize_log_files(reply.log_files) @@ -126,18 +124,6 @@ async def stream_logs( async for streamed_log in stream: yield streamed_log.data - def _verify_node_registered(self, node_id: str): - if node_id not in self.client.get_all_registered_log_agent_ids(): - raise DataSourceUnavailable( - f"Given node id {node_id} is not available. " - "It's either the node is dead, or it is not registered. " - "Use `ray list nodes` " - "to see the node status. If the node is registered, " - "it is highly likely " - "a transient issue. Try again." - ) - assert node_id is not None - async def _resolve_job_filename(self, sub_job_id: str) -> Tuple[str, str]: """Return the log file name and node id for a given job submission id. @@ -249,7 +235,6 @@ async def _resolve_actor_filename( "Actor is not scheduled yet." ) node_id = NodeID(node_id_binary) - self._verify_node_registered(node_id.hex()) log_filename = await self._resolve_worker_file( node_id_hex=node_id.hex(), worker_id_hex=worker_id.hex(), @@ -415,7 +400,6 @@ async def resolve_filename( "Node id needs to be specified for resolving" f" filenames of pid {pid}" ) - self._verify_node_registered(node_id) log_filename = await self._resolve_worker_file( node_id_hex=node_id, worker_id_hex=None, diff --git a/python/ray/dashboard/modules/metrics/metrics_head.py b/python/ray/dashboard/modules/metrics/metrics_head.py index 58a9b31a9cc40..93e0f7f891ca8 100644 --- a/python/ray/dashboard/modules/metrics/metrics_head.py +++ b/python/ray/dashboard/modules/metrics/metrics_head.py @@ -283,7 +283,7 @@ def _create_default_grafana_configs(self): if isinstance(prometheus_headers, list): prometheus_header_pairs = prometheus_headers elif isinstance(prometheus_headers, dict): - prometheus_header_pairs = [(k, v) for k, v in prometheus_headers.items()] + prometheus_header_pairs = list(prometheus_headers.items()) data_sources_path = os.path.join(grafana_provisioning_folder, "datasources") os.makedirs( diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 8707c6abae196..fd4dfe8b1a815 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -11,7 +11,6 @@ import grpc import ray._private.utils -import ray.dashboard.consts as dashboard_consts import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.utils as dashboard_utils from ray._private import ray_constants @@ -30,7 +29,11 @@ parse_usage, ) from ray.core.generated import gcs_pb2, node_manager_pb2, node_manager_pb2_grpc -from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS +from ray.dashboard.consts import ( + GCS_RPC_TIMEOUT_SECONDS, + DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX, + DASHBOARD_AGENT_ADDR_IP_PREFIX, +) from ray.dashboard.datacenter import DataOrganizer, DataSource from ray.dashboard.modules.node import node_consts from ray.dashboard.modules.node.node_consts import ( @@ -125,7 +128,6 @@ def get_internal_states(self): return { "head_node_registration_time_s": self._head_node_registration_time_s, "registered_nodes": len(DataSource.nodes), - "registered_agents": len(DataSource.agents), "module_lifetime_s": time.time() - self._module_start_time, } @@ -195,48 +197,27 @@ async def _update_node(self, node: dict): ) assert node["state"] in ["ALIVE", "DEAD"] is_alive = node["state"] == "ALIVE" - # Prepare agents for alive node, and pop agents for dead node. - if is_alive: - if node_id not in DataSource.agents: - # Agent port is read from internal KV, which is only populated - # upon Agent startup. In case this update received before agent - # fully started up, we schedule a task to asynchronously update - # DataSource with appropriate agent port. - asyncio.create_task(self._update_agent(node_id)) - else: - DataSource.agents.pop(node_id, None) - self._dead_node_queue.append(node_id) - if len(self._dead_node_queue) > node_consts.MAX_DEAD_NODES_TO_CACHE: - DataSource.nodes.pop(self._dead_node_queue.popleft(), None) - DataSource.nodes[node_id] = node - - async def _update_agent(self, node_id): - """ - Given a node, update the agent_port in DataSource.agents. Problem is it's not - present until agent.py starts, so we need to loop waiting for agent.py writes - its port to internal kv. - """ - key = ( - f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}".encode() - ) - while True: - try: - agent_addr = await self.gcs_aio_client.internal_kv_get( + if not is_alive: + # Remove the agent address from the internal KV. + keys = [ + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}", + f"{DASHBOARD_AGENT_ADDR_IP_PREFIX}{node['nodeManagerAddress']}", + ] + tasks = [ + self.gcs_aio_client.internal_kv_del( key, + del_by_prefix=False, namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - timeout=None, + timeout=GCS_RPC_TIMEOUT_SECONDS, ) - # The node may be dead already. Only update DataSource.agents if the - # node is still alive. - if DataSource.nodes.get(node_id, {}).get("state") != "ALIVE": - return - if agent_addr: - DataSource.agents[node_id] = json.loads(agent_addr) - return - except Exception: - logger.exception(f"Error getting agent port for node {node_id}.") + for key in keys + ] + await asyncio.gather(*tasks) - await asyncio.sleep(node_consts.RAY_DASHBOARD_AGENT_POLL_INTERVAL_S) + self._dead_node_queue.append(node_id) + if len(self._dead_node_queue) > node_consts.MAX_DEAD_NODES_TO_CACHE: + DataSource.nodes.pop(self._dead_node_queue.popleft(), None) + DataSource.nodes[node_id] = node async def _update_nodes(self): """ @@ -263,23 +244,25 @@ async def _update_nodes(self): ) warning_shown = True - @routes.get("/internal/node_module") - async def get_node_module_internal_state(self, req) -> aiohttp.web.Response: - return dashboard_optional_utils.rest_response( - success=True, - message="", - **self.get_internal_states(), - ) - async def get_nodes_logical_resources(self) -> dict: from ray.autoscaler.v2.utils import is_autoscaler_v2 if is_autoscaler_v2(): - from ray.autoscaler.v2.sdk import get_cluster_status + from ray.autoscaler.v2.sdk import ClusterStatusParser + from ray.autoscaler.v2.schema import Stats try: - cluster_status = get_cluster_status(self.gcs_address) + # here we have a sync request + req_time = time.time() + cluster_status = await self.gcs_aio_client.get_cluster_status() + reply_time = time.time() + cluster_status = ClusterStatusParser.from_get_cluster_status_reply( + cluster_status, + stats=Stats( + gcs_request_time_s=reply_time - req_time, request_ts_s=req_time + ), + ) except Exception: logger.exception("Error getting cluster status") return {} diff --git a/python/ray/dashboard/modules/node/tests/test_node.py b/python/ray/dashboard/modules/node/tests/test_node.py index 93d6cbd600991..e8c62618f8247 100644 --- a/python/ray/dashboard/modules/node/tests/test_node.py +++ b/python/ray/dashboard/modules/node/tests/test_node.py @@ -43,22 +43,10 @@ def test_nodes_update(enable_test_module, ray_start_with_dashboard): assert dump_info["result"] is True dump_data = dump_info["data"] assert len(dump_data["nodes"]) == 1 - assert len(dump_data["agents"]) == 1 - - response = requests.get(webui_url + "/test/notified_agents") - response.raise_for_status() - try: - notified_agents = response.json() - except Exception as ex: - logger.info("failed response: %s", response.text) - raise ex - assert notified_agents["result"] is True - notified_agents = notified_agents["data"] - assert len(notified_agents) == 1 - assert notified_agents == dump_data["agents"] break - except (AssertionError, requests.exceptions.ConnectionError) as e: - logger.info("Retry because of %s", e) + + except (AssertionError, requests.exceptions.ConnectionError): + logger.exception("Retry") finally: if time.time() > start_time + timeout_seconds: raise Exception("Timed out while testing.") @@ -190,10 +178,6 @@ def _check_nodes(): else: assert detail["raylet"]["state"] == "DEAD" assert detail["raylet"].get("objectStoreAvailableMemory", 0) == 0 - response = requests.get(webui_url + "/test/dump?key=agents") - response.raise_for_status() - agents = response.json() - assert len(agents["data"]["agents"]) == 3 return True except Exception as ex: logger.info(ex) diff --git a/python/ray/dashboard/modules/reporter/reporter_head.py b/python/ray/dashboard/modules/reporter/reporter_head.py index 30295c805f0b3..9e66ba6d78801 100644 --- a/python/ray/dashboard/modules/reporter/reporter_head.py +++ b/python/ray/dashboard/modules/reporter/reporter_head.py @@ -686,7 +686,7 @@ async def _get_stub_address_by_ip( if not agent_addr_json: return None node_id, http_port, grpc_port = json.loads(agent_addr_json) - return node_id, ip, http_port, grpc_port + return NodeID.from_hex(node_id), ip, http_port, grpc_port def _make_stub( self, ip_port: str diff --git a/python/ray/dashboard/modules/serve/tests/test_serve_dashboard.py b/python/ray/dashboard/modules/serve/tests/test_serve_dashboard.py index dc4afb5fb2e2e..22eb3522c211b 100644 --- a/python/ray/dashboard/modules/serve/tests/test_serve_dashboard.py +++ b/python/ray/dashboard/modules/serve/tests/test_serve_dashboard.py @@ -358,12 +358,21 @@ def test_get_serve_instance_details(ray_start_stop, f_deployment_options, url): "docs_path": None, "deployments": {"f", "BasicDriver"}, "source": "declarative", + "required_resources": { + "f": { + "CPU": f_deployment_options.get("ray_actor_options", {}).get( + "num_cpus", 0.1 + ) + }, + "BasicDriver": {"CPU": 0.1}, + }, }, "app2": { "route_prefix": "/banana", "docs_path": "/my_docs", "deployments": {"FastAPIDeployment"}, "source": "declarative", + "required_resources": {"FastAPIDeployment": {"CPU": 1}}, }, } @@ -443,6 +452,10 @@ def applications_running(): == deployment.deployment_config.num_replicas ) assert len(deployment.replicas) == deployment.target_num_replicas + assert ( + deployment.required_resources + == expected_values[app]["required_resources"][deployment.name] + ) for replica in deployment.replicas: assert replica.replica_id diff --git a/python/ray/dashboard/modules/state/state_head.py b/python/ray/dashboard/modules/state/state_head.py index 824fe30265251..1f32eda3574ad 100644 --- a/python/ray/dashboard/modules/state/state_head.py +++ b/python/ray/dashboard/modules/state/state_head.py @@ -75,7 +75,6 @@ def __init__( ) DataSource.nodes.signal.append(self._update_raylet_stubs) - DataSource.agents.signal.append(self._update_agent_stubs) async def limit_handler_(self): return do_reply( @@ -119,20 +118,6 @@ async def _update_raylet_stubs(self, change: Change): int(node_info["runtimeEnvAgentPort"]), ) - async def _update_agent_stubs(self, change: Change): - """Callback that's called when a new agent is added to Datasource.""" - if change.old: - node_id, _ = change.old - self._state_api_data_source_client.unregister_agent_client(node_id) - if change.new: - # When a new node information is written to DataSource. - node_id, (node_ip, http_port, grpc_port) = change.new - self._state_api_data_source_client.register_agent_client( - node_id, - node_ip, - grpc_port, - ) - @routes.get("/api/v0/actors") @RateLimitedModule.enforce_max_concurrent_calls async def list_actors(self, req: aiohttp.web.Request) -> aiohttp.web.Response: diff --git a/python/ray/dashboard/modules/tests/test_head.py b/python/ray/dashboard/modules/tests/test_head.py index 98e46f2fa9828..4258052facb87 100644 --- a/python/ray/dashboard/modules/tests/test_head.py +++ b/python/ray/dashboard/modules/tests/test_head.py @@ -20,16 +20,6 @@ class TestHead(dashboard_utils.DashboardHeadModule): def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig): super().__init__(config) - self._notified_agents = {} - DataSource.agents.signal.append(self._update_notified_agents) - - async def _update_notified_agents(self, change): - if change.old: - node_id, _ = change.old - self._notified_agents.pop(node_id) - if change.new: - node_id, (node_ip, http_port, grpc_port) = change.new - self._notified_agents[node_id] = (node_ip, http_port, grpc_port) @staticmethod def is_minimal_module(): @@ -73,14 +63,6 @@ async def dump(self, req) -> aiohttp.web.Response: **{key: data}, ) - @routes.get("/test/notified_agents") - async def get_notified_agents(self, req) -> aiohttp.web.Response: - return dashboard_optional_utils.rest_response( - success=True, - message="Fetch notified agents success.", - **self._notified_agents, - ) - @routes.get("/test/http_get") async def get_url(self, req) -> aiohttp.web.Response: url = req.query.get("url") diff --git a/python/ray/dashboard/subprocesses/handle.py b/python/ray/dashboard/subprocesses/handle.py index c94e9f8516f29..469ce8398a49b 100644 --- a/python/ray/dashboard/subprocesses/handle.py +++ b/python/ray/dashboard/subprocesses/handle.py @@ -392,7 +392,7 @@ def dispatch_parent_bound_messages(self): except ValueError: # queue is closed. break - except Exception as e: + except Exception: logger.exception( f"Error unpickling parent bound message from {self_str}." " This may result in a http request never being responded to." @@ -400,7 +400,7 @@ def dispatch_parent_bound_messages(self): continue try: self.handle_parent_bound_message(message) - except Exception as e: + except Exception: logger.exception( f"Error handling parent bound message from {self_str}." " This may result in a http request never being responded to." diff --git a/python/ray/dashboard/subprocesses/module.py b/python/ray/dashboard/subprocesses/module.py index 70af520c33ebf..8ca4c3304a1b5 100644 --- a/python/ray/dashboard/subprocesses/module.py +++ b/python/ray/dashboard/subprocesses/module.py @@ -119,7 +119,7 @@ def dispatch_child_bound_messages( message = self._child_bound_queue.get() try: self.handle_child_bound_message(loop, message) - except Exception as e: + except Exception: logger.exception( f"Error handling child bound message {message}. This request will hang forever." ) diff --git a/python/ray/dashboard/tests/test_dashboard.py b/python/ray/dashboard/tests/test_dashboard.py index fd652926ea96d..82b6be6ae4393 100644 --- a/python/ray/dashboard/tests/test_dashboard.py +++ b/python/ray/dashboard/tests/test_dashboard.py @@ -381,11 +381,15 @@ def test_http_get(enable_test_module, ray_start_with_dashboard): logger.info("failed response: %s", response.text) raise ex assert dump_info["result"] is True - dump_data = dump_info["data"] - assert len(dump_data["agents"]) == 1 - node_id, (node_ip, http_port, grpc_port) = next( - iter(dump_data["agents"].items()) + + # Get agent ip and http port + node_id_hex = ray_start_with_dashboard["node_id"] + agent_addr = ray.experimental.internal_kv._internal_kv_get( + f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id_hex}", + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, ) + assert agent_addr is not None + node_ip, http_port, _ = json.loads(agent_addr) response = requests.get( f"http://{node_ip}:{http_port}" diff --git a/python/ray/data/BUILD b/python/ray/data/BUILD index f992b45ec1d78..294402b349bba 100644 --- a/python/ray/data/BUILD +++ b/python/ray/data/BUILD @@ -101,7 +101,7 @@ py_test( name = "test_arrow_block", size = "medium", srcs = ["tests/test_arrow_block.py"], - tags = ["team:data", "exclusive"], + tags = ["team:data", "exclusive", "data_non_parallel"], deps = ["//:ray_lib", ":conftest"], ) @@ -145,6 +145,14 @@ py_test( deps = ["//:ray_lib", ":conftest"], ) +py_test( + name = "test_audio", + size = "small", + srcs = ["tests/test_audio.py"], + tags = ["team:data", "exclusive"], + deps = ["//:ray_lib", ":conftest"], +) + py_test( name = "test_avro", size = "small", @@ -565,6 +573,14 @@ py_test( deps = ["//:ray_lib", ":conftest"], ) +py_test( + name = "test_video", + size = "small", + srcs = ["tests/test_video.py"], + tags = ["team:data", "exclusive"], + deps = ["//:ray_lib", ":conftest"], +) + py_test( name = "test_webdataset", size = "medium", diff --git a/python/ray/data/__init__.py b/python/ray/data/__init__.py index 5e3747be38757..087d41ab38d64 100644 --- a/python/ray/data/__init__.py +++ b/python/ray/data/__init__.py @@ -42,6 +42,7 @@ from_torch, range, range_tensor, + read_audio, read_avro, read_bigquery, read_binary_files, @@ -62,6 +63,7 @@ read_sql, read_text, read_tfrecords, + read_videos, read_webdataset, ) @@ -137,6 +139,7 @@ "from_huggingface", "range", "range_tensor", + "read_audio", "read_avro", "read_text", "read_binary_files", @@ -155,6 +158,7 @@ "read_parquet_bulk", "read_sql", "read_tfrecords", + "read_videos", "read_webdataset", "Preprocessor", "TFXReadOptions", diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 9674d8e94c9b6..9fe99e4eb1b18 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, List, Union, Dict import numpy as np from packaging.version import parse as parse_version @@ -217,6 +217,164 @@ def _concatenate_chunked_arrays(arrs: "pyarrow.ChunkedArray") -> "pyarrow.Chunke return pyarrow.chunked_array(chunks, type=type_) +def _extract_unified_struct_types( + schema: "pyarrow.Schema", +) -> Dict[str, "pyarrow.StructType"]: + """ + Extract all struct fields from a schema and map their names to types. + + Args: + schema: Arrow schema to extract struct types from. + + Returns: + Dict[str, pa.StructType]: Mapping of struct field names to their types. + """ + import pyarrow as pa + + return { + field.name: field.type for field in schema if pa.types.is_struct(field.type) + } + + +def _backfill_missing_fields( + column: "pyarrow.ChunkedArray", + unified_struct_type: "pyarrow.StructType", + block_length: int, +) -> "pyarrow.StructArray": + """ + Align a struct column's fields to match the unified schema's struct type. + + Args: + column: The column data to align. + unified_struct_type: The unified struct type to align to. + block_length: The number of rows in the block. + + Returns: + pa.StructArray: The aligned struct array. + """ + import pyarrow as pa + + # Flatten chunked arrays into a single array if necessary + if isinstance(column, pa.ChunkedArray): + column = pa.concat_arrays(column.chunks) + + # Extract the current struct field names and their corresponding data + current_fields = { + field.name: column.field(i) for i, field in enumerate(column.type) + } + + # Assert that the current fields are a subset of the unified struct type's field names + unified_field_names = {field.name for field in unified_struct_type} + assert set(current_fields.keys()).issubset( + unified_field_names + ), f"Fields {set(current_fields.keys())} are not a subset of unified struct fields {unified_field_names}." + + # Early exit if no fields are missing in the schema + if column.type == unified_struct_type: + return column + + aligned_fields = [] + + # Iterate over the fields in the unified struct type schema + for field in unified_struct_type: + field_name = field.name + field_type = field.type + + if field_name in current_fields: + # If the field exists in the current column, align it + current_array = current_fields[field_name] + if pa.types.is_struct(field_type): + # Recursively align nested struct fields + current_array = _backfill_missing_fields( + column=current_array, + unified_struct_type=field_type, + block_length=block_length, + ) + aligned_fields.append(current_array) + else: + # If the field is missing, fill with nulls + aligned_fields.append(pa.nulls(block_length, type=field_type)) + + # Reconstruct the struct column with aligned fields + return pa.StructArray.from_arrays( + aligned_fields, + fields=unified_struct_type, + ) + + +def _align_struct_fields( + blocks: List["pyarrow.Table"], schema: "pyarrow.Schema" +) -> List["pyarrow.Table"]: + """ + Align struct columns across blocks to match the provided schema. + + Args: + blocks: List of Arrow tables to align. + schema: Unified schema with desired struct column alignment. + + Returns: + List[pa.Table]: List of aligned Arrow tables. + """ + import pyarrow as pa + + # Check if all block schemas are already aligned + if all(block.schema == schema for block in blocks): + return blocks + + # Extract all struct column types from the provided schema + unified_struct_types = _extract_unified_struct_types(schema) + + # If there are no struct columns in the schema, return blocks as is + if not unified_struct_types: + return blocks + + aligned_blocks = [] + + # Iterate over each block (table) in the list + for block in blocks: + # Store aligned struct columns + aligned_columns = {} + + # Get the number of rows in the block + block_length = len(block) + + # Process each struct column defined in the unified schema + for column_name, unified_struct_type in unified_struct_types.items(): + # If the column exists in the block, align its fields + if column_name in block.schema.names: + column = block[column_name] + + # Check if the column type matches a struct type + if isinstance(column.type, pa.StructType): + aligned_columns[column_name] = _backfill_missing_fields( + column, unified_struct_type, block_length + ) + else: + # If the column is not a struct, simply keep the original column + aligned_columns[column_name] = column + else: + # If the column is missing, create a null-filled column with the same + # length as the block + aligned_columns[column_name] = pa.array( + [None] * block_length, type=unified_struct_type + ) + + # Create a new aligned block with the updated columns and the unified schema. + new_columns = [] + for column_name in schema.names: + if column_name in aligned_columns: + # Use the aligned column if available + new_columns.append(aligned_columns[column_name]) + else: + # Use the original column if not aligned + assert column_name in block.schema.names + new_columns.append(block[column_name]) + aligned_blocks.append(pa.table(new_columns, schema=schema)) + + # Return the list of aligned blocks + return aligned_blocks + + def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table": """Concatenate provided Arrow Tables into a single Arrow Table. This has special handling for extension types that pyarrow.concat_tables does not yet support. @@ -240,6 +398,16 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table": if len(blocks) == 1: return blocks[0] + # If the result contains pyarrow schemas, unify them + schemas_to_unify = [b.schema for b in blocks] + try: + schema = unify_schemas(schemas_to_unify) + except Exception as e: + raise ArrowConversionError(str(blocks)) from e + + # Handle alignment of struct type columns. + blocks = _align_struct_fields(blocks, schema) + # Rollup columns with opaque (null-typed) lists, to process in following for-loop. cols_with_null_list = set() for b in blocks: @@ -248,13 +416,6 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table": if pa.types.is_list(col_type) and pa.types.is_null(col_type.value_type): cols_with_null_list.add(col_name) - # If the result contains pyarrow schemas, unify them - schemas_to_unify = [b.schema for b in blocks] - try: - schema = unify_schemas(schemas_to_unify) - except Exception as e: - raise ArrowConversionError(str(blocks)) from e - if ( any(isinstance(type_, pa.ExtensionType) for type_ in schema.types) or cols_with_null_list diff --git a/python/ray/data/_internal/datasource/audio_datasource.py b/python/ray/data/_internal/datasource/audio_datasource.py new file mode 100644 index 0000000000000..3a3a28d26f45f --- /dev/null +++ b/python/ray/data/_internal/datasource/audio_datasource.py @@ -0,0 +1,57 @@ +import io +from typing import TYPE_CHECKING, Iterator, List, Union + +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.util import _check_import +from ray.data.block import Block +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + + +class AudioDatasource(FileBasedDatasource): + _FILE_EXTENSIONS = [ + "mp3", + "wav", + "aac", + "flac", + "ogg", + "m4a", + "wma", + "alac", + "aiff", + "pcm", + "amr", + "opus", + "ra", + "rm", + "au", + "mid", + "midi", + "caf", + ] + + def __init__( + self, + paths: Union[str, List[str]], + **file_based_datasource_kwargs, + ): + super().__init__(paths, **file_based_datasource_kwargs) + + _check_import(self, module="soundfile", package="soundfile") + + def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: + import soundfile + + # `soundfile` doesn't support reading from a `pyarrow.NativeFile` directly, so + # we need to read the file into memory first. + stream = io.BytesIO(f.read()) + amplitude, sample_rate = soundfile.read(stream, always_2d=True, dtype="float32") + + # (amplitude, channels) -> (channels, amplitude) + amplitude = amplitude.transpose((1, 0)) + + builder = DelegatingBlockBuilder() + builder.add({"amplitude": amplitude, "sample_rate": sample_rate}) + yield builder.build() diff --git a/python/ray/data/_internal/datasource/bigquery_datasink.py b/python/ray/data/_internal/datasource/bigquery_datasink.py index 92178996f3eaf..5bc3f39a218de 100644 --- a/python/ray/data/_internal/datasource/bigquery_datasink.py +++ b/python/ray/data/_internal/datasource/bigquery_datasink.py @@ -8,8 +8,8 @@ import pyarrow.parquet as pq import ray -from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.datasource import bigquery_datasource +from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import _check_import from ray.data.block import Block, BlockAccessor diff --git a/python/ray/data/_internal/datasource/parquet_datasink.py b/python/ray/data/_internal/datasource/parquet_datasink.py index 3f4706ccbb0db..00f3551950792 100644 --- a/python/ray/data/_internal/datasource/parquet_datasink.py +++ b/python/ray/data/_internal/datasource/parquet_datasink.py @@ -6,7 +6,6 @@ from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.util import call_with_retry from ray.data.block import Block, BlockAccessor -from ray.data.context import DataContext from ray.data.datasource.file_based_datasource import _resolve_kwargs from ray.data.datasource.file_datasink import _FileDatasink from ray.data.datasource.filename_provider import FilenameProvider @@ -28,7 +27,7 @@ def __init__( partition_cols: Optional[List[str]] = None, arrow_parquet_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, arrow_parquet_args: Optional[Dict[str, Any]] = None, - num_rows_per_file: Optional[int] = None, + min_rows_per_file: Optional[int] = None, filesystem: Optional["pyarrow.fs.FileSystem"] = None, try_create_dir: bool = True, open_stream_args: Optional[Dict[str, Any]] = None, @@ -43,7 +42,7 @@ def __init__( self.arrow_parquet_args_fn = arrow_parquet_args_fn self.arrow_parquet_args = arrow_parquet_args - self.num_rows_per_file = num_rows_per_file + self.min_rows_per_file = min_rows_per_file self.partition_cols = partition_cols super().__init__( @@ -95,7 +94,7 @@ def write_blocks_to_path(): call_with_retry( write_blocks_to_path, description=f"write '{filename}' to '{self.path}'", - match=DataContext.get_current().retried_io_errors, + match=self._data_context.retried_io_errors, max_attempts=WRITE_FILE_MAX_ATTEMPTS, max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS, ) @@ -168,5 +167,5 @@ def _write_partition_files( writer.write_table(group_table) @property - def num_rows_per_write(self) -> Optional[int]: - return self.num_rows_per_file + def min_rows_per_write(self) -> Optional[int]: + return self.min_rows_per_file diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index b15d27baa2baf..7dec7eff46074 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -21,6 +21,7 @@ from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import ( + RetryingPyFileSystem, _check_pyarrow_version, _is_local_scheme, call_with_retry, @@ -196,6 +197,9 @@ def __init__( ) paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem) + filesystem = RetryingPyFileSystem.wrap( + filesystem, context=DataContext.get_current() + ) # HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet # files. To avoid this, we expand the input paths with the default metadata diff --git a/python/ray/data/_internal/datasource/sql_datasource.py b/python/ray/data/_internal/datasource/sql_datasource.py index c38f5bdd6a4a6..4d69022c5f47c 100644 --- a/python/ray/data/_internal/datasource/sql_datasource.py +++ b/python/ray/data/_internal/datasource/sql_datasource.py @@ -1,3 +1,5 @@ +import logging +import math from contextlib import contextmanager from typing import Any, Callable, Iterable, Iterator, List, Optional @@ -7,6 +9,8 @@ Connection = Any # A Python DB API2-compliant `Connection` object. Cursor = Any # A Python DB API2-compliant `Cursor` object. +logger = logging.getLogger(__name__) + def _cursor_to_block(cursor) -> Block: import pyarrow as pa @@ -71,19 +75,113 @@ def _connect(connection_factory: Callable[[], Connection]) -> Iterator[Cursor]: class SQLDatasource(Datasource): - def __init__(self, sql: str, connection_factory: Callable[[], Connection]): + MIN_ROWS_PER_READ_TASK = 50 + + def __init__( + self, + sql: str, + connection_factory: Callable[[], Connection], + shard_hash_fn: str, + shard_keys: Optional[List[str]] = None, + ): self.sql = sql + if shard_keys and len(shard_keys) > 1: + self.shard_keys = f"CONCAT({','.join(shard_keys)})" + elif shard_keys and len(shard_keys) == 1: + self.shard_keys = f"{shard_keys[0]}" + else: + self.shard_keys = None + self.shard_hash_fn = shard_hash_fn self.connection_factory = connection_factory def estimate_inmemory_data_size(self) -> Optional[int]: return None + def supports_sharding(self, parallelism: int) -> bool: + """Check if database supports sharding with MOD/ABS/CONCAT operations. + + Returns: + bool: True if sharding is supported, False otherwise. + """ + if parallelism <= 1 or self.shard_keys is None: + return False + + # Test if database supports required operations (MOD, ABS, MD5, CONCAT) + # by executing a sample query + hash_fn = self.shard_hash_fn + query = ( + f"SELECT COUNT(1) FROM ({self.sql}) as T" + f" WHERE MOD(ABS({hash_fn}({self.shard_keys})), {parallelism}) = 0" + ) + try: + with _connect(self.connection_factory) as cursor: + cursor.execute(query) + return True + except Exception as e: + logger.info(f"Database does not support sharding: {str(e)}.") + return False + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: - def read_fn() -> Iterable[Block]: + def fallback_read_fn() -> Iterable[Block]: + """Read all data in a single block when sharding is not supported.""" with _connect(self.connection_factory) as cursor: cursor.execute(self.sql) + return [_cursor_to_block(cursor)] + + num_rows_total = self._get_num_rows() + + if num_rows_total == 0: + return [] + + parallelism = min( + parallelism, math.ceil(num_rows_total / self.MIN_ROWS_PER_READ_TASK) + ) + num_rows_per_block = num_rows_total // parallelism + num_blocks_with_extra_row = num_rows_total % parallelism + + # Check if sharding is supported by the database + # If not, fall back to reading all data in a single task + if not self.supports_sharding(parallelism): + logger.info( + "Sharding is not supported. " + "Falling back to reading all data in a single task." + ) + metadata = BlockMetadata(None, None, None, None, None) + return [ReadTask(fallback_read_fn, metadata)] + + tasks = [] + for i in range(parallelism): + num_rows = num_rows_per_block + if i < num_blocks_with_extra_row: + num_rows += 1 + read_fn = self._create_parallel_read_fn(i, parallelism) + metadata = BlockMetadata( + num_rows=num_rows, + size_bytes=None, + schema=None, + input_files=None, + exec_stats=None, + ) + tasks.append(ReadTask(read_fn, metadata)) + + return tasks + + def _get_num_rows(self) -> int: + with _connect(self.connection_factory) as cursor: + cursor.execute(f"SELECT COUNT(*) FROM ({self.sql}) as T") + return cursor.fetchone()[0] + + def _create_parallel_read_fn(self, task_id: int, parallelism: int): + hash_fn = self.shard_hash_fn + query = ( + f"SELECT * FROM ({self.sql}) as T " + f"WHERE MOD(ABS({hash_fn}({self.shard_keys})), {parallelism}) = {task_id}" + ) + + def read_fn() -> Iterable[Block]: + with _connect(self.connection_factory) as cursor: + cursor.execute(query) block = _cursor_to_block(cursor) return [block] - metadata = BlockMetadata(None, None, None, None, None) - return [ReadTask(read_fn, metadata)] + return read_fn diff --git a/python/ray/data/_internal/datasource/tfrecords_datasource.py b/python/ray/data/_internal/datasource/tfrecords_datasource.py index 076325903374d..925f45657e1d2 100644 --- a/python/ray/data/_internal/datasource/tfrecords_datasource.py +++ b/python/ray/data/_internal/datasource/tfrecords_datasource.py @@ -127,14 +127,19 @@ def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block raise RuntimeError(f"Failed to read TFRecord file {full_path}.") def _resolve_full_path(self, relative_path): - if isinstance(self._filesystem, pyarrow.fs.S3FileSystem): + from ray.data._internal.util import RetryingPyFileSystem + + filesystem = self._filesystem + if isinstance(filesystem, RetryingPyFileSystem): + filesystem = filesystem.unwrap() + if isinstance(filesystem, pyarrow.fs.S3FileSystem): return f"s3://{relative_path}" - if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem): + if isinstance(filesystem, pyarrow.fs.GcsFileSystem): return f"gs://{relative_path}" - if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem): + if isinstance(filesystem, pyarrow.fs.HadoopFileSystem): return f"hdfs:///{relative_path}" - if isinstance(self._filesystem, pyarrow.fs.PyFileSystem): - protocol = self._filesystem.handler.fs.protocol + if isinstance(filesystem, pyarrow.fs.PyFileSystem): + protocol = filesystem.handler.fs.protocol if isinstance(protocol, list) or isinstance(protocol, tuple): protocol = protocol[0] if protocol == "gcs": diff --git a/python/ray/data/_internal/datasource/video_datasource.py b/python/ray/data/_internal/datasource/video_datasource.py new file mode 100644 index 0000000000000..2a4a06e876e21 --- /dev/null +++ b/python/ray/data/_internal/datasource/video_datasource.py @@ -0,0 +1,59 @@ +import logging +from typing import TYPE_CHECKING, List, Union + +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.util import _check_import +from ray.data.datasource.file_based_datasource import FileBasedDatasource + +if TYPE_CHECKING: + import pyarrow + +logger = logging.getLogger(__name__) + + +class VideoDatasource(FileBasedDatasource): + _FILE_EXTENSIONS = [ + "mp4", + "mkv", + "mov", + "avi", + "wmv", + "flv", + "webm", + "m4v", + "3gp", + "mpeg", + "mpg", + "ts", + "ogv", + "rm", + "rmvb", + "vob", + "asf", + "f4v", + "m2ts", + "mts", + "divx", + "xvid", + "mxf", + ] + + def __init__( + self, + paths: Union[str, List[str]], + **file_based_datasource_kwargs, + ): + super().__init__(paths, **file_based_datasource_kwargs) + + _check_import(self, module="decord", package="decord") + + def _read_stream(self, f: "pyarrow.NativeFile", path: str): + from decord import VideoReader + + reader = VideoReader(f) + + for frame_index, frame in enumerate(reader): + item = {"frame": frame.asnumpy(), "frame_index": frame_index} + builder = DelegatingBlockBuilder() + builder.add(item) + yield builder.build() diff --git a/python/ray/data/_internal/datasource/webdataset_datasource.py b/python/ray/data/_internal/datasource/webdataset_datasource.py index bc5661647b047..0289b5ac4838b 100644 --- a/python/ray/data/_internal/datasource/webdataset_datasource.py +++ b/python/ray/data/_internal/datasource/webdataset_datasource.py @@ -9,7 +9,6 @@ from functools import partial from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union -import ray from ray.data._internal.util import iterate_with_retry from ray.data.block import BlockAccessor from ray.data.datasource.file_based_datasource import FileBasedDatasource @@ -353,9 +352,10 @@ def get_tar_file_iterator(): ) # S3 can raise transient errors during iteration - ctx = ray.data.DataContext.get_current() files = iterate_with_retry( - get_tar_file_iterator, "iterate tar file", match=ctx.retried_io_errors + get_tar_file_iterator, + "iterate tar file", + match=self._data_context.retried_io_errors, ) samples = _group_by_keys(files, meta=dict(__url__=path), suffixes=self.suffixes) diff --git a/python/ray/data/_internal/execution/interfaces/executor.py b/python/ray/data/_internal/execution/interfaces/executor.py index 007346b60f294..003583e9fb58e 100644 --- a/python/ray/data/_internal/execution/interfaces/executor.py +++ b/python/ray/data/_internal/execution/interfaces/executor.py @@ -61,10 +61,14 @@ def execute( """ raise NotImplementedError - def shutdown(self): + def shutdown(self, exception: Optional[Exception] = None): """Shutdown an executor, which may still be running. This should interrupt execution and clean up any used resources. + + Args: + exception: The exception that causes the executor to shut down, or None if + the executor finishes successfully. """ pass diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index 678ff6c0d5bbd..3cbfe5a69ab2a 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -275,7 +275,7 @@ def shutdown(self): # parallelization across the actor pool. We only know this information after # execution has completed. min_workers = self._actor_pool.min_size() - if len(self._output_metadata) < min_workers: + if len(self._output_blocks_stats) < min_workers: # The user created a stream that has too few blocks to begin with. logger.warning( "To ensure full parallelization across an actor pool of size " @@ -416,6 +416,18 @@ def submit( def __repr__(self): return f"MapWorker({self.src_fn_name})" + def on_exit(self): + """Called when the actor is about to exist. + This enables performing cleanup operations via `UDF.__del__`. + + Note, this only ensures cleanup is performed when the job exists gracefully. + If the driver or the actor is forcefully killed, `__del__` will not be called. + """ + # `_map_actor_context` is a global variable that references the UDF object. + # Delete it to trigger `UDF.__del__`. + del ray.data._map_actor_context + ray.data._map_actor_context = None + @dataclass class _ActorState: @@ -746,6 +758,9 @@ def _remove_actor(self, actor: ray.actor.ActorHandle): # garbage collect the actor, instead of using ray.kill. # Because otherwise the actor cannot be restarted upon lineage reconstruction. if actor in self._running_actors: + # Call `on_exit` to trigger `UDF.__del__` which may perform + # cleanup operations. + actor.on_exit.remote() del self._running_actors[actor] def _get_location(self, bundle: RefBundle) -> Optional[NodeIdStr]: diff --git a/python/ray/data/_internal/execution/operators/limit_operator.py b/python/ray/data/_internal/execution/operators/limit_operator.py index 47e93ae3f22c4..ffc1b42555aac 100644 --- a/python/ray/data/_internal/execution/operators/limit_operator.py +++ b/python/ray/data/_internal/execution/operators/limit_operator.py @@ -9,7 +9,7 @@ ) from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.stats import StatsDict -from ray.data.block import Block, BlockAccessor, BlockMetadata +from ray.data.block import Block, BlockAccessor, BlockMetadata, BlockStats from ray.data.context import DataContext from ray.types import ObjectRef @@ -27,7 +27,7 @@ def __init__( self._consumed_rows = 0 self._buffer: Deque[RefBundle] = deque() self._name = f"limit={limit}" - self._output_metadata: List[BlockMetadata] = [] + self._output_blocks_stats: List[BlockStats] = [] self._cur_output_bundles = 0 super().__init__(self._name, input_op, data_context, target_max_block_size=None) if self._limit <= 0: @@ -49,7 +49,7 @@ def _add_input_inner(self, refs: RefBundle, input_index: int) -> None: if self._consumed_rows + num_rows <= self._limit: out_blocks.append(block) out_metadata.append(metadata) - self._output_metadata.append(metadata) + self._output_blocks_stats.append(metadata.to_stats()) self._consumed_rows += num_rows else: # Slice the last block. @@ -70,7 +70,7 @@ def slice_fn(block, metadata, num_rows) -> Tuple[Block, BlockMetadata]: out_blocks.append(block) metadata = ray.get(metadata_ref) out_metadata.append(metadata) - self._output_metadata.append(metadata) + self._output_blocks_stats.append(metadata.to_stats()) self._consumed_rows = self._limit break self._cur_output_bundles += 1 @@ -109,7 +109,7 @@ def _get_next_inner(self) -> RefBundle: return output def get_stats(self) -> StatsDict: - return {self._name: self._output_metadata} + return {self._name: self._output_blocks_stats} def num_outputs_total(self) -> Optional[int]: # Before execution is completed, we don't know how many output diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 49169ca750ca7..e4a20c9fd5b5f 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -45,7 +45,14 @@ MapTransformer, ) from ray.data._internal.stats import StatsDict -from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata +from ray.data.block import ( + Block, + BlockAccessor, + BlockExecStats, + BlockMetadata, + BlockStats, + to_stats, +) from ray.data.context import DataContext from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -88,7 +95,7 @@ def __init__( # Queue for task outputs, either ordered or unordered (this is set by start()). self._output_queue: _OutputQueue = None # Output metadata, added to on get_next(). - self._output_metadata: List[BlockMetadata] = [] + self._output_blocks_stats: List[BlockStats] = [] # All active `DataOpTask`s. self._data_tasks: Dict[int, DataOpTask] = {} self._next_data_task_idx = 0 @@ -428,7 +435,7 @@ def _get_next_inner(self) -> RefBundle: assert self._started bundle = self._output_queue.get_next() self._metrics.on_output_dequeued(bundle) - self._output_metadata.extend(bundle.metadata) + self._output_blocks_stats.extend(to_stats(bundle.metadata)) return bundle @abstractmethod @@ -439,7 +446,7 @@ def _extra_metrics(self) -> Dict[str, Any]: return {"ray_remote_args": dict(sorted(self._remote_args_for_metrics.items()))} def get_stats(self) -> StatsDict: - return {self._name: self._output_metadata} + return {self._name: self._output_blocks_stats} def get_map_transformer(self) -> MapTransformer: return self._map_transformer diff --git a/python/ray/data/_internal/execution/operators/map_transformer.py b/python/ray/data/_internal/execution/operators/map_transformer.py index d3135fce59a47..61546b24e7746 100644 --- a/python/ray/data/_internal/execution/operators/map_transformer.py +++ b/python/ray/data/_internal/execution/operators/map_transformer.py @@ -27,6 +27,21 @@ class MapTransformFnDataType(Enum): Batch = 2 +class MapTransformFnCategory(Enum): + """An enum that represents the PreProcess/DataProcess/PostProcess category of a + MapTransformFn. + """ + + # Data format conversion before the actual data processing, i.e. converting input blocks to rows or batches. + PreProcess = 0 + + # Actual Data processing/transformation. + DataProcess = 1 + + # Data format conversion after the actual data processing, i.e., converting rows or batches to output blocks. + PostProcess = 2 + + class MapTransformFn: """Represents a single transform function in a MapTransformer.""" @@ -34,6 +49,7 @@ def __init__( self, input_type: MapTransformFnDataType, output_type: MapTransformFnDataType, + category: MapTransformFnCategory, is_udf: bool = False, ): """ @@ -45,6 +61,7 @@ def __init__( self._callable = callable self._input_type = input_type self._output_type = output_type + self._category = category self._target_max_block_size = None self._is_udf = is_udf @@ -64,6 +81,10 @@ def input_type(self) -> MapTransformFnDataType: def output_type(self) -> MapTransformFnDataType: return self._output_type + @property + def category(self) -> MapTransformFnCategory: + return self._category + def set_target_max_block_size(self, target_max_block_size: int): self._target_max_block_size = target_max_block_size @@ -90,6 +111,7 @@ def __init__( Used for the actor-based map operator. """ self.set_transform_fns(transform_fns) + self._init_fn = init_fn if init_fn is not None else lambda: None self._target_max_block_size = None self._udf_time = 0 @@ -209,7 +231,10 @@ class RowMapTransformFn(MapTransformFn): def __init__(self, row_fn: MapTransformCallable[Row, Row], is_udf: bool = False): self._row_fn = row_fn super().__init__( - MapTransformFnDataType.Row, MapTransformFnDataType.Row, is_udf=is_udf + MapTransformFnDataType.Row, + MapTransformFnDataType.Row, + category=MapTransformFnCategory.DataProcess, + is_udf=is_udf, ) def __call__(self, input: Iterable[Row], ctx: TaskContext) -> Iterable[Row]: @@ -218,6 +243,9 @@ def __call__(self, input: Iterable[Row], ctx: TaskContext) -> Iterable[Row]: def __repr__(self) -> str: return f"RowMapTransformFn({self._row_fn})" + def __eq__(self, other): + return isinstance(other, RowMapTransformFn) and self._row_fn == other._row_fn + class BatchMapTransformFn(MapTransformFn): """A batch-to-batch MapTransformFn.""" @@ -227,7 +255,10 @@ def __init__( ): self._batch_fn = batch_fn super().__init__( - MapTransformFnDataType.Batch, MapTransformFnDataType.Batch, is_udf=is_udf + MapTransformFnDataType.Batch, + MapTransformFnDataType.Batch, + category=MapTransformFnCategory.DataProcess, + is_udf=is_udf, ) def __call__( @@ -238,6 +269,11 @@ def __call__( def __repr__(self) -> str: return f"BatchMapTransformFn({self._batch_fn})" + def __eq__(self, other): + return ( + isinstance(other, BatchMapTransformFn) and self._batch_fn == other._batch_fn + ) + class BlockMapTransformFn(MapTransformFn): """A block-to-block MapTransformFn.""" @@ -247,6 +283,7 @@ def __init__(self, block_fn: MapTransformCallable[Block, Block]): super().__init__( MapTransformFnDataType.Block, MapTransformFnDataType.Block, + category=MapTransformFnCategory.DataProcess, ) def __call__(self, input: Iterable[Block], ctx: TaskContext) -> Iterable[Block]: @@ -255,6 +292,11 @@ def __call__(self, input: Iterable[Block], ctx: TaskContext) -> Iterable[Block]: def __repr__(self) -> str: return f"BlockMapTransformFn({self._block_fn})" + def __eq__(self, other): + return ( + isinstance(other, BlockMapTransformFn) and self._block_fn == other._block_fn + ) + class BlocksToRowsMapTransformFn(MapTransformFn): """A MapTransformFn that converts input blocks to rows.""" @@ -263,6 +305,7 @@ def __init__(self): super().__init__( MapTransformFnDataType.Block, MapTransformFnDataType.Row, + category=MapTransformFnCategory.PreProcess, ) def __call__(self, blocks: Iterable[Block], _: TaskContext) -> Iterable[Row]: @@ -281,6 +324,9 @@ def instance(cls) -> "BlocksToRowsMapTransformFn": def __repr__(self) -> str: return "BlocksToRowsMapTransformFn()" + def __eq__(self, other): + return isinstance(other, BlocksToRowsMapTransformFn) + class BlocksToBatchesMapTransformFn(MapTransformFn): """A MapTransformFn that converts input blocks to batches.""" @@ -297,6 +343,7 @@ def __init__( super().__init__( MapTransformFnDataType.Block, MapTransformFnDataType.Batch, + category=MapTransformFnCategory.PreProcess, ) def __call__( @@ -352,6 +399,14 @@ def __repr__(self) -> str: f")" ) + def __eq__(self, other): + return ( + isinstance(other, BlocksToBatchesMapTransformFn) + and self.batch_format == other.batch_format + and self.batch_size == other.batch_size + and self.zero_copy_batch == other.zero_copy_batch + ) + class BuildOutputBlocksMapTransformFn(MapTransformFn): """A MapTransformFn that converts UDF-returned data to output blocks.""" @@ -365,6 +420,7 @@ def __init__(self, input_type: MapTransformFnDataType): super().__init__( input_type, MapTransformFnDataType.Block, + category=MapTransformFnCategory.PostProcess, ) def __call__( @@ -415,6 +471,12 @@ def for_blocks(cls) -> "BuildOutputBlocksMapTransformFn": def __repr__(self) -> str: return f"BuildOutputBlocksMapTransformFn(input_type={self._input_type})" + def __eq__(self, other): + return ( + isinstance(other, BuildOutputBlocksMapTransformFn) + and self.input_type == other.input_type + ) + def _splitrange(n, k): """Calculates array lens of np.array_split(). @@ -445,7 +507,11 @@ def __init__(self, additional_split_factor: int): """ assert additional_split_factor > 1 self._additional_split_factor = additional_split_factor - super().__init__(MapTransformFnDataType.Block, MapTransformFnDataType.Block) + super().__init__( + MapTransformFnDataType.Block, + MapTransformFnDataType.Block, + category=MapTransformFnCategory.PostProcess, + ) def __call__(self, blocks: Iterable[Block], ctx: TaskContext) -> Iterable[Block]: for block in blocks: diff --git a/python/ray/data/_internal/execution/operators/zip_operator.py b/python/ray/data/_internal/execution/operators/zip_operator.py index 552639ef97ccb..7f5df35c00f44 100644 --- a/python/ray/data/_internal/execution/operators/zip_operator.py +++ b/python/ray/data/_internal/execution/operators/zip_operator.py @@ -13,6 +13,7 @@ BlockExecStats, BlockMetadata, BlockPartition, + to_stats, ) from ray.data.context import DataContext @@ -205,7 +206,7 @@ def _zip( owns_blocks=input_owned, ) ) - stats = {self._name: output_metadata} + stats = {self._name: to_stats(output_metadata)} # Clean up inputs. for ref in left_input: diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py index 9342bebc098c8..04c1f93a23014 100644 --- a/python/ray/data/_internal/execution/streaming_executor.py +++ b/python/ray/data/_internal/execution/streaming_executor.py @@ -165,7 +165,9 @@ def get_next(self, output_split_idx: Optional[int] = None) -> RefBundle: # Needs to be BaseException to catch KeyboardInterrupt. Otherwise we # can leave dangling progress bars by skipping shutdown. except BaseException as e: - self._outer.shutdown(isinstance(e, StopIteration)) + self._outer.shutdown( + e if not isinstance(e, StopIteration) else None + ) raise def __del__(self): @@ -176,7 +178,7 @@ def __del__(self): def __del__(self): self.shutdown() - def shutdown(self, execution_completed: bool = True): + def shutdown(self, exception: Optional[Exception] = None): global _num_shutdown with self._shutdown_lock: @@ -188,7 +190,7 @@ def shutdown(self, execution_completed: bool = True): # Give the scheduling loop some time to finish processing. self.join(timeout=2.0) self._update_stats_metrics( - state="FINISHED" if execution_completed else "FAILED", + state="FINISHED" if exception is None else "FAILED", force_update=True, ) # Once Dataset execution completes, mark it as complete @@ -206,7 +208,7 @@ def shutdown(self, execution_completed: bool = True): if self._global_info: # Set the appropriate description that summarizes # the result of dataset execution. - if execution_completed: + if exception is None: prog_bar_msg = ( f"{OK_PREFIX} Dataset execution finished in " f"{self._final_stats.time_total_s:.2f} seconds" @@ -218,6 +220,12 @@ def shutdown(self, execution_completed: bool = True): for op, state in self._topology.items(): op.shutdown() state.close_progress_bars() + if exception is None: + for callback in get_execution_callbacks(self._data_context): + callback.after_execution_succeeds() + else: + for callback in get_execution_callbacks(self._data_context): + callback.after_execution_fails(exception) self._autoscaler.on_executor_shutdown() def run(self): @@ -237,13 +245,9 @@ def run(self): ) if not continue_sched or self._shutdown: break - for callback in get_execution_callbacks(self._data_context): - callback.after_execution_succeeds() except Exception as e: # Propagate it to the result iterator. self._output_node.mark_finished(e) - for callback in get_execution_callbacks(self._data_context): - callback.after_execution_fails(e) finally: # Signal end of results. self._output_node.mark_finished() diff --git a/python/ray/data/_internal/logical/operators/write_operator.py b/python/ray/data/_internal/logical/operators/write_operator.py index cee1930b788f9..bbf0159fa4f36 100644 --- a/python/ray/data/_internal/logical/operators/write_operator.py +++ b/python/ray/data/_internal/logical/operators/write_operator.py @@ -19,7 +19,7 @@ def __init__( ): if isinstance(datasink_or_legacy_datasource, Datasink): min_rows_per_bundled_input = ( - datasink_or_legacy_datasource.num_rows_per_write + datasink_or_legacy_datasource.min_rows_per_write ) else: min_rows_per_bundled_input = None diff --git a/python/ray/data/_internal/logical/rules/zero_copy_map_fusion.py b/python/ray/data/_internal/logical/rules/zero_copy_map_fusion.py index 6495f64f10a49..2ad70ff9ef847 100644 --- a/python/ray/data/_internal/logical/rules/zero_copy_map_fusion.py +++ b/python/ray/data/_internal/logical/rules/zero_copy_map_fusion.py @@ -6,6 +6,7 @@ BuildOutputBlocksMapTransformFn, MapTransformFn, MapTransformFnDataType, + MapTransformFnCategory, ) from ray.data._internal.logical.interfaces.optimizer import Rule from ray.data._internal.logical.interfaces.physical_plan import PhysicalPlan @@ -53,8 +54,9 @@ def _optimize(self, transform_fns: List[MapTransformFn]) -> List[MapTransformFn] class EliminateBuildOutputBlocks(ZeroCopyMapFusionRule): - """This rule eliminates unnecessary BuildOutputBlocksMapTransformFn, - if the previous fn already outputs blocks. + """This rule eliminates unnecessary BuildOutputBlocksMapTransformFn + (which is of category MapTransformFnCategory.PostProcess), if the previous fn + already outputs blocks. This happens for the "Read -> Map/Write" fusion. """ @@ -75,12 +77,14 @@ def _optimize(self, transform_fns: List[MapTransformFn]) -> List[MapTransformFn] and i < len(transform_fns) - 1 and isinstance(cur_fn, BuildOutputBlocksMapTransformFn) ): + assert cur_fn.category == MapTransformFnCategory.PostProcess prev_fn = transform_fns[i - 1] next_fn = transform_fns[i + 1] if ( prev_fn.output_type == MapTransformFnDataType.Block and next_fn.input_type == MapTransformFnDataType.Block ): + assert prev_fn.category == MapTransformFnCategory.DataProcess drop = True if not drop: new_transform_fns.append(cur_fn) diff --git a/python/ray/data/_internal/planner/exchange/pull_based_shuffle_task_scheduler.py b/python/ray/data/_internal/planner/exchange/pull_based_shuffle_task_scheduler.py index b2cf448f030d2..8fdbe6fbaa67e 100644 --- a/python/ray/data/_internal/planner/exchange/pull_based_shuffle_task_scheduler.py +++ b/python/ray/data/_internal/planner/exchange/pull_based_shuffle_task_scheduler.py @@ -10,6 +10,7 @@ from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.stats import StatsDict from ray.data._internal.util import convert_bytes_to_human_readable_str +from ray.data.block import to_stats logger = logging.getLogger(__name__) @@ -142,8 +143,8 @@ def execute( ) ) stats = { - "map": shuffle_map_metadata, - "reduce": new_metadata, + "map": to_stats(shuffle_map_metadata), + "reduce": to_stats(new_metadata), } return (output, stats) diff --git a/python/ray/data/_internal/planner/exchange/push_based_shuffle_task_scheduler.py b/python/ray/data/_internal/planner/exchange/push_based_shuffle_task_scheduler.py index c409d5ce4bee2..b7f2c9e3bb894 100644 --- a/python/ray/data/_internal/planner/exchange/push_based_shuffle_task_scheduler.py +++ b/python/ray/data/_internal/planner/exchange/push_based_shuffle_task_scheduler.py @@ -13,7 +13,7 @@ from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.stats import StatsDict from ray.data._internal.util import convert_bytes_to_human_readable_str -from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata +from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata, to_stats from ray.data.context import DataContext from ray.types import ObjectRef from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -643,10 +643,11 @@ def merge(*args, **kwargs): owns_blocks=input_owned, ) ) + stats = { - "map": map_stage_metadata, - "merge": merge_stage_metadata, - "reduce": reduce_stage_metadata, + "map": to_stats(map_stage_metadata), + "merge": to_stats(merge_stage_metadata), + "reduce": to_stats(reduce_stage_metadata), } return (output, stats) diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index 4af5146ae3231..9413767c7a395 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -1,9 +1,10 @@ import ast import logging -from typing import Any, List, Union import pyarrow as pa import pyarrow.compute as pc +import pyarrow.dataset as ds + logger = logging.getLogger(__name__) @@ -16,7 +17,8 @@ class ExpressionEvaluator: - def get_filters(self, expression: str) -> pc.Expression: + @staticmethod + def get_filters(expression: str) -> ds.Expression: """Parse and evaluate the expression to generate a filter condition. Args: @@ -28,29 +30,16 @@ def get_filters(self, expression: str) -> pc.Expression: """ try: tree = ast.parse(expression, mode="eval") - return self._build_filter_condition(tree.body) + return _ConvertToArrowExpressionVisitor().visit(tree.body) except SyntaxError as e: raise ValueError(f"Invalid syntax in the expression: {expression}") from e except Exception as e: logger.exception(f"Error processing expression: {e}") raise - def _build_filter_condition(self, node) -> Union[pc.Expression, List[Any], str]: - """Recursively evaluate an AST node to build the filter condition. - - Args: - node: The AST node to evaluate, representing part of the expression. - - Returns: - The evaluated result for the node, which could be a - filter condition, list, or field name. - """ - visitor = _ConvertToArrowExpressionVisitor() - return visitor.visit(node) - class _ConvertToArrowExpressionVisitor(ast.NodeVisitor): - def visit_Compare(self, node: ast.Compare) -> pc.Expression: + def visit_Compare(self, node: ast.Compare) -> ds.Expression: """Handle comparison operations (e.g., a == b, a < b, a in b). Args: @@ -62,11 +51,14 @@ def visit_Compare(self, node: ast.Compare) -> pc.Expression: # Handle left operand # TODO Validate columns if isinstance(node.left, ast.Attribute): - left_expr = self.visit(node.left) # Visit and handle attributes + # Visit and handle attributes + left_expr = self.visit(node.left) elif isinstance(node.left, ast.Name): - left_expr = self.visit(node.left) # Treat as a simple field + # Treat as a simple field + left_expr = self.visit(node.left) elif isinstance(node.left, ast.Constant): - left_expr = node.left.value # Constant values are used directly + # Constant values are used directly + left_expr = node.left.value else: raise ValueError(f"Unsupported left operand type: {type(node.left)}") @@ -92,7 +84,7 @@ def visit_Compare(self, node: ast.Compare) -> pc.Expression: else: raise ValueError(f"Unsupported operator type: {op}") - def visit_BoolOp(self, node: ast.BoolOp) -> pc.Expression: + def visit_BoolOp(self, node: ast.BoolOp) -> ds.Expression: """Handle logical operations (e.g., a and b, a or b). Args: @@ -118,8 +110,8 @@ def visit_BoolOp(self, node: ast.BoolOp) -> pc.Expression: return combined_expr - def visit_Name(self, node: ast.Name) -> pc.Expression: - """Handle variable (name) nodes and return them as pc.Expression. + def visit_Name(self, node: ast.Name) -> ds.Expression: + """Handle variable (name) nodes and return them as pa.dataset.Expression. Even if the name contains periods, it's treated as a single string. @@ -127,11 +119,10 @@ def visit_Name(self, node: ast.Name) -> pc.Expression: node: The AST node representing a variable. Returns: - The variable wrapped as a pc.Expression. + The variable wrapped as a pa.dataset.Expression. """ - field_name = ( - node.id - ) # Directly use the field name as a string (even if it contains periods) + # Directly use the field name as a string (even if it contains periods) + field_name = node.id return pc.field(field_name) def visit_Attribute(self, node: ast.Attribute) -> object: @@ -159,21 +150,21 @@ def visit_Attribute(self, node: ast.Attribute) -> object: raise ValueError(f"Unsupported attribute: {node.attr}") - def visit_List(self, node: ast.List) -> pc.Expression: + def visit_List(self, node: ast.List) -> ds.Expression: """Handle list literals. Args: node: The AST node representing a list. Returns: - The list of elements wrapped as a pc.Expression. + The list of elements wrapped as a pa.dataset.Expression. """ elements = [self.visit(elt) for elt in node.elts] return pa.array(elements) - # TODO (srinathk) Note that visit_Constant does not return pc.Expression + # TODO (srinathk) Note that visit_Constant does not return pa.dataset.Expression # because to support function in() which takes in a List, the elements in the List - # needs to values instead of pc.Expression per pyarrow.dataset.Expression + # needs to values instead of pa.dataset.Expression per pyarrow.dataset.Expression # specification. May be down the road, we can update it as Arrow relaxes this # constraint. def visit_Constant(self, node: ast.Constant) -> object: @@ -187,7 +178,7 @@ def visit_Constant(self, node: ast.Constant) -> object: """ return node.value # Return the constant value directly. - def visit_Call(self, node: ast.Call) -> pc.Expression: + def visit_Call(self, node: ast.Call) -> ds.Expression: """Handle function calls (e.g., is_nan(a), is_valid(b)). Args: diff --git a/python/ray/data/_internal/planner/randomize_blocks.py b/python/ray/data/_internal/planner/randomize_blocks.py index 835017f2cafd5..6211a96a42020 100644 --- a/python/ray/data/_internal/planner/randomize_blocks.py +++ b/python/ray/data/_internal/planner/randomize_blocks.py @@ -32,9 +32,9 @@ def fn( input_owned = all(b.owns_blocks for b in refs) random.shuffle(blocks_with_metadata) output = [] - meta_list = [] + stats_list = [] for block, meta in blocks_with_metadata: - meta_list.append(meta) + stats_list.append(meta.to_stats()) output.append( RefBundle( [ @@ -46,6 +46,6 @@ def fn( owns_blocks=input_owned, ) ) - return output, {op._name: meta_list} + return output, {op._name: stats_list} return fn diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index fc6903cd92e2c..b39caa6ff9b36 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -17,7 +17,7 @@ OpRuntimeMetrics, ) from ray.data._internal.util import capfirst -from ray.data.block import BlockMetadata +from ray.data.block import BlockMetadata, BlockStats from ray.data.context import DataContext from ray.util.annotations import DeveloperAPI from ray.util.metrics import Gauge @@ -29,7 +29,7 @@ STATS_ACTOR_NAMESPACE = "_dataset_stats_actor" -StatsDict = Dict[str, List[BlockMetadata]] +StatsDict = Dict[str, List[BlockStats]] def fmt(seconds: float) -> str: @@ -717,11 +717,11 @@ def to_summary(self) -> "DatasetStatsSummary": operators_stats = [] is_sub_operator = len(self.metadata) > 1 - for name, meta in self.metadata.items(): + for name, stats in self.metadata.items(): operators_stats.append( OperatorStatsSummary.from_block_metadata( name, - meta, + stats, is_sub_operator=is_sub_operator, ) ) @@ -1055,20 +1055,20 @@ class OperatorStatsSummary: def from_block_metadata( cls, operator_name: str, - block_metas: List[BlockMetadata], + block_stats: List[BlockStats], is_sub_operator: bool, ) -> "OperatorStatsSummary": """Calculate the stats for a operator from a given list of blocks, and generates a `OperatorStatsSummary` object with the results. Args: - block_metas: List of `BlockMetadata` to calculate stats of + block_stats: List of `BlockStats` to calculate stats of operator_name: Name of operator associated with `blocks` is_sub_operator: Whether this set of blocks belongs to a sub operator. Returns: A `OperatorStatsSummary` object initialized with the calculated statistics """ - exec_stats = [m.exec_stats for m in block_metas if m.exec_stats is not None] + exec_stats = [m.exec_stats for m in block_stats if m.exec_stats is not None] rounded_total = 0 time_total_s = 0 earliest_start_time, latest_end_time = 0, 0 @@ -1097,7 +1097,7 @@ def from_block_metadata( exec_summary_str += "\n" task_rows = collections.defaultdict(int) - for meta in block_metas: + for meta in block_stats: if meta.num_rows is not None and meta.exec_stats is not None: task_rows[meta.exec_stats.task_idx] += meta.num_rows task_rows_stats = None @@ -1144,7 +1144,7 @@ def from_block_metadata( } output_num_rows_stats = None - output_num_rows = [m.num_rows for m in block_metas if m.num_rows is not None] + output_num_rows = [m.num_rows for m in block_stats if m.num_rows is not None] if output_num_rows: output_num_rows_stats = { "min": min(output_num_rows), @@ -1155,7 +1155,7 @@ def from_block_metadata( output_size_bytes_stats = None output_size_bytes = [ - m.size_bytes for m in block_metas if m.size_bytes is not None + m.size_bytes for m in block_stats if m.size_bytes is not None ] if output_size_bytes: output_size_bytes_stats = { diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index a8b63720ac628..537940c5a9b21 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -24,6 +24,7 @@ ) import numpy as np +import pyarrow import ray from ray._private.utils import _get_pyarrow_version @@ -31,7 +32,6 @@ if TYPE_CHECKING: import pandas - import pyarrow from ray.data._internal.compute import ComputeStrategy from ray.data._internal.planner.exchange.sort_task_spec import SortKey @@ -1108,6 +1108,251 @@ def _run_transforming_worker(worker_id: int): interrupted_event.set() +class RetryingContextManager: + def __init__( + self, + f: pyarrow.NativeFile, + context: DataContext, + max_attempts: int = 10, + max_backoff_s: int = 32, + ): + self._f = f + self._data_context = context + self._max_attempts = max_attempts + self._max_backoff_s = max_backoff_s + + def _retry_operation(self, operation: Callable, description: str): + """Execute an operation with retries.""" + return call_with_retry( + operation, + description=description, + match=self._data_context.retried_io_errors, + max_attempts=self._max_attempts, + max_backoff_s=self._max_backoff_s, + ) + + def __enter__(self): + return self._retry_operation(self._f.__enter__, "enter file context") + + def __exit__(self, exc_type, exc_value, traceback): + self._retry_operation( + lambda: self._f.__exit__(exc_type, exc_value, traceback), + "exit file context", + ) + + +class RetryingPyFileSystem(pyarrow.fs.PyFileSystem): + def __init__(self, handler: "RetryingPyFileSystemHandler"): + if not isinstance(handler, RetryingPyFileSystemHandler): + assert ValueError("handler must be a RetryingPyFileSystemHandler") + super().__init__(handler) + + @property + def data_context(self): + return self.handler.data_context + + def unwrap(self): + return self.handler.unwrap() + + @classmethod + def wrap( + cls, + fs: "pyarrow.fs.FileSystem", + context: DataContext, + max_attempts: int = 10, + max_backoff_s: int = 32, + ): + if isinstance(fs, RetryingPyFileSystem): + return fs + handler = RetryingPyFileSystemHandler(fs, context, max_attempts, max_backoff_s) + return cls(handler) + + def __reduce__(self): + # Serialization of this class breaks for some reason without this + return (self.__class__, (self.handler,)) + + @classmethod + def __setstate__(cls, state): + # Serialization of this class breaks for some reason without this + return cls(*state) + + +class RetryingPyFileSystemHandler(pyarrow.fs.FileSystemHandler): + """Wrapper for filesystem objects that adds retry functionality for file operations. + + This class wraps any filesystem object and adds automatic retries for common + file operations that may fail transiently. + """ + + def __init__( + self, + fs: "pyarrow.fs.FileSystem", + context: DataContext, + max_attempts: int = 10, + max_backoff_s: int = 32, + ): + """Initialize the retrying filesystem wrapper. + + Args: + fs: The underlying filesystem to wrap + context: DataContext for retry settings + max_attempts: Maximum number of retry attempts + max_backoff_s: Maximum backoff time in seconds + """ + assert not isinstance( + fs, RetryingPyFileSystem + ), "Cannot wrap a RetryingPyFileSystem" + self._fs = fs + self._data_context = context + self._max_attempts = max_attempts + self._max_backoff_s = max_backoff_s + + @property + def data_context(self): + return self._data_context + + def _retry_operation(self, operation: Callable, description: str): + """Execute an operation with retries.""" + return call_with_retry( + operation, + description=description, + match=self._data_context.retried_io_errors, + max_attempts=self._max_attempts, + max_backoff_s=self._max_backoff_s, + ) + + def unwrap(self): + return self._fs + + def copy_file(self, src: str, dest: str): + """Copy a file.""" + return self._retry_operation( + lambda: self._fs.copy_file(src, dest), f"copy file from {src} to {dest}" + ) + + def create_dir(self, path: str, recursive: bool): + """Create a directory and subdirectories.""" + return self._retry_operation( + lambda: self._fs.create_dir(path, recursive=recursive), + f"create directory {path}", + ) + + def delete_dir(self, path: str): + """Delete a directory and its contents, recursively.""" + return self._retry_operation( + lambda: self._fs.delete_dir(path), f"delete directory {path}" + ) + + def delete_dir_contents(self, path: str, missing_dir_ok: bool = False): + """Delete a directory's contents, recursively.""" + return self._retry_operation( + lambda: self._fs.delete_dir_contents(path, missing_dir_ok=missing_dir_ok), + f"delete directory contents {path}", + ) + + def delete_file(self, path: str): + """Delete a file.""" + return self._retry_operation( + lambda: self._fs.delete_file(path), f"delete file {path}" + ) + + def delete_root_dir_contents(self): + return self._retry_operation( + lambda: self._fs.delete_dir_contents("/", accept_root_dir=True), + "delete root dir contents", + ) + + def equals(self, other: "pyarrow.fs.FileSystem") -> bool: + """Test if this filesystem equals another.""" + return self._fs.equals(other) + + def get_file_info(self, paths: List[str]): + """Get info for the given files.""" + return self._retry_operation( + lambda: self._fs.get_file_info(paths), + f"get file info for {paths}", + ) + + def get_file_info_selector(self, selector): + return self._retry_operation( + lambda: self._fs.get_file_info(selector), + f"get file info for {selector}", + ) + + def get_type_name(self): + return "RetryingPyFileSystem" + + def move(self, src: str, dest: str): + """Move / rename a file or directory.""" + return self._retry_operation( + lambda: self._fs.move(src, dest), f"move from {src} to {dest}" + ) + + def normalize_path(self, path: str) -> str: + """Normalize filesystem path.""" + return self._retry_operation( + lambda: self._fs.normalize_path(path), f"normalize path {path}" + ) + + def open_append_stream( + self, + path: str, + metadata=None, + ) -> "pyarrow.NativeFile": + """Open an output stream for appending. + + Compression is disabled in this method because it is handled in the + PyFileSystem abstract class. + """ + return self._retry_operation( + lambda: self._fs.open_append_stream( + path, + compression=None, + metadata=metadata, + ), + f"open append stream for {path}", + ) + + def open_input_stream( + self, + path: str, + ) -> "pyarrow.NativeFile": + """Open an input stream for sequential reading. + + Compression is disabled in this method because it is handled in the + PyFileSystem abstract class. + """ + return self._retry_operation( + lambda: self._fs.open_input_stream(path, compression=None), + f"open input stream for {path}", + ) + + def open_output_stream( + self, + path: str, + metadata=None, + ) -> "pyarrow.NativeFile": + """Open an output stream for sequential writing." + + Compression is disabled in this method because it is handled in the + PyFileSystem abstract class. + """ + return self._retry_operation( + lambda: self._fs.open_output_stream( + path, + compression=None, + metadata=metadata, + ), + f"open output stream for {path}", + ) + + def open_input_file(self, path: str) -> "pyarrow.NativeFile": + """Open an input file for random access reading.""" + return self._retry_operation( + lambda: self._fs.open_input_file(path), f"open input file {path}" + ) + + def call_with_retry( f: Callable[[], Any], description: str, @@ -1133,17 +1378,18 @@ def call_with_retry( try: return f() except Exception as e: - is_retryable = match is None or any( - [pattern in str(e) for pattern in match] - ) + is_retryable = match is None or any(pattern in str(e) for pattern in match) if is_retryable and i + 1 < max_attempts: # Retry with binary expoential backoff with random jitter. - backoff = min((2 ** (i + 1)), max_backoff_s) * random.random() + backoff = min((2 ** (i + 1)), max_backoff_s) * (random.random()) logger.debug( f"Retrying {i+1} attempts to {description} after {backoff} seconds." ) time.sleep(backoff) else: + logger.debug( + f"Did not find a match for {str(e)}. Raising after {i+1} attempts." + ) raise e from None @@ -1184,9 +1430,7 @@ def iterate_with_retry( yield item return except Exception as e: - is_retryable = match is None or any( - [pattern in str(e) for pattern in match] - ) + is_retryable = match is None or any(pattern in str(e) for pattern in match) if is_retryable and attempt + 1 < max_attempts: # Retry with binary expoential backoff with random jitter. backoff = min((2 ** (attempt + 1)), max_backoff_s) * random.random() @@ -1216,6 +1460,36 @@ def convert_bytes_to_human_readable_str(num_bytes: int) -> str: return num_bytes_str +def _validate_rows_per_file_args( + *, num_rows_per_file: Optional[int] = None, min_rows_per_file: Optional[int] = None +) -> Optional[int]: + """Helper method to validate and handle rows per file arguments. + + Args: + num_rows_per_file: Deprecated parameter for number of rows per file + min_rows_per_file: New parameter for minimum rows per file + + Returns: + The effective min_rows_per_file value to use + """ + if num_rows_per_file is not None: + import warnings + + warnings.warn( + "`num_rows_per_file` is deprecated and will be removed in a future release. " + "Use `min_rows_per_file` instead.", + DeprecationWarning, + stacklevel=3, + ) + if min_rows_per_file is not None: + raise ValueError( + "Cannot specify both `num_rows_per_file` and `min_rows_per_file`. " + "Use `min_rows_per_file` as `num_rows_per_file` is deprecated." + ) + return num_rows_per_file + return min_rows_per_file + + def is_nan(value): try: return isinstance(value, float) and np.isnan(value) diff --git a/python/ray/data/block.py b/python/ray/data/block.py index fc8e78b0736e4..a0d09b12569e2 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -2,7 +2,7 @@ import logging import os import time -from dataclasses import dataclass +from dataclasses import dataclass, asdict, fields from enum import Enum from typing import ( TYPE_CHECKING, @@ -132,6 +132,11 @@ def _apply_batch_size( return given_batch_size +@DeveloperAPI +def to_stats(metas: List["BlockMetadata"]) -> List["BlockStats"]: + return [m.to_stats() for m in metas] + + @DeveloperAPI class BlockExecStats: """Execution stats for this block. @@ -204,28 +209,47 @@ def build(self) -> "BlockExecStats": @DeveloperAPI @dataclass -class BlockMetadata: - """Metadata about the block.""" +class BlockStats: + """Statistics about the block produced""" #: The number of rows contained in this block, or None. num_rows: Optional[int] #: The approximate size in bytes of this block, or None. size_bytes: Optional[int] + #: Execution stats for this block. + exec_stats: Optional[BlockExecStats] + + def __post_init__(self): + if self.size_bytes is not None: + # Require size_bytes to be int, ray.util.metrics objects + # will not take other types like numpy.int64 + assert isinstance(self.size_bytes, int) + + +_BLOCK_STATS_FIELD_NAMES = {f.name for f in fields(BlockStats)} + + +@DeveloperAPI +@dataclass +class BlockMetadata(BlockStats): + """Metadata about the block.""" + #: The pyarrow schema or types of the block elements, or None. schema: Optional[Union[type, "pyarrow.lib.Schema"]] #: The list of file paths used to generate this block, or #: the empty list if indeterminate. input_files: Optional[List[str]] - #: Execution stats for this block. - exec_stats: Optional[BlockExecStats] + + def to_stats(self): + return BlockStats( + **{k: v for k, v in asdict(self).items() if k in _BLOCK_STATS_FIELD_NAMES} + ) def __post_init__(self): + super().__post_init__() + if self.input_files is None: self.input_files = [] - if self.size_bytes is not None: - # Require size_bytes to be int, ray.util.metrics objects - # will not take other types like numpy.int64 - assert isinstance(self.size_bytes, int) @DeveloperAPI diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 6e79541397123..60deb88128adf 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -82,7 +82,12 @@ from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.split import _get_num_rows, _split_at_indices from ray.data._internal.stats import DatasetStats, DatasetStatsSummary, StatsManager -from ray.data._internal.util import AllToAllAPI, ConsumptionAPI, get_compute_strategy +from ray.data._internal.util import ( + AllToAllAPI, + ConsumptionAPI, + _validate_rows_per_file_args, + get_compute_strategy, +) from ray.data.aggregate import AggregateFn from ray.data.block import ( VALID_BATCH_FORMATS, @@ -1278,8 +1283,7 @@ def filter( # TODO: (srinathk) bind the expression to the actual schema. # If fn is a string, convert it to a pyarrow.dataset.Expression # Initialize ExpressionEvaluator with valid columns, if available - evaluator = ExpressionEvaluator() - resolved_expr = evaluator.get_filters(expression=expr) + resolved_expr = ExpressionEvaluator.get_filters(expression=expr) compute = TaskPoolStrategy(size=concurrency) else: @@ -2984,9 +2988,10 @@ def write_parquet( arrow_open_stream_args: Optional[Dict[str, Any]] = None, filename_provider: Optional[FilenameProvider] = None, arrow_parquet_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, - num_rows_per_file: Optional[int] = None, + min_rows_per_file: Optional[int] = None, ray_remote_args: Dict[str, Any] = None, concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, **arrow_parquet_args, ) -> None: """Writes the :class:`~ray.data.Dataset` to parquet files under the provided ``path``. @@ -3042,17 +3047,18 @@ def write_parquet( instead of ``arrow_parquet_args`` if any of your write arguments can't pickled, or if you'd like to lazily resolve the write arguments for each dataset block. - num_rows_per_file: [Experimental] The target number of rows to write to each - file. If ``None``, Ray Data writes a system-chosen number of rows to - each file. The specified value is a hint, not a strict limit. Ray Data - might write more or fewer rows to each file. In specific, if the number - of rows per block is larger than the specified value, Ray Data writes - the number of rows per block to each file. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the total number of tasks run. By default, concurrency is dynamically decided based on the available resources. + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. arrow_parquet_args: Options to pass to `pyarrow.parquet.write_table() None: """Writes the :class:`~ray.data.Dataset` to JSON and JSONL files. @@ -3162,17 +3173,18 @@ def write_json( instead of ``pandas_json_args`` if any of your write arguments can't be pickled, or if you'd like to lazily resolve the write arguments for each dataset block. - num_rows_per_file: [Experimental] The target number of rows to write to each - file. If ``None``, Ray Data writes a system-chosen number of rows to - each file. The specified value is a hint, not a strict limit. Ray Data - might write more or fewer rows to each file. In specific, if the number - of rows per block is larger than the specified value, Ray Data writes - the number of rows per block to each file. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the total number of tasks run. By default, concurrency is dynamically decided based on the available resources. + num_rows_per_file: Deprecated. Use ``min_rows_per_file`` instead. pandas_json_args: These args are passed to `pandas.DataFrame.to_json() `_, @@ -3183,11 +3195,15 @@ def write_json( if pandas_json_args_fn is None: pandas_json_args_fn = lambda: {} # noqa: E731 + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + datasink = JSONDatasink( path, pandas_json_args_fn=pandas_json_args_fn, pandas_json_args=pandas_json_args, - num_rows_per_file=num_rows_per_file, + min_rows_per_file=effective_min_rows, filesystem=filesystem, try_create_dir=try_create_dir, open_stream_args=arrow_open_stream_args, @@ -3283,9 +3299,10 @@ def write_csv( arrow_open_stream_args: Optional[Dict[str, Any]] = None, filename_provider: Optional[FilenameProvider] = None, arrow_csv_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, - num_rows_per_file: Optional[int] = None, + min_rows_per_file: Optional[int] = None, ray_remote_args: Dict[str, Any] = None, concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, **arrow_csv_args, ) -> None: """Writes the :class:`~ray.data.Dataset` to CSV files. @@ -3348,17 +3365,18 @@ def write_csv( Use this argument instead of ``arrow_csv_args`` if any of your write arguments cannot be pickled, or if you'd like to lazily resolve the write arguments for each dataset block. - num_rows_per_file: [Experimental] The target number of rows to write to each - file. If ``None``, Ray Data writes a system-chosen number of rows to - each file. The specified value is a hint, not a strict limit. Ray Data - might write more or fewer rows to each file. In specific, if the number - of rows per block is larger than the specified value, Ray Data writes - the number of rows per block to each file. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the total number of tasks run. By default, concurrency is dynamically decided based on the available resources. + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. arrow_csv_args: Options to pass to `pyarrow.write.write_csv `_ @@ -3367,11 +3385,15 @@ def write_csv( if arrow_csv_args_fn is None: arrow_csv_args_fn = lambda: {} # noqa: E731 + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + datasink = CSVDatasink( path, arrow_csv_args_fn=arrow_csv_args_fn, arrow_csv_args=arrow_csv_args, - num_rows_per_file=num_rows_per_file, + min_rows_per_file=effective_min_rows, filesystem=filesystem, try_create_dir=try_create_dir, open_stream_args=arrow_open_stream_args, @@ -3395,9 +3417,10 @@ def write_tfrecords( try_create_dir: bool = True, arrow_open_stream_args: Optional[Dict[str, Any]] = None, filename_provider: Optional[FilenameProvider] = None, - num_rows_per_file: Optional[int] = None, + min_rows_per_file: Optional[int] = None, ray_remote_args: Dict[str, Any] = None, concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, ) -> None: """Write the :class:`~ray.data.Dataset` to TFRecord files. @@ -3453,23 +3476,27 @@ def write_tfrecords( filename_provider: A :class:`~ray.data.datasource.FilenameProvider` implementation. Use this parameter to customize what your filenames look like. - num_rows_per_file: [Experimental] The target number of rows to write to each - file. If ``None``, Ray Data writes a system-chosen number of rows to - each file. The specified value is a hint, not a strict limit. Ray Data - might write more or fewer rows to each file. In specific, if the number - of rows per block is larger than the specified value, Ray Data writes - the number of rows per block to each file. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the total number of tasks run. By default, concurrency is dynamically decided based on the available resources. - + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. """ + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + datasink = TFRecordDatasink( path=path, tf_schema=tf_schema, - num_rows_per_file=num_rows_per_file, + min_rows_per_file=effective_min_rows, filesystem=filesystem, try_create_dir=try_create_dir, open_stream_args=arrow_open_stream_args, @@ -3492,10 +3519,11 @@ def write_webdataset( try_create_dir: bool = True, arrow_open_stream_args: Optional[Dict[str, Any]] = None, filename_provider: Optional[FilenameProvider] = None, - num_rows_per_file: Optional[int] = None, + min_rows_per_file: Optional[int] = None, ray_remote_args: Dict[str, Any] = None, encoder: Optional[Union[bool, str, callable, list]] = True, concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, ) -> None: """Writes the dataset to `WebDataset `_ files. @@ -3540,23 +3568,27 @@ def write_webdataset( filename_provider: A :class:`~ray.data.datasource.FilenameProvider` implementation. Use this parameter to customize what your filenames look like. - num_rows_per_file: [Experimental] The target number of rows to write to each - file. If ``None``, Ray Data writes a system-chosen number of rows to - each file. The specified value is a hint, not a strict limit. Ray Data - might write more or fewer rows to each file. In specific, if the number - of rows per block is larger than the specified value, Ray Data writes - the number of rows per block to each file. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the total number of tasks run. By default, concurrency is dynamically decided based on the available resources. - + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. """ + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) + datasink = WebDatasetDatasink( path, encoder=encoder, - num_rows_per_file=num_rows_per_file, + min_rows_per_file=effective_min_rows, filesystem=filesystem, try_create_dir=try_create_dir, open_stream_args=arrow_open_stream_args, @@ -3580,9 +3612,10 @@ def write_numpy( try_create_dir: bool = True, arrow_open_stream_args: Optional[Dict[str, Any]] = None, filename_provider: Optional[FilenameProvider] = None, - num_rows_per_file: Optional[int] = None, + min_rows_per_file: Optional[int] = None, ray_remote_args: Dict[str, Any] = None, concurrency: Optional[int] = None, + num_rows_per_file: Optional[int] = None, ) -> None: """Writes a column of the :class:`~ray.data.Dataset` to .npy files. @@ -3630,23 +3663,27 @@ def write_numpy( filename_provider: A :class:`~ray.data.datasource.FilenameProvider` implementation. Use this parameter to customize what your filenames look like. - num_rows_per_file: [Experimental] The target number of rows to write to each - file. If ``None``, Ray Data writes a system-chosen number of rows to - each file. The specified value is a hint, not a strict limit. Ray Data - might write more or fewer rows to each file. In specific, if the number - of rows per block is larger than the specified value, Ray Data writes - the number of rows per block to each file. + min_rows_per_file: [Experimental] The target minimum number of rows to write + to each file. If ``None``, Ray Data writes a system-chosen number of + rows to each file. If the number of rows per block is larger than the + specified value, Ray Data writes the number of rows per block to each file. + The specified value is a hint, not a strict limit. Ray Data + might write more or fewer rows to each file. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this to control number of tasks to run concurrently. This doesn't change the total number of tasks run. By default, concurrency is dynamically decided based on the available resources. + num_rows_per_file: [Deprecated] Use min_rows_per_file instead. """ + effective_min_rows = _validate_rows_per_file_args( + num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file + ) datasink = NumpyDatasink( path, column, - num_rows_per_file=num_rows_per_file, + min_rows_per_file=effective_min_rows, filesystem=filesystem, try_create_dir=try_create_dir, open_stream_args=arrow_open_stream_args, diff --git a/python/ray/data/datasource/datasink.py b/python/ray/data/datasource/datasink.py index 904b6a54f1fc0..fee67910ba98a 100644 --- a/python/ray/data/datasource/datasink.py +++ b/python/ray/data/datasource/datasink.py @@ -103,7 +103,7 @@ def supports_distributed_writes(self) -> bool: return True @property - def num_rows_per_write(self) -> Optional[int]: + def min_rows_per_write(self) -> Optional[int]: """The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call. If ``None``, Ray Data passes a system-chosen number of rows. diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 4320cd64ff0c8..91f5b33b71777 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -18,9 +18,11 @@ import ray from ray.data._internal.util import ( + RetryingContextManager, + RetryingPyFileSystem, _check_pyarrow_version, _is_local_scheme, - call_with_retry, + iterate_with_retry, make_async_gen, ) from ray.data.block import Block, BlockAccessor @@ -55,12 +57,6 @@ # 16 file size fetches from S3 takes ~1.5 seconds with Arrow's S3FileSystem. PATHS_PER_FILE_SIZE_FETCH_TASK = 16 -# The max retry backoff in seconds for opening file. -OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS = 32 - -# The max number of attempts for opening file. -OPEN_FILE_MAX_ATTEMPTS = 10 - @DeveloperAPI @dataclass @@ -137,6 +133,7 @@ def __init__( ) self._schema = schema + self._data_context = DataContext.get_current() self._open_stream_args = open_stream_args self._meta_provider = meta_provider self._partition_filter = partition_filter @@ -144,6 +141,9 @@ def __init__( self._ignore_missing_paths = ignore_missing_paths self._include_paths = include_paths paths, self._filesystem = _resolve_paths_and_filesystem(paths, filesystem) + self._filesystem = RetryingPyFileSystem.wrap( + self._filesystem, context=self._data_context + ) paths, file_sizes = map( list, zip( @@ -214,7 +214,6 @@ def estimate_inmemory_data_size(self) -> Optional[int]: def get_read_tasks(self, parallelism: int) -> List[ReadTask]: import numpy as np - ctx = DataContext.get_current() open_stream_args = self._open_stream_args partitioning = self._partitioning @@ -229,34 +228,33 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: ] paths, file_sizes = list(map(list, zip(*shuffled_files_metadata))) - read_stream = self._read_stream filesystem = _wrap_s3_serialization_workaround(self._filesystem) if open_stream_args is None: open_stream_args = {} - open_input_source = self._open_input_source - def read_files( read_paths: Iterable[str], ) -> Iterable[Block]: nonlocal filesystem, open_stream_args, partitioning - DataContext._set_current(ctx) fs = _unwrap_s3_serialization_workaround(filesystem) + for read_path in read_paths: partitions: Dict[str, str] = {} if partitioning is not None: parse = PathPartitionParser(partitioning) partitions = parse(read_path) - with _open_file_with_retry( - read_path, - lambda read_path=read_path: open_input_source( - fs, read_path, **open_stream_args - ), + with RetryingContextManager( + self._open_input_source(fs, read_path, **open_stream_args), + context=self._data_context, ) as f: - for block in read_stream(f, read_path): + for block in iterate_with_retry( + lambda: self._read_stream(f, read_path), + description="read stream iteratively", + match=self._data_context.retried_io_errors, + ): if partitions: block = _add_partitions(block, partitions) if self._include_paths: @@ -272,7 +270,7 @@ def read_task_fn(): # TODO: We should refactor the code so that we can get the results in # order even when using multiple threads. - if ctx.execution_options.preserve_order: + if self._data_context.execution_options.preserve_order: num_threads = 0 if num_threads > 0: @@ -322,15 +320,15 @@ def read_task_fn(): def _open_input_source( self, - filesystem: "pyarrow.fs.FileSystem", + filesystem: "RetryingPyFileSystem", path: str, **open_args, ) -> "pyarrow.NativeFile": """Opens a source path for reading and returns the associated Arrow NativeFile. The default implementation opens the source path as a sequential input stream, - using ctx.streaming_read_buffer_size as the buffer size if none is given by the - caller. + using self._data_context.streaming_read_buffer_size as the buffer size if none + is given by the caller. Implementations that do not support streaming reads (e.g. that require random access) should override this method. @@ -338,8 +336,6 @@ def _open_input_source( import pyarrow as pa from pyarrow.fs import HadoopFileSystem - ctx = DataContext.get_current() - compression = open_args.get("compression", None) if compression is None: try: @@ -359,7 +355,7 @@ def _open_input_source( buffer_size = open_args.pop("buffer_size", None) if buffer_size is None: - buffer_size = ctx.streaming_read_buffer_size + buffer_size = self._data_context.streaming_read_buffer_size if compression == "snappy": # Arrow doesn't support streaming Snappy decompression since the canonical @@ -369,19 +365,13 @@ def _open_input_source( else: open_args["compression"] = compression - file = call_with_retry( - lambda: filesystem.open_input_stream( - path, buffer_size=buffer_size, **open_args - ), - description=f"open file {path}", - match=ctx.retried_io_errors, - ) + file = filesystem.open_input_stream(path, buffer_size=buffer_size, **open_args) if compression == "snappy": import snappy stream = io.BytesIO() - if isinstance(filesystem, HadoopFileSystem): + if isinstance(filesystem.unwrap(), HadoopFileSystem): snappy.hadoop_snappy.stream_decompress(src=file, dst=stream) else: snappy.stream_decompress(src=file, dst=stream) @@ -483,23 +473,43 @@ def _wrap_s3_serialization_workaround(filesystem: "pyarrow.fs.FileSystem"): import pyarrow as pa import pyarrow.fs - if isinstance(filesystem, pa.fs.S3FileSystem): - return _S3FileSystemWrapper(filesystem) + wrap_retries = False + fs_to_be_wrapped = filesystem # Only unwrap for S3FileSystemWrapper + context = None + if isinstance(fs_to_be_wrapped, RetryingPyFileSystem): + wrap_retries = True + context = fs_to_be_wrapped.data_context + fs_to_be_wrapped = fs_to_be_wrapped.unwrap() + if isinstance(fs_to_be_wrapped, pa.fs.S3FileSystem): + return _S3FileSystemWrapper( + fs_to_be_wrapped, wrap_retries=wrap_retries, context=context + ) return filesystem def _unwrap_s3_serialization_workaround( - filesystem: Union["pyarrow.fs.FileSystem", "_S3FileSystemWrapper"] + filesystem: Union["pyarrow.fs.FileSystem", "_S3FileSystemWrapper"], + context: Optional[DataContext] = None, ): if isinstance(filesystem, _S3FileSystemWrapper): - return filesystem.unwrap() - else: - return filesystem + wrap_retries = filesystem._wrap_retries + context = filesystem._context + filesystem = filesystem.unwrap() + if wrap_retries: + filesystem = RetryingPyFileSystem.wrap(filesystem, context=context) + return filesystem class _S3FileSystemWrapper: - def __init__(self, fs: "pyarrow.fs.S3FileSystem"): + def __init__( + self, + fs: "pyarrow.fs.S3FileSystem", + wrap_retries: bool = False, + context: Optional[DataContext] = None, + ): self._fs = fs + self._wrap_retries = wrap_retries + self._context = context def unwrap(self): return self._fs @@ -538,30 +548,6 @@ def _resolve_kwargs( return kwargs -def _open_file_with_retry( - file_path: str, - open_file: Callable[[], "pyarrow.NativeFile"], -) -> "pyarrow.NativeFile": - """Open file with an exponential backoff retry strategy. - - This is to avoid transient task failure with remote storage (such as S3), - when the remote storage throttles the requests. - """ - if OPEN_FILE_MAX_ATTEMPTS < 1: - raise ValueError( - "OPEN_FILE_MAX_ATTEMPTS cannot be negative or 0. Get: " - f"{OPEN_FILE_MAX_ATTEMPTS}" - ) - - return call_with_retry( - open_file, - description=f"open file {file_path}", - match=DataContext.get_current().retried_io_errors, - max_attempts=OPEN_FILE_MAX_ATTEMPTS, - max_backoff_s=OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS, - ) - - def _validate_shuffle_arg(shuffle: Optional[str]) -> None: if not ( shuffle is None or shuffle == "files" or isinstance(shuffle, FileShuffleConfig) diff --git a/python/ray/data/datasource/file_datasink.py b/python/ray/data/datasource/file_datasink.py index ad4e1a16aed3b..5ea082a3026e5 100644 --- a/python/ray/data/datasource/file_datasink.py +++ b/python/ray/data/datasource/file_datasink.py @@ -6,7 +6,11 @@ from ray._private.utils import _add_creatable_buckets_param_if_s3_uri from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.execution.interfaces import TaskContext -from ray.data._internal.util import _is_local_scheme, call_with_retry +from ray.data._internal.util import ( + RetryingPyFileSystem, + _is_local_scheme, + call_with_retry, +) from ray.data.block import Block, BlockAccessor from ray.data.context import DataContext from ray.data.datasource.datasink import Datasink, WriteResult @@ -23,10 +27,6 @@ logger = logging.getLogger(__name__) -WRITE_FILE_MAX_ATTEMPTS = 10 -WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS = 32 - - class _FileDatasink(Datasink[None]): def __init__( self, @@ -62,8 +62,12 @@ def __init__( dataset_uuid=dataset_uuid, file_format=file_format ) + self._data_context = DataContext.get_current() self.unresolved_path = path paths, self.filesystem = _resolve_paths_and_filesystem(path, filesystem) + self.filesystem = RetryingPyFileSystem.wrap( + self.filesystem, context=self._data_context + ) assert len(paths) == 1, len(paths) self.path = paths[0] @@ -101,9 +105,7 @@ def _create_dir(self, dest) -> bool: # should not. parsed_uri = urlparse(dest) is_s3_uri = parsed_uri.scheme == "s3" - skip_create_dir_for_s3 = ( - is_s3_uri and not DataContext.get_current().s3_try_create_dir - ) + skip_create_dir_for_s3 = is_s3_uri and not self._data_context.s3_try_create_dir if self.try_create_dir and not skip_create_dir_for_s3: if self.filesystem.get_file_info(dest).type is FileType.NotFound: @@ -190,20 +192,16 @@ def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext): row, ctx.task_idx, block_index, row_index ) write_path = posixpath.join(self.path, filename) + logger.debug(f"Writing {write_path} file.") - def write_row_to_path(row, write_path): + def write_row_to_path(): with self.open_output_stream(write_path) as file: self.write_row_to_file(row, file) - logger.debug(f"Writing {write_path} file.") call_with_retry( - lambda row=row, write_path=write_path: write_row_to_path( - row, write_path - ), + write_row_to_path, description=f"write '{write_path}'", - match=DataContext.get_current().retried_io_errors, - max_attempts=WRITE_FILE_MAX_ATTEMPTS, - max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS, + match=self._data_context.retried_io_errors, ) @@ -227,11 +225,11 @@ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): """ # noqa: E501 def __init__( - self, path, *, num_rows_per_file: Optional[int] = None, **file_datasink_kwargs + self, path, *, min_rows_per_file: Optional[int] = None, **file_datasink_kwargs ): super().__init__(path, **file_datasink_kwargs) - self._num_rows_per_file = num_rows_per_file + self._min_rows_per_file = min_rows_per_file def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): """Write a block of data to a file. @@ -256,11 +254,9 @@ def write_block_to_path(): call_with_retry( write_block_to_path, description=f"write '{write_path}'", - match=DataContext.get_current().retried_io_errors, - max_attempts=WRITE_FILE_MAX_ATTEMPTS, - max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS, + match=self._data_context.retried_io_errors, ) @property - def num_rows_per_write(self) -> Optional[int]: - return self._num_rows_per_file + def min_rows_per_write(self) -> Optional[int]: + return self._min_rows_per_file diff --git a/python/ray/data/datasource/file_meta_provider.py b/python/ray/data/datasource/file_meta_provider.py index c6654e9e2708f..d9113dbe1a5da 100644 --- a/python/ray/data/datasource/file_meta_provider.py +++ b/python/ray/data/datasource/file_meta_provider.py @@ -16,10 +16,9 @@ import numpy as np -import ray from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn -from ray.data._internal.util import call_with_retry +from ray.data._internal.util import RetryingPyFileSystem from ray.data.block import BlockMetadata from ray.data.datasource.partitioning import Partitioning from ray.util.annotations import DeveloperAPI @@ -111,7 +110,7 @@ def _get_block_metadata( def expand_paths( self, paths: List[str], - filesystem: Optional["pyarrow.fs.FileSystem"], + filesystem: Optional["RetryingPyFileSystem"], partitioning: Optional[Partitioning] = None, ignore_missing_paths: bool = False, ) -> Iterator[Tuple[str, int]]: @@ -172,7 +171,7 @@ def _get_block_metadata( def expand_paths( self, paths: List[str], - filesystem: "pyarrow.fs.FileSystem", + filesystem: "RetryingPyFileSystem", partitioning: Optional[Partitioning] = None, ignore_missing_paths: bool = False, ) -> Iterator[Tuple[str, int]]: @@ -197,7 +196,7 @@ class FastFileMetadataProvider(DefaultFileMetadataProvider): def expand_paths( self, paths: List[str], - filesystem: "pyarrow.fs.FileSystem", + filesystem: "RetryingPyFileSystem", partitioning: Optional[Partitioning] = None, ignore_missing_paths: bool = False, ) -> Iterator[Tuple[str, int]]: @@ -254,7 +253,7 @@ def _handle_read_os_error(error: OSError, paths: Union[str, List[str]]) -> str: def _expand_paths( paths: List[str], - filesystem: "pyarrow.fs.FileSystem", + filesystem: "RetryingPyFileSystem", partitioning: Optional[Partitioning], ignore_missing_paths: bool = False, ) -> Iterator[Tuple[str, int]]: @@ -274,10 +273,14 @@ def _expand_paths( # provided paths on the client; this should be a single file info request. # 3. If more than threshold requests required, parallelize them via Ray tasks. # 1. Small # of paths case. + is_local = isinstance(filesystem, LocalFileSystem) + if isinstance(filesystem, RetryingPyFileSystem): + is_local = isinstance(filesystem.unwrap(), LocalFileSystem) + if ( len(paths) < FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD # Local file systems are very fast to hit. - or isinstance(filesystem, LocalFileSystem) + or is_local ): yield from _get_file_infos_serial(paths, filesystem, ignore_missing_paths) else: @@ -302,7 +305,7 @@ def _expand_paths( def _get_file_infos_serial( paths: List[str], - filesystem: "pyarrow.fs.FileSystem", + filesystem: "RetryingPyFileSystem", ignore_missing_paths: bool = False, ) -> Iterator[Tuple[str, int]]: for path in paths: @@ -349,7 +352,7 @@ def _get_file_infos_common_path_prefix( def _get_file_infos_parallel( paths: List[str], - filesystem: "pyarrow.fs.FileSystem", + filesystem: "RetryingPyFileSystem", ignore_missing_paths: bool = False, ) -> Iterator[Tuple[str, int]]: from ray.data.datasource.file_based_datasource import ( @@ -413,19 +416,14 @@ def _fetch_metadata_parallel( def _get_file_infos( - path: str, filesystem: "pyarrow.fs.FileSystem", ignore_missing_path: bool = False + path: str, filesystem: "RetryingPyFileSystem", ignore_missing_path: bool = False ) -> List[Tuple[str, int]]: """Get the file info for all files at or under the provided path.""" from pyarrow.fs import FileType file_infos = [] try: - ctx = ray.data.DataContext.get_current() - file_info = call_with_retry( - lambda: filesystem.get_file_info(path), - description="get file info", - match=ctx.retried_io_errors, - ) + file_info = filesystem.get_file_info(path) except OSError as e: _handle_read_os_error(e, path) if file_info.type == FileType.Directory: @@ -443,7 +441,7 @@ def _get_file_infos( def _expand_directory( path: str, - filesystem: "pyarrow.fs.FileSystem", + filesystem: "RetryingPyFileSystem", exclude_prefixes: Optional[List[str]] = None, ignore_missing_path: bool = False, ) -> List[Tuple[str, int]]: diff --git a/python/ray/data/examples/data/different-extensions/data.csv b/python/ray/data/examples/data/different-extensions/data.csv index 301800ec1f13c..2ba1bd9e02241 100644 --- a/python/ray/data/examples/data/different-extensions/data.csv +++ b/python/ray/data/examples/data/different-extensions/data.csv @@ -1,2 +1,2 @@ a,b -0,1 \ No newline at end of file +0,1 diff --git a/python/ray/data/examples/data/simple.txt b/python/ray/data/examples/data/simple.txt index 87dd8a3807c0f..e4a2e58b19e47 100644 --- a/python/ray/data/examples/data/simple.txt +++ b/python/ray/data/examples/data/simple.txt @@ -1,2 +1,2 @@ hello world -attention is all you need! \ No newline at end of file +attention is all you need! diff --git a/python/ray/data/examples/data/sms_spam_collection_subset.txt b/python/ray/data/examples/data/sms_spam_collection_subset.txt index 15ac29800b490..5b8f65132bb35 100644 --- a/python/ray/data/examples/data/sms_spam_collection_subset.txt +++ b/python/ray/data/examples/data/sms_spam_collection_subset.txt @@ -7,4 +7,4 @@ spam FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like ham Even my brother is not like to speak with me. They treat me like aids patent. ham As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune spam WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only. -spam Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030 \ No newline at end of file +spam Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030 diff --git a/python/ray/data/examples/data/year=2022/month=09/sales.csv b/python/ray/data/examples/data/year=2022/month=09/sales.csv index d21ebba956636..7a3a64abea608 100644 --- a/python/ray/data/examples/data/year=2022/month=09/sales.csv +++ b/python/ray/data/examples/data/year=2022/month=09/sales.csv @@ -1,2 +1,2 @@ order_number,quantity -10107,30 \ No newline at end of file +10107,30 diff --git a/python/ray/data/examples/data/year=2022/month=09/sales.json b/python/ray/data/examples/data/year=2022/month=09/sales.json index 0529c0bea5374..22e8905162c91 100644 --- a/python/ray/data/examples/data/year=2022/month=09/sales.json +++ b/python/ray/data/examples/data/year=2022/month=09/sales.json @@ -1,4 +1,4 @@ { "order_number": 10107, "quantity": 30 -} \ No newline at end of file +} diff --git a/python/ray/data/llm.py b/python/ray/data/llm.py new file mode 100644 index 0000000000000..a0dba5bb9f831 --- /dev/null +++ b/python/ray/data/llm.py @@ -0,0 +1,79 @@ +from ray.llm._internal.batch.processor import ( + ProcessorConfig as _ProcessorConfig, + Processor, + HttpRequestProcessorConfig as _HttpRequestProcessorConfig, +) +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class ProcessorConfig(_ProcessorConfig): + """The processor configuration.""" + + pass + + +@PublicAPI(stability="alpha") +class HttpRequestProcessorConfig(_HttpRequestProcessorConfig): + """The configuration for the HTTP request processor. + + Examples: + .. testcode:: + :skipif: True + + import ray + from ray.data.llm import HttpRequestProcessorConfig, build_llm_processor + + config = HttpRequestProcessorConfig( + url="https://api.openai.com/v1/chat/completions", + headers={"Authorization": "Bearer sk-..."}, + concurrency=1, + ) + processor = build_llm_processor( + config, + preprocess=lambda row: dict( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": "You are a calculator"}, + {"role": "user", "content": f"{row['id']} ** 3 = ?"}, + ], + temperature=0.3, + max_tokens=20, + ), + postprocess=lambda row: dict( + resp=row["choices"][0]["message"]["content"], + ), + ) + + ds = ray.data.range(10) + ds = processor(ds) + for row in ds.take_all(): + print(row) + """ + + pass + + +@PublicAPI(stability="alpha") +def build_llm_processor(config: ProcessorConfig, **kwargs) -> Processor: + """Build a LLM processor using the given config. + + Args: + config: The processor config. + **kwargs: Additional keyword arguments to pass to the processor. + See `Processor` for argument details. + + Returns: + The built processor. + """ + from ray.llm._internal.batch.processor import ProcessorBuilder + + return ProcessorBuilder.build(config, **kwargs) + + +__all__ = [ + "ProcessorConfig", + "Processor", + "HttpRequestProcessorConfig", + "build_llm_processor", +] diff --git a/python/ray/data/preprocessors/encoder.py b/python/ray/data/preprocessors/encoder.py index 8bd6af80f6b19..d907871775968 100644 --- a/python/ray/data/preprocessors/encoder.py +++ b/python/ray/data/preprocessors/encoder.py @@ -608,7 +608,8 @@ def get_pd_value_counts(df: pd.DataFrame) -> List[Dict[str, Counter]]: for batch in value_counts.iter_batches(batch_size=None): for col, counters in batch.items(): for counter in counters: - final_counters[col] += counter + counter = {k: v for k, v in counter.items() if v is not None} + final_counters[col] += Counter(counter) # Inspect if there is any NA values. for col in columns: diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 8322805da0836..ff49356741766 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -20,6 +20,7 @@ import ray from ray._private.auto_init_hook import wrap_auto_init from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray +from ray.data._internal.datasource.audio_datasource import AudioDatasource from ray.data._internal.datasource.avro_datasource import AvroDatasource from ray.data._internal.datasource.bigquery_datasource import BigQueryDatasource from ray.data._internal.datasource.binary_datasource import BinaryDatasource @@ -45,6 +46,7 @@ from ray.data._internal.datasource.text_datasource import TextDatasource from ray.data._internal.datasource.tfrecords_datasource import TFRecordDatasource from ray.data._internal.datasource.torch_datasource import TorchDatasource +from ray.data._internal.datasource.video_datasource import VideoDatasource from ray.data._internal.datasource.webdataset_datasource import WebDatasetDatasource from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.logical.operators.from_operators import ( @@ -398,6 +400,7 @@ def read_datasource( # TODO(hchen/chengsu): Remove the duplicated get_read_tasks call here after # removing LazyBlockList code path. read_tasks = datasource_or_legacy_reader.get_read_tasks(requested_parallelism) + import uuid stats = DatasetStats( @@ -423,6 +426,177 @@ def read_datasource( ) +@PublicAPI(stability="alpha") +def read_audio( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Optional[Partitioning] = None, + include_paths: bool = False, + ignore_missing_paths: bool = False, + file_extensions: Optional[List[str]] = AudioDatasource._FILE_EXTENSIONS, + shuffle: Union[Literal["files"], None] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, +): + """Creates a :class:`~ray.data.Dataset` from audio files. + + Examples: + >>> import ray + >>> path = "s3://anonymous@air-example-data-2/6G-audio-data-LibriSpeech-train-clean-100-flac/train-clean-100/5022/29411/5022-29411-0000.flac" + >>> ds = ray.data.read_audio(path) + >>> ds.schema() + Column Type + ------ ---- + amplitude numpy.ndarray(shape=(1, 191760), dtype=float) + sample_rate int64 + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + filesystem: The pyarrow filesystem + implementation to read from. These filesystems are specified in the + `pyarrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use + with a custom callback to read only selected partitions of a dataset. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to ``None``. + include_paths: If ``True``, include the path to each image. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file/directory paths in ``paths`` + that are not found. Defaults to False. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + override_num_blocks: Override the number of output blocks from all read tasks. + By default, the number of output blocks is dynamically decided based on + input data size and available resources. You shouldn't manually set this + value in most cases. + ray_remote_args: kwargs passed to :meth:`~ray.remote` in the read tasks. + + Returns: + A :class:`~ray.data.Dataset` containing audio amplitudes and associated + metadata. + """ # noqa: E501 + datasource = AudioDatasource( + paths, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=DefaultFileMetadataProvider(), + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + +@PublicAPI(stability="alpha") +def read_videos( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + partition_filter: Optional[PathPartitionFilter] = None, + partitioning: Optional[Partitioning] = None, + include_paths: bool = False, + ignore_missing_paths: bool = False, + file_extensions: Optional[List[str]] = VideoDatasource._FILE_EXTENSIONS, + shuffle: Union[Literal["files"], None] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, +): + """Creates a :class:`~ray.data.Dataset` from video files. + + Each row in the resulting dataset represents a video frame. + + Examples: + >>> import ray + >>> path = "s3://anonymous@ray-example-data/basketball.mp4" + >>> ds = ray.data.read_videos(path) + >>> ds.schema() + Column Type + ------ ---- + frame numpy.ndarray(shape=(720, 1280, 3), dtype=uint8) + frame_index int64 + + Args: + paths: A single file or directory, or a list of file or directory paths. + A list of paths can contain both files and directories. + filesystem: The pyarrow filesystem + implementation to read from. These filesystems are specified in the + `pyarrow docs `_. Specify this parameter if + you need to provide specific configurations to the filesystem. By default, + the filesystem is automatically selected based on the scheme of the paths. + For example, if the path begins with ``s3://``, the `S3FileSystem` is used. + arrow_open_stream_args: kwargs passed to + `pyarrow.fs.FileSystem.open_input_file `_. + when opening input files to read. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use + with a custom callback to read only selected partitions of a dataset. + partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object + that describes how paths are organized. Defaults to ``None``. + include_paths: If ``True``, include the path to each image. File paths are + stored in the ``'path'`` column. + ignore_missing_paths: If True, ignores any file/directory paths in ``paths`` + that are not found. Defaults to False. + file_extensions: A list of file extensions to filter files by. + concurrency: The maximum number of Ray tasks to run concurrently. Set this + to control number of tasks to run concurrently. This doesn't change the + total number of tasks run or the total number of output blocks. By default, + concurrency is dynamically decided based on the available resources. + ray_remote_args: kwargs passed to :meth:`~ray.remote` in the read tasks. + + Returns: + A :class:`~ray.data.Dataset` containing video frames from the video files. + """ + datasource = VideoDatasource( + paths, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=DefaultFileMetadataProvider(), + partition_filter=partition_filter, + partitioning=partitioning, + ignore_missing_paths=ignore_missing_paths, + shuffle=shuffle, + include_paths=include_paths, + file_extensions=file_extensions, + ) + return read_datasource( + datasource, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + ) + + @PublicAPI(stability="alpha") def read_mongo( uri: str, @@ -2095,6 +2269,8 @@ def read_sql( sql: str, connection_factory: Callable[[], Connection], *, + shard_keys: Optional[list[str]] = None, + shard_hash_fn: str = "MD5", parallelism: int = -1, ray_remote_args: Optional[Dict[str, Any]] = None, concurrency: Optional[int] = None, @@ -2105,14 +2281,16 @@ def read_sql( .. note:: - By default, ``read_sql`` launches multiple read tasks, and each task executes a - ``LIMIT`` and ``OFFSET`` to fetch a subset of the rows. However, for many - databases, ``OFFSET`` is slow. + Parallelism is supported by databases that support sharding. This means + that the database needs to support all of the following operations: + ``MOD``, ``ABS``, and ``CONCAT``. + + You can use ``shard_hash_fn`` to specify the hash function to use for sharding. + The default is ``MD5``, but other common alternatives include ``hash``, + ``unicode``, and ``SHA``. - As a workaround, set ``override_num_blocks=1`` to directly fetch all rows in a - single task. Note that this approach requires all result rows to fit in the - memory of single task. If the rows don't fit, your program may raise an out of - memory error. + If the database does not support sharding, the read operation will be + executed in a single task. Examples: @@ -2165,6 +2343,10 @@ def create_connection(): connection_factory: A function that takes no arguments and returns a Python DB API2 `Connection object `_. + shard_keys: The keys to shard the data by. + shard_hash_fn: The hash function string to use for sharding. Defaults to "MD5". + For other databases, common alternatives include "hash" and "SHA". + This is applied to the shard keys. parallelism: This argument is deprecated. Use ``override_num_blocks`` argument. ray_remote_args: kwargs passed to :func:`ray.remote` in the read tasks. concurrency: The maximum number of Ray tasks to run concurrently. Set this @@ -2172,6 +2354,7 @@ def create_connection(): total number of tasks run or the total number of output blocks. By default, concurrency is dynamically decided based on the available resources. override_num_blocks: Override the number of output blocks from all read tasks. + This is used for sharding when shard_keys is provided. By default, the number of output blocks is dynamically decided based on input data size and available resources. You shouldn't manually set this value in most cases. @@ -2179,13 +2362,21 @@ def create_connection(): Returns: A :class:`Dataset` containing the queried data. """ - if parallelism != -1 and parallelism != 1: - raise ValueError( - "To ensure correctness, 'read_sql' always launches one task. The " - "'parallelism' argument you specified can't be used." - ) + datasource = SQLDatasource( + sql=sql, + shard_keys=shard_keys, + shard_hash_fn=shard_hash_fn, + connection_factory=connection_factory, + ) + if override_num_blocks and override_num_blocks > 1: + if shard_keys is None: + raise ValueError("shard_keys must be provided when override_num_blocks > 1") + + if not datasource.supports_sharding(override_num_blocks): + raise ValueError( + "Database does not support sharding. Please set override_num_blocks to 1." + ) - datasource = SQLDatasource(sql=sql, connection_factory=connection_factory) return read_datasource( datasource, parallelism=parallelism, diff --git a/python/ray/data/tests/preprocessors/test_encoder.py b/python/ray/data/tests/preprocessors/test_encoder.py index bfae00596439c..5fa5570568590 100644 --- a/python/ray/data/tests/preprocessors/test_encoder.py +++ b/python/ray/data/tests/preprocessors/test_encoder.py @@ -14,6 +14,38 @@ ) +def test_ordinal_encoder_strings(): + """Test the OrdinalEncoder for strings.""" + + input_dataframe = pd.DataFrame({"sex": ["male"] * 2000 + ["female"]}) + + ds = ray.data.from_pandas(input_dataframe) + encoder = OrdinalEncoder(columns=["sex"]) + encoded_ds = encoder.fit_transform(ds) + encoded_ds_pd = encoded_ds.to_pandas() + + # Check if the "sex" column exists and is correctly encoded as integers + assert ( + "sex" in encoded_ds_pd.columns + ), "The 'sex' column is missing in the encoded DataFrame" + assert ( + encoded_ds_pd["sex"].dtype == "int64" + ), "The 'sex' column is not encoded as integers" + + # Verify that the encoding worked as expected. + # We expect "male" to be encoded as 0 and "female" as 1 + unique_values = encoded_ds_pd["sex"].unique() + assert set(unique_values) == { + 0, + 1, + }, f"Unexpected unique values in 'sex' column: {unique_values}" + expected_encoding = {"male": 1, "female": 0} + for original, encoded in zip(input_dataframe["sex"], encoded_ds_pd["sex"]): + assert ( + encoded == expected_encoding[original] + ), f"Expected {original} to be encoded as {expected_encoding[original]}, but got {encoded}" # noqa: E501 + + def test_ordinal_encoder(): """Tests basic OrdinalEncoder functionality.""" col_a = ["red", "green", "blue", "red"] diff --git a/python/ray/data/tests/test_actor_pool_map_operator.py b/python/ray/data/tests/test_actor_pool_map_operator.py index f979d78904954..6c316b8cc4cf7 100644 --- a/python/ray/data/tests/test_actor_pool_map_operator.py +++ b/python/ray/data/tests/test_actor_pool_map_operator.py @@ -25,6 +25,9 @@ def __init__(self, node_id: str = "node1"): def get_location(self) -> str: return self.node_id + def on_exit(self): + pass + class TestActorPool(unittest.TestCase): def setup_class(self): diff --git a/python/ray/data/tests/test_arrow_serialization.py b/python/ray/data/tests/test_arrow_serialization.py index 232ed32cc7497..d8c65d5ad2a60 100644 --- a/python/ray/data/tests/test_arrow_serialization.py +++ b/python/ray/data/tests/test_arrow_serialization.py @@ -234,10 +234,7 @@ def fixed_size_list_array(): @pytest.fixture def map_array(): return pa.array( - [ - [(key, item) for key, item in zip("abcdefghij", range(10))] - for _ in range(1000) - ], + [list(zip("abcdefghij", range(10))) for _ in range(1000)], type=pa.map_(pa.string(), pa.int64()), ) @@ -349,10 +346,7 @@ def complex_nested_array(): ] ), pa.array( - [ - [(key, item) for key, item in zip("abcdefghij", range(10))] - for _ in range(1000) - ], + [list(zip("abcdefghij", range(10))) for _ in range(1000)], type=pa.map_(pa.string(), pa.int64()), ), ], diff --git a/python/ray/data/tests/test_audio.py b/python/ray/data/tests/test_audio.py new file mode 100644 index 0000000000000..0ba9ca4d6b71e --- /dev/null +++ b/python/ray/data/tests/test_audio.py @@ -0,0 +1,38 @@ +import numpy as np +import pytest + +import ray +from ray.tests.conftest import * # noqa + +NUM_AUDIO_FILES = 10 + + +@pytest.fixture +def audio_uri(): + root = "s3://anonymous@air-example-data-2/6G-audio-data-LibriSpeech-train-clean-100-flac" # noqa: E501 + return [ + f"{root}/train-clean-100/5022/29411/5022-29411-{n:04}.flac" + for n in range(NUM_AUDIO_FILES) + ] + + +def test_read_audio(ray_start_regular_shared, audio_uri): + ds = ray.data.read_audio(audio_uri) + + # Verify basic audio properties + assert ds.count() == NUM_AUDIO_FILES, ds.count() + assert ds.schema().names == ["amplitude", "sample_rate"], ds.schema() + + # Check the sample rate + assert all(row["sample_rate"] == 16000 for row in ds.take_all()) + + for row in ds.take_all(): + assert row["amplitude"].ndim == 2 + assert row["amplitude"].shape[0] == 1 + assert row["amplitude"].dtype == np.float32 + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/data/tests/test_csv.py b/python/ray/data/tests/test_csv.py index fc6709110119a..3aab217c7666a 100644 --- a/python/ray/data/tests/test_csv.py +++ b/python/ray/data/tests/test_csv.py @@ -825,17 +825,17 @@ def test_csv_invalid_file_handler(ray_start_regular_shared, tmp_path): ) -@pytest.mark.parametrize("num_rows_per_file", [5, 10, 50]) -def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_per_file): +@pytest.mark.parametrize("min_rows_per_file", [5, 10, 50]) +def test_write_min_rows_per_file(tmp_path, ray_start_regular_shared, min_rows_per_file): ray.data.range(100, override_num_blocks=20).write_csv( - tmp_path, num_rows_per_file=num_rows_per_file + tmp_path, min_rows_per_file=min_rows_per_file ) for filename in os.listdir(tmp_path): with open(os.path.join(tmp_path, filename), "r") as file: # Subtract 1 from the number of lines to account for the header. num_rows_written = len(file.read().splitlines()) - 1 - assert num_rows_written == num_rows_per_file + assert num_rows_written == min_rows_per_file if __name__ == "__main__": diff --git a/python/ray/data/tests/test_datasink.py b/python/ray/data/tests/test_datasink.py index 8b5eaff9f2c1d..cfbb32328a226 100644 --- a/python/ray/data/tests/test_datasink.py +++ b/python/ray/data/tests/test_datasink.py @@ -105,21 +105,21 @@ def get_node_id(): assert node_ids == {bar_node_id} -@pytest.mark.parametrize("num_rows_per_write", [5, 10, 50]) -def test_num_rows_per_write(tmp_path, ray_start_regular_shared, num_rows_per_write): +@pytest.mark.parametrize("min_rows_per_write", [5, 10, 50]) +def test_min_rows_per_write(tmp_path, ray_start_regular_shared, min_rows_per_write): class MockDatasink(Datasink[None]): - def __init__(self, num_rows_per_write): - self._num_rows_per_write = num_rows_per_write + def __init__(self, min_rows_per_write): + self._min_rows_per_write = min_rows_per_write def write(self, blocks: Iterable[Block], ctx: TaskContext) -> None: - assert sum(len(block) for block in blocks) == self._num_rows_per_write + assert sum(len(block) for block in blocks) == self._min_rows_per_write @property - def num_rows_per_write(self): - return self._num_rows_per_write + def min_rows_per_write(self): + return self._min_rows_per_write ray.data.range(100, override_num_blocks=20).write_datasink( - MockDatasink(num_rows_per_write) + MockDatasink(min_rows_per_write) ) diff --git a/python/ray/data/tests/test_exceptions.py b/python/ray/data/tests/test_exceptions.py index 2b0de96dad04d..9818c3f379e03 100644 --- a/python/ray/data/tests/test_exceptions.py +++ b/python/ray/data/tests/test_exceptions.py @@ -21,8 +21,7 @@ def test_user_exception( ctx.log_internal_stack_trace_to_stdout = log_internal_stack_trace_to_stdout def f(row): - 1 / 0 - return row + _ = 1 / 0 with pytest.raises(UserCodeException) as exc_info: ray.data.range(1).map(f).take_all() @@ -80,7 +79,7 @@ def test_full_traceback_logged_with_ray_debugger( monkeypatch.setenv("RAY_DEBUG_POST_MORTEM", 1) def f(row): - 1 / 0 + _ = 1 / 0 return row with pytest.raises(Exception) as exc_info: diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 5790ec891c73e..837927e42f1ba 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -130,7 +130,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: large_object = np.zeros((128, 1024, 1024), dtype=np.uint8) # 128 MiB def read_fn(): - large_object + _ = large_object yield pd.DataFrame({"column": [0]}) return [ReadTask(read_fn, BlockMetadata(1, None, None, None, None))] diff --git a/python/ray/data/tests/test_expression_evaluator.py b/python/ray/data/tests/test_expression_evaluator.py index 08e29a78ad42d..040b42e192466 100644 --- a/python/ray/data/tests/test_expression_evaluator.py +++ b/python/ray/data/tests/test_expression_evaluator.py @@ -33,8 +33,13 @@ def sample_data(tmpdir_factory): "is_student": [False, True, False, False, True, None], # Including a None value } + # Define the schema explicitly + schema = pa.schema( + [("age", pa.float64()), ("city", pa.string()), ("is_student", pa.bool_())] + ) + # Create a PyArrow table from the sample data - table = pa.table(data) + table = pa.table(data, schema=schema) # Use tmpdir_factory to create a temporary directory temp_dir = tmpdir_factory.mktemp("data") @@ -44,7 +49,7 @@ def sample_data(tmpdir_factory): pq.write_table(table, str(parquet_file)) # Yield the path to the Parquet file for testing - yield str(parquet_file) + yield str(parquet_file), schema expressions_and_expected_data = [ @@ -290,13 +295,13 @@ def sample_data(tmpdir_factory): @pytest.mark.parametrize("expression, expected_data", expressions_and_expected_data) def test_filter(sample_data, expression, expected_data): """Test the filter functionality of the ExpressionEvaluator.""" - # Instantiate the ExpressionEvaluator with valid column names - evaluator = ExpressionEvaluator() - filters = evaluator.get_filters(expression) + # Instantiate the ExpressionEvaluator with valid column names + sample_data_path, _ = sample_data + filters = ExpressionEvaluator.get_filters(expression=expression) # Read the table from the Parquet file with the applied filters - filtered_table = pq.read_table(sample_data, filters=filters) + filtered_table = pq.read_table(sample_data_path, filters=filters) # Convert the filtered table back to a list of dictionaries for comparison result = filtered_table.to_pandas().to_dict(orient="records") @@ -314,11 +319,11 @@ def convert_nan_to_none(data): def test_filter_bad_expression(sample_data): - evaluator = ExpressionEvaluator() with pytest.raises(ValueError, match="Invalid syntax in the expression"): - evaluator.get_filters("bad filter") + ExpressionEvaluator.get_filters(expression="bad filter") - filters = evaluator.get_filters("hi > 3") + filters = ExpressionEvaluator.get_filters(expression="hi > 3") + sample_data_path, _ = sample_data with pytest.raises(pa.ArrowInvalid): - pq.read_table(sample_data, filters=filters) + pq.read_table(sample_data_path, filters=filters) diff --git a/python/ray/data/tests/test_file_based_datasource.py b/python/ray/data/tests/test_file_based_datasource.py index 8eb985c8a7449..8054f6bbf8838 100644 --- a/python/ray/data/tests/test_file_based_datasource.py +++ b/python/ray/data/tests/test_file_based_datasource.py @@ -8,11 +8,7 @@ import ray from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data.block import Block -from ray.data.datasource.file_based_datasource import ( - OPEN_FILE_MAX_ATTEMPTS, - FileBasedDatasource, - _open_file_with_retry, -) +from ray.data.datasource.file_based_datasource import FileBasedDatasource from ray.data.datasource.path_util import _has_file_extension, _is_local_windows_path @@ -94,36 +90,35 @@ def test_file_extensions(ray_start_regular_shared, tmp_path): assert ds.input_files() == [csv_path] -def test_open_file_with_retry(ray_start_regular_shared): - class FlakyFileOpener: - def __init__(self, max_attempts: int): - self.retry_attempts = 0 - self.max_attempts = max_attempts - - def open(self): - self.retry_attempts += 1 - if self.retry_attempts < self.max_attempts: - raise OSError( - "When creating key x in bucket y: AWS Error SLOW_DOWN during " - "PutObject operation: Please reduce your request rate." - ) - return "dummy" - - original_max_attempts = OPEN_FILE_MAX_ATTEMPTS - try: - # Test openning file successfully after retries. - opener = FlakyFileOpener(3) - assert _open_file_with_retry("dummy", lambda: opener.open()) == "dummy" - - # Test exhausting retries and failed eventually. - ray.data.datasource.file_based_datasource.OPEN_FILE_MAX_ATTEMPTS = 3 - opener = FlakyFileOpener(4) - with pytest.raises(OSError): - _open_file_with_retry("dummy", lambda: opener.open()) - finally: - ray.data.datasource.file_based_datasource.OPEN_FILE_MAX_ATTEMPTS = ( - original_max_attempts - ) +def test_flaky_datasource(ray_start_regular_shared): + + from ray.data._internal.datasource.csv_datasource import CSVDatasource + + class Counter: + def __init__(self): + self.value = 0 + + def increment(self): + self.value += 1 + return self.value + + class FlakyCSVDatasource(CSVDatasource): + def __init__(self, paths, **csv_datasource_kwargs): + super().__init__(paths, **csv_datasource_kwargs) + CounterActor = ray.remote(Counter) + self.counter = CounterActor.remote() + + def _read_stream(self, f: "pyarrow.NativeFile", path: str): + count = self.counter.increment.remote() + if ray.get(count) == 1: + raise RuntimeError("AWS Error INTERNAL_FAILURE") + else: + for block in CSVDatasource._read_stream(self, f, path): + yield block + + datasource = FlakyCSVDatasource(["example://iris.csv"]) + ds = ray.data.read_datasource(datasource) + assert len(ds.take()) == 20 def test_windows_path(): diff --git a/python/ray/data/tests/test_file_datasink.py b/python/ray/data/tests/test_file_datasink.py index 8b1be4f5e5464..75de5be52a4af 100644 --- a/python/ray/data/tests/test_file_datasink.py +++ b/python/ray/data/tests/test_file_datasink.py @@ -147,8 +147,8 @@ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): assert os.path.isdir(path) -@pytest.mark.parametrize("num_rows_per_file", [5, 10, 50]) -def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_per_file): +@pytest.mark.parametrize("min_rows_per_file", [5, 10, 50]) +def test_write_min_rows_per_file(tmp_path, ray_start_regular_shared, min_rows_per_file): class MockFileDatasink(BlockBasedFileDatasink): def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): for _ in range(block.num_rows()): @@ -157,14 +157,14 @@ def write_block_to_file(self, block: BlockAccessor, file: "pyarrow.NativeFile"): ds = ray.data.range(100, override_num_blocks=20) ds.write_datasink( - MockFileDatasink(path=tmp_path, num_rows_per_file=num_rows_per_file) + MockFileDatasink(path=tmp_path, min_rows_per_file=min_rows_per_file) ) num_rows_written_total = 0 for filename in os.listdir(tmp_path): with open(os.path.join(tmp_path, filename), "r") as file: num_rows_written = len(file.read().splitlines()) - assert num_rows_written == num_rows_per_file + assert num_rows_written == min_rows_per_file num_rows_written_total += num_rows_written assert num_rows_written_total == 100 diff --git a/python/ray/data/tests/test_json.py b/python/ray/data/tests/test_json.py index 822da43469641..9a179c67af743 100644 --- a/python/ray/data/tests/test_json.py +++ b/python/ray/data/tests/test_json.py @@ -645,16 +645,16 @@ def test_json_read_across_blocks(ray_start_regular_shared, fs, data_path, endpoi dsdf = ds.to_pandas() -@pytest.mark.parametrize("num_rows_per_file", [5, 10, 50]) -def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_per_file): +@pytest.mark.parametrize("min_rows_per_file", [5, 10, 50]) +def test_write_min_rows_per_file(tmp_path, ray_start_regular_shared, min_rows_per_file): ray.data.range(100, override_num_blocks=20).write_json( - tmp_path, num_rows_per_file=num_rows_per_file + tmp_path, min_rows_per_file=min_rows_per_file ) for filename in os.listdir(tmp_path): with open(os.path.join(tmp_path, filename), "r") as file: num_rows_written = len(file.read().splitlines()) - assert num_rows_written == num_rows_per_file + assert num_rows_written == min_rows_per_file def test_mixed_gzipped_json_files(ray_start_regular_shared, tmp_path): diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index a8b6e65ad3327..4f07b239e584e 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -1500,6 +1500,31 @@ def test_random_sample_checks(ray_start_regular_shared): ray.data.range(1).random_sample(10) +def test_actor_udf_cleanup(ray_start_regular_shared, tmp_path): + """Test that for the actor map operator, the UDF object is deleted properly.""" + test_file = tmp_path / "test.txt" + + # Simulate the case that the UDF depends on some external resources that + # need to be cleaned up. + class StatefulUDF: + def __init__(self): + with open(test_file, "w") as f: + f.write("test") + + def __call__(self, row): + return row + + def __del__(self): + # Delete the file when the UDF is deleted. + os.remove(test_file) + + ds = ray.data.range(10) + ds = ds.map(StatefulUDF, concurrency=1) + assert sorted(extract_values("id", ds.take_all())) == list(range(10)) + + wait_for_condition(lambda: not os.path.exists(test_file)) + + # NOTE: All tests above share a Ray cluster, while the tests below do not. These # tests should only be carefully reordered to retain this invariant! def test_actor_pool_strategy_default_num_actors(shutdown_only): diff --git a/python/ray/data/tests/test_numpy.py b/python/ray/data/tests/test_numpy.py index 7b1fd5c1d3dc7..d71edb6526e05 100644 --- a/python/ray/data/tests/test_numpy.py +++ b/python/ray/data/tests/test_numpy.py @@ -292,15 +292,15 @@ def test_numpy_write(ray_start_regular_shared, fs, data_path, endpoint_url): np.testing.assert_equal(extract_values("data", ds.take(1)), [np.array([0])]) -@pytest.mark.parametrize("num_rows_per_file", [5, 10, 50]) -def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_per_file): +@pytest.mark.parametrize("min_rows_per_file", [5, 10, 50]) +def test_write_min_rows_per_file(tmp_path, ray_start_regular_shared, min_rows_per_file): ray.data.range(100, override_num_blocks=20).write_numpy( - tmp_path, column="id", num_rows_per_file=num_rows_per_file + tmp_path, column="id", min_rows_per_file=min_rows_per_file ) for filename in os.listdir(tmp_path): array = np.load(os.path.join(tmp_path, filename)) - assert len(array) == num_rows_per_file + assert len(array) == min_rows_per_file if __name__ == "__main__": diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index 412839d71a74c..e60afe556313b 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -1210,17 +1210,17 @@ def test_parquet_bulk_columns(ray_start_regular_shared): assert ds.columns() == ["variety"] -@pytest.mark.parametrize("num_rows_per_file", [5, 10, 50]) -def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_per_file): +@pytest.mark.parametrize("min_rows_per_file", [5, 10, 50]) +def test_write_min_rows_per_file(tmp_path, ray_start_regular_shared, min_rows_per_file): import pyarrow.parquet as pq ray.data.range(100, override_num_blocks=20).write_parquet( - tmp_path, num_rows_per_file=num_rows_per_file + tmp_path, min_rows_per_file=min_rows_per_file ) for filename in os.listdir(tmp_path): table = pq.read_table(os.path.join(tmp_path, filename)) - assert len(table) == num_rows_per_file + assert len(table) == min_rows_per_file @pytest.mark.parametrize("shuffle", [True, False, "file"]) @@ -1393,7 +1393,7 @@ def test_write_auto_infer_nullable_fields( ctx.target_max_block_size = 1 ds = ray.data.range(len(row_data)).map(lambda row: row_data[row["id"]]) # So we force writing to a single file. - ds.write_parquet(tmp_path, num_rows_per_file=2) + ds.write_parquet(tmp_path, min_rows_per_file=2) def test_seed_file_shuffle(restore_data_context, tmp_path): diff --git a/python/ray/data/tests/test_sql.py b/python/ray/data/tests/test_sql.py index 677e73457aa34..ffc9a657913fe 100644 --- a/python/ray/data/tests/test_sql.py +++ b/python/ray/data/tests/test_sql.py @@ -22,13 +22,6 @@ def temp_database_fixture() -> Generator[str, None, None]: yield file.name -def test_read_sql_with_parallelism_warns(temp_database): - with pytest.raises(ValueError): - ray.data.read_sql( - "SELECT * FROM movie", lambda: sqlite3.connect(temp_database), parallelism=2 - ) - - def test_read_sql(temp_database: str): connection = sqlite3.connect(temp_database) connection.execute("CREATE TABLE movie(title, year, score)") @@ -49,6 +42,73 @@ def test_read_sql(temp_database: str): assert sorted(actual_values) == sorted(expected_values) +def test_read_sql_with_parallelism_fallback(temp_database: str): + connection = sqlite3.connect(temp_database) + connection.execute("CREATE TABLE grade(name, id, score)") + base_tuple = ("xiaoming", 1, 8.2) + # Generate 200 elements + expected_values = [ + (f"{base_tuple[0]}{i}", i, base_tuple[2] + i + 1) for i in range(500) + ] + connection.executemany("INSERT INTO grade VALUES (?, ?, ?)", expected_values) + connection.commit() + connection.close() + + num_blocks = 2 + dataset = ray.data.read_sql( + "SELECT * FROM grade", + lambda: sqlite3.connect(temp_database), + override_num_blocks=num_blocks, + shard_hash_fn="unicode", + shard_keys=["id"], + ) + dataset = dataset.materialize() + assert dataset.num_blocks() == num_blocks + + actual_values = [tuple(record.values()) for record in dataset.take_all()] + assert sorted(actual_values) == sorted(expected_values) + + +# for mysql test +@pytest.mark.skip(reason="skip this test because mysql env is not ready") +def test_read_sql_with_parallelism_mysql(temp_database: str): + # connect mysql + import pymysql + + connection = pymysql.connect( + host="10.10.xx.xx", user="root", password="22222", database="test" + ) + cursor = connection.cursor() + + cursor.execute( + "CREATE TABLE IF NOT EXISTS grade (name VARCHAR(255), id INT, score FLOAT)" + ) + + base_tuple = ("xiaoming", 1, 8.2) + expected_values = [ + (f"{base_tuple[0]}{i}", i, base_tuple[2] + i + 1) for i in range(200) + ] + + cursor.executemany( + "INSERT INTO grade (name, id, score) VALUES (%s, %s, %s)", expected_values + ) + connection.commit() + + cursor.close() + connection.close() + + dataset = ray.data.read_sql( + "SELECT * FROM grade", + lambda: pymysql.connect(host="xxxxx", user="xx", password="xx", database="xx"), + parallelism=4, + shard_keys=["id"], + ) + actual_values = [tuple(record.values()) for record in dataset.take_all()] + + assert sorted(actual_values) == sorted(expected_values) + assert dataset.materialize().num_blocks() == 4 + + def test_write_sql(temp_database: str): connection = sqlite3.connect(temp_database) connection.cursor().execute("CREATE TABLE test(string, number)") diff --git a/python/ray/data/tests/test_streaming_executor.py b/python/ray/data/tests/test_streaming_executor.py index 1a4b21d409000..c2e508bd6fd09 100644 --- a/python/ray/data/tests/test_streaming_executor.py +++ b/python/ray/data/tests/test_streaming_executor.py @@ -1,7 +1,7 @@ import time import unittest from concurrent.futures import ThreadPoolExecutor -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest @@ -680,7 +680,7 @@ def after_execution_fails(self, error: Exception): remove_execution_callback(callback, ctx) assert get_execution_callbacks(ctx) == [] - # Test the failure case. + # Test the case where the dataset fails due to an error in the UDF. ds = ray.data.range(10) ctx = ds.context ctx.raise_original_map_exception = True @@ -698,6 +698,27 @@ def map_fn(_): error = callback._execution_error assert isinstance(error, ValueError), error + # Test the case the dataset is canceled by "ctrl-c". + ds = ray.data.range(10) + ctx = ds.context + callback = CustomExecutionCallback() + add_execution_callback(callback, ctx) + + def patched_get_outupt_blocking(*args, **kwargs): + raise KeyboardInterrupt() + + with patch( + "ray.data._internal.execution.streaming_executor.OpState.get_output_blocking", + new=patched_get_outupt_blocking, + ): + with pytest.raises(KeyboardInterrupt): + ds.take_all() + + assert callback._before_execution_starts_called + assert not callback._after_execution_succeeds_called + error = callback._execution_error + assert isinstance(error, KeyboardInterrupt), error + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_tfrecords.py b/python/ray/data/tests/test_tfrecords.py index 4d5a5699fd831..710b8d4385b99 100644 --- a/python/ray/data/tests/test_tfrecords.py +++ b/python/ray/data/tests/test_tfrecords.py @@ -742,15 +742,15 @@ def test_read_with_invalid_schema( ) -@pytest.mark.parametrize("num_rows_per_file", [5, 10, 50]) -def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_per_file): +@pytest.mark.parametrize("min_rows_per_file", [5, 10, 50]) +def test_write_min_rows_per_file(tmp_path, ray_start_regular_shared, min_rows_per_file): ray.data.range(100, override_num_blocks=20).write_tfrecords( - tmp_path, num_rows_per_file=num_rows_per_file + tmp_path, min_rows_per_file=min_rows_per_file ) for filename in os.listdir(tmp_path): dataset = tf.data.TFRecordDataset(os.path.join(tmp_path, filename)) - assert len(list(dataset)) == num_rows_per_file + assert len(list(dataset)) == min_rows_per_file def read_tfrecords_with_tfx_read_override(paths, tfx_read=False, **read_opts): diff --git a/python/ray/data/tests/test_transform_pyarrow.py b/python/ray/data/tests/test_transform_pyarrow.py index a221bd6c76839..8cffe30da4983 100644 --- a/python/ray/data/tests/test_transform_pyarrow.py +++ b/python/ray/data/tests/test_transform_pyarrow.py @@ -5,6 +5,8 @@ import pandas as pd import pyarrow as pa import pytest +from packaging.version import parse as parse_version +from ray._private.utils import _get_pyarrow_version import ray from ray.air.util.tensor_extensions.arrow import ArrowTensorTypeV2 @@ -198,9 +200,6 @@ def test_arrow_concat_tensor_extension_uniform_but_different(): # fails for this case. -@pytest.mark.skipif( - not _object_extension_type_allowed(), reason="Object extension type not supported." -) def test_arrow_concat_with_objects(): obj = types.SimpleNamespace(a=1, b="test") t1 = pa.table({"a": [3, 4], "b": [7, 8]}) @@ -214,6 +213,382 @@ def test_arrow_concat_with_objects(): assert t3.column("b").to_pylist() == [7, 8, 0, 1] +@pytest.mark.skipif( + parse_version(_get_pyarrow_version()) < parse_version("17.0.0"), + reason="Requires PyArrow version 17 or higher", +) +def test_struct_with_different_field_names(): + # Ensures that when concatenating tables with struct columns having different + # field names, missing fields in each struct are filled with None in the + # resulting table. + + t1 = pa.table( + { + "a": [1, 2], + "d": pa.array( + [{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], + type=pa.struct([("x", pa.int32()), ("y", pa.string())]), + ), + } + ) + + t2 = pa.table( + { + "a": [3], + "d": pa.array( + [{"x": 3, "z": "c"}], + type=pa.struct([("x", pa.int32()), ("z", pa.string())]), + ), + } + ) + + # Concatenate tables with different field names in struct + t3 = concat([t1, t2]) + + assert isinstance(t3, pa.Table) + assert len(t3) == 3 + + # Check the entire schema + expected_schema = pa.schema( + [ + ("a", pa.int64()), + ( + "d", + pa.struct( + [ + ("x", pa.int32()), + ("y", pa.string()), + ("z", pa.string()), + ] + ), + ), + ] + ) + assert t3.schema == expected_schema + + # Check that missing fields are filled with None + assert t3.column("a").to_pylist() == [1, 2, 3] + assert t3.column("d").to_pylist() == [ + {"x": 1, "y": "a", "z": None}, + {"x": 2, "y": "b", "z": None}, + {"x": 3, "y": None, "z": "c"}, + ] + + +@pytest.mark.skipif( + parse_version(_get_pyarrow_version()) < parse_version("17.0.0"), + reason="Requires PyArrow version 17 or higher", +) +def test_nested_structs(): + # Checks that deeply nested structs (3 levels of nesting) are handled properly + # during concatenation and the resulting table preserves the correct nesting + # structure. + + t1 = pa.table( + { + "a": [1], + "d": pa.array( + [ + { + "x": { + "y": {"p": 1}, # Missing "q" + "z": {"m": 3}, # Missing "n" + }, + "w": 5, + } + ], + type=pa.struct( + [ + ( + "x", + pa.struct( + [ + ( + "y", + pa.struct([("p", pa.int32())]), # Only "p" + ), + ( + "z", + pa.struct([("m", pa.int32())]), # Only "m" + ), + ] + ), + ), + ("w", pa.int32()), + ] + ), + ), + } + ) + + t2 = pa.table( + { + "a": [2], + "d": pa.array( + [ + { + "x": { + "y": {"q": 7}, # Missing "p" + "z": {"n": 9}, # Missing "m" + }, + "w": 10, + } + ], + type=pa.struct( + [ + ( + "x", + pa.struct( + [ + ( + "y", + pa.struct([("q", pa.int32())]), # Only "q" + ), + ( + "z", + pa.struct([("n", pa.int32())]), # Only "n" + ), + ] + ), + ), + ("w", pa.int32()), + ] + ), + ), + } + ) + + # Concatenate tables with nested structs and missing fields + t3 = concat([t1, t2]) + assert isinstance(t3, pa.Table) + assert len(t3) == 2 + + # Validate the schema of the resulting table + expected_schema = pa.schema( + [ + ("a", pa.int64()), + ( + "d", + pa.struct( + [ + ( + "x", + pa.struct( + [ + ( + "y", + pa.struct( + [("p", pa.int32()), ("q", pa.int32())] + ), + ), + ( + "z", + pa.struct( + [("m", pa.int32()), ("n", pa.int32())] + ), + ), + ] + ), + ), + ("w", pa.int32()), + ] + ), + ), + ] + ) + assert t3.schema == expected_schema + + # Validate the data in the concatenated table + assert t3.column("a").to_pylist() == [1, 2] + assert t3.column("d").to_pylist() == [ + { + "x": { + "y": {"p": 1, "q": None}, # Missing "q" filled with None + "z": {"m": 3, "n": None}, # Missing "n" filled with None + }, + "w": 5, + }, + { + "x": { + "y": {"p": None, "q": 7}, # Missing "p" filled with None + "z": {"m": None, "n": 9}, # Missing "m" filled with None + }, + "w": 10, + }, + ] + + +def test_struct_with_null_values(): + # Ensures that when concatenating tables with struct columns containing null + # values, the null values are properly handled, and the result reflects the + # expected structure. + + # Define the first table with struct containing null values + t1 = pa.table( + { + "a": [1, 2], + "d": pa.array( + [{"x": 1, "y": "a"}, None], # Second row is null + type=pa.struct([("x", pa.int32()), ("y", pa.string())]), + ), + } + ) + + # Define the second table with struct containing a null value + t2 = pa.table( + { + "a": [3], + "d": pa.array( + [None], # Entire struct is null + type=pa.struct([("x", pa.int32()), ("y", pa.string())]), + ), + } + ) + + # Concatenate tables with struct columns containing null values + t3 = concat([t1, t2]) + assert isinstance(t3, pa.Table) + assert len(t3) == 3 + + # Validate the schema of the resulting table + expected_schema = pa.schema( + [ + ("a", pa.int64()), + ("d", pa.struct([("x", pa.int32()), ("y", pa.string())])), + ] + ) + assert ( + t3.schema == expected_schema + ), f"Expected schema: {expected_schema}, but got {t3.schema}" + + # Verify the PyArrow table content + assert t3.column("a").to_pylist() == [1, 2, 3] + + # Adjust expected to match the format of the actual result + expected = [ + {"x": 1, "y": "a"}, + None, # Entire struct is None, not {"x": None, "y": None} + None, # Entire struct is None, not {"x": None, "y": None} + ] + + result = t3.column("d").to_pylist() + assert result == expected, f"Expected {expected}, but got {result}" + + +def test_struct_with_mismatched_lengths(): + # Verifies that when concatenating tables with struct columns of different lengths, + # the missing values are properly padded with None in the resulting table. + # Define the first table with 2 rows and a struct column + t1 = pa.table( + { + "a": [1, 2], + "d": pa.array( + [{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], + type=pa.struct([("x", pa.int32()), ("y", pa.string())]), + ), + } + ) + + # Define the second table with 1 row and a struct column + t2 = pa.table( + { + "a": [3], + "d": pa.array( + [{"x": 3, "y": "c"}], + type=pa.struct([("x", pa.int32()), ("y", pa.string())]), + ), + } + ) + + # Concatenate tables with struct columns of different lengths + t3 = concat([t1, t2]) + assert isinstance(t3, pa.Table) + assert len(t3) == 3 # Check that the resulting table has the correct number of rows + + # Validate the schema of the resulting table + expected_schema = pa.schema( + [ + ("a", pa.int64()), + ("d", pa.struct([("x", pa.int32()), ("y", pa.string())])), + ] + ) + assert ( + t3.schema == expected_schema + ), f"Expected schema: {expected_schema}, but got {t3.schema}" + + # Verify the content of the resulting table + assert t3.column("a").to_pylist() == [1, 2, 3] + expected = [ + {"x": 1, "y": "a"}, + {"x": 2, "y": "b"}, + {"x": 3, "y": "c"}, + ] + result = t3.column("d").to_pylist() + + assert result == expected, f"Expected {expected}, but got {result}" + + +def test_struct_with_empty_arrays(): + # Checks the behavior when concatenating tables with structs containing empty + # arrays, verifying that null structs are correctly handled. + + # Define the first table with valid struct data + t1 = pa.table( + { + "a": [1, 2], + "d": pa.array( + [{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], + type=pa.struct([("x", pa.int32()), ("y", pa.string())]), + ), + } + ) + + # Define the second table with null struct value (empty arrays for fields) + x_array = pa.array([None], type=pa.int32()) + y_array = pa.array([None], type=pa.string()) + + # Create a struct array from null field arrays + null_struct_array = pa.StructArray.from_arrays( + [x_array, y_array], + ["x", "y"], + mask=pa.array([True]), + ) + + t2 = pa.table({"a": [3], "d": null_struct_array}) + + # Concatenate tables with struct columns containing null values + t3 = concat([t1, t2]) + + # Verify that the concatenated result is a valid PyArrow Table + assert isinstance(t3, pa.Table) + assert len(t3) == 3 # Check that the concatenated table has 3 rows + + # Validate the schema of the resulting concatenated table + expected_schema = pa.schema( + [ + ("a", pa.int64()), # Assuming 'a' is an integer column + ( + "d", + pa.struct([("x", pa.int32()), ("y", pa.string())]), + ), # Struct column 'd' + ] + ) + assert ( + t3.schema == expected_schema + ), f"Expected schema: {expected_schema}, but got {t3.schema}" + + # Verify the content of the concatenated table + assert t3.column("a").to_pylist() == [1, 2, 3] + expected = [ + {"x": 1, "y": "a"}, + {"x": 2, "y": "b"}, + None, # Entire struct is None, as PyArrow handles it + ] + result = t3.column("d").to_pylist() + + assert result == expected, f"Expected {expected}, but got {result}" + + def test_arrow_concat_object_with_tensor_fails(): obj = types.SimpleNamespace(a=1, b="test") t1 = pa.table({"a": ArrowPythonObjectArray.from_objects([obj, obj]), "b": [0, 1]}) diff --git a/python/ray/data/tests/test_video.py b/python/ray/data/tests/test_video.py new file mode 100644 index 0000000000000..b700eb1cc80fd --- /dev/null +++ b/python/ray/data/tests/test_video.py @@ -0,0 +1,28 @@ +import pyarrow as pa +import pytest + +import ray + + +def test_read_videos(): + uri = "s3://anonymous@ray-example-data/basketball.mp4" + ds = ray.data.read_videos(uri) + + assert ds.count() == 333 + assert ds.schema().names == ["frame", "frame_index"] + + frame_indices = ds.select_columns(["frame_index"]).take_all() + assert sorted(frame_indices, key=lambda item: item["frame_index"]) == [ + {"frame_index": i} for i in range(333) + ] + + frame_type, frame_index_type = ds.schema().types + assert frame_type.shape == (720, 1280, 3) + assert frame_type.scalar_type == pa.uint8() + assert frame_index_type == pa.int64() + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/data/tests/test_webdataset.py b/python/ray/data/tests/test_webdataset.py index feb62b543fd49..c9a541f0d3ca8 100644 --- a/python/ray/data/tests/test_webdataset.py +++ b/python/ray/data/tests/test_webdataset.py @@ -269,15 +269,15 @@ def test_webdataset_decoding(ray_start_2_cpus, tmp_path): assert meta_json["e"]["img_filename"] == "for_test.jpg" -@pytest.mark.parametrize("num_rows_per_file", [5, 10, 50]) -def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_per_file): +@pytest.mark.parametrize("min_rows_per_file", [5, 10, 50]) +def test_write_min_rows_per_file(tmp_path, ray_start_regular_shared, min_rows_per_file): ray.data.from_items( [{"id": str(i)} for i in range(100)], override_num_blocks=20 - ).write_webdataset(tmp_path, num_rows_per_file=num_rows_per_file) + ).write_webdataset(tmp_path, min_rows_per_file=min_rows_per_file) for filename in os.listdir(tmp_path): dataset = wds.WebDataset(os.path.join(tmp_path, filename)) - assert len(list(dataset)) == num_rows_per_file + assert len(list(dataset)) == min_rows_per_file if __name__ == "__main__": diff --git a/python/ray/experimental/channel/common.py b/python/ray/experimental/channel/common.py index c395422d5daef..09b1e49743774 100644 --- a/python/ray/experimental/channel/common.py +++ b/python/ray/experimental/channel/common.py @@ -421,6 +421,12 @@ def _read_list(self, timeout: Optional[float] = None) -> List[Any]: def release_channel_buffers(self, timeout: Optional[float] = None) -> None: for c in self._input_channels: start_time = time.monotonic() + assert hasattr( + c, "release_buffer" + ), "release_buffer() is only supported for shared memory channel " + "(e.g., Channel, BufferedSharedMemoryChannel, CompositeChannel) " + "and used between the last actor and the driver, but got a channel" + f" of type {type(c)}." c.release_buffer(timeout) if timeout is not None: timeout -= time.monotonic() - start_time diff --git a/python/ray/experimental/channel/shared_memory_channel.py b/python/ray/experimental/channel/shared_memory_channel.py index 661437fae3cde..11cac4bc3439c 100644 --- a/python/ray/experimental/channel/shared_memory_channel.py +++ b/python/ray/experimental/channel/shared_memory_channel.py @@ -18,14 +18,6 @@ # entry/init points. logger = logging.getLogger(__name__) -DEFAULT_MAX_BUFFER_SIZE = int(1e6) # 100 mB -# The min buffer size must be large enough to at least fit an instance of the -# _ResizeChannel class along with any metadata. -MIN_BUFFER_SIZE = int(1000) # 1000 bytes -# For shared memory channels, the default number of buffers per channel to -# allocate. -DEFAULT_NUM_SHM_BUFFERS = 1 - def _create_channel_ref( self, @@ -109,14 +101,21 @@ def __init__( that can be passed between tasks in the DAG. The buffers will be automatically resized if larger messages are written to the channel. - num_shm_buffers: The number of shared memory buffer per channel. + num_shm_buffers: The number of shared memory buffers per channel. + Note: In the case of multiple nodes, we only support 1 shared + memory buffer. """ super().__init__() + + from ray.dag import DAGContext + + ctx = DAGContext.get_current() + if buffer_size_bytes is None: - buffer_size_bytes = DEFAULT_MAX_BUFFER_SIZE + buffer_size_bytes = ctx.buffer_size_bytes self.buffer_size_bytes = buffer_size_bytes if num_shm_buffers is None: - num_shm_buffers = DEFAULT_NUM_SHM_BUFFERS + num_shm_buffers = 1 self._num_shm_buffers = num_shm_buffers def create_channel( @@ -192,6 +191,9 @@ def __init__( elif isinstance(typ, int): typ = SharedMemoryType(buffer_size_bytes=typ) + # The min buffer size must be large enough to at least fit an instance of the + # _ResizeChannel class along with any metadata. + MIN_BUFFER_SIZE = int(1000) # 1000 bytes if typ.buffer_size_bytes < MIN_BUFFER_SIZE: raise ValueError( "typ.buffer_size_bytes must be at least MIN_BUFFER_SIZE " @@ -540,7 +542,8 @@ class BufferedSharedMemoryChannel(ChannelInterface): Args: writer: The actor that may write to the channel. None signifies the driver. reader_and_node_list: A list of tuples, where each tuple contains a reader - actor handle and the node ID where the actor is located. + actor handle and the node ID where the actor is located. Note that currently + we only support this for readers on the same node as the writer. num_shm_buffers: Number of shared memory buffers to read/write. typ: Type information about the values passed through the channel. Either an integer representing the max buffer size in bytes @@ -653,6 +656,9 @@ class CompositeChannel(ChannelInterface): writer: The actor that may write to the channel. None signifies the driver. reader_and_node_list: A list of tuples, where each tuple contains a reader actor handle and the node ID where the actor is located. + num_shm_buffers: The number of shared memory buffers per channel. + Note: In the case of multiple nodes, we only support 1 shared + memory buffer. driver_actor_id: If this channel is read by a driver and that driver is an actual actor, this will be the actor ID of that driver actor. """ @@ -699,14 +705,29 @@ def __init__( actor_id = self._get_actor_id(self._writer) self._channel_dict[actor_id] = local_channel # There are some remote readers which are not the same Ray actor as the writer. - # Create a shared memory channel for the writer and the remote readers. - if len(remote_reader_and_node_list) != 0: + # We create a BufferedSharedMemoryChannel for readers on the same node, and + # a single Channel for readers on different nodes due to + # https://github.com/ray-project/ray/issues/49044 + ( + readers_same_node, + readers_different_node, + ) = utils.split_actors_by_node_locality( + utils.get_actor_node(self._writer), remote_reader_and_node_list + ) + + if len(readers_same_node) != 0: remote_channel = BufferedSharedMemoryChannel( - self._writer, remote_reader_and_node_list, num_shm_buffers + self._writer, readers_same_node, num_shm_buffers ) self._channels.add(remote_channel) + for reader, _ in readers_same_node: + actor_id = self._get_actor_id(reader) + self._channel_dict[actor_id] = remote_channel - for reader, _ in remote_reader_and_node_list: + if len(readers_different_node) != 0: + remote_channel = Channel(self._writer, readers_different_node) + self._channels.add(remote_channel) + for reader, _ in readers_different_node: actor_id = self._get_actor_id(reader) self._channel_dict[actor_id] = remote_channel diff --git a/python/ray/experimental/channel/utils.py b/python/ray/experimental/channel/utils.py index 88560b3bc1c87..6df57828130a8 100644 --- a/python/ray/experimental/channel/utils.py +++ b/python/ray/experimental/channel/utils.py @@ -41,3 +41,52 @@ def split_readers_by_locality( local_readers.append((reader, node)) return remote_readers, local_readers + + +def split_actors_by_node_locality( + node: str, + actor_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], +) -> Tuple[ + List[Tuple["ray.actor.ActorHandle", str]], List[Tuple["ray.actor.ActorHandle", str]] +]: + """Split actors into remote and local actors based on node. The local actors will be + on the same node as the given node. The remote actors will be on a different node. + + Args: + writer_node: The node of the writer + actor_and_node_list: List of (actor, node) tuples + + Returns: + Tuple containing: + - List of (actor, node) tuples for actors on the same node + - List of (actor, node) tuples for actors on a different node + """ + actors_on_same_node = [] + actors_on_different_node = [] + + for actor, actor_node in actor_and_node_list: + if node == actor_node: + actors_on_same_node.append((actor, actor_node)) + else: + actors_on_different_node.append((actor, actor_node)) + + return actors_on_same_node, actors_on_different_node + + +def get_actor_node(actor: Optional["ray.actor.ActorHandle"]) -> str: + """Get the node of the actor. + + Args: + actor: The actor handle of the actor + + Returns: + The node of the actor + """ + if actor is None or actor == ray.get_runtime_context().current_actor: + return ray.get_runtime_context().get_node_id() + else: + return ray.get( + actor.__ray_call__.remote( + lambda self: ray.get_runtime_context().get_node_id() + ) + ) diff --git a/python/ray/experimental/tqdm_ray.py b/python/ray/experimental/tqdm_ray.py index e1cd885a8b3c4..e5bcd4943d9ef 100644 --- a/python/ray/experimental/tqdm_ray.py +++ b/python/ray/experimental/tqdm_ray.py @@ -30,6 +30,7 @@ # Global manager singleton. _manager: Optional["_BarManager"] = None +_mgr_lock = threading.Lock() _print = builtins.print @@ -383,13 +384,15 @@ def _update_offsets(self): def instance() -> _BarManager: """Get or create a BarManager for this process.""" global _manager - if _manager is None: - _manager = _BarManager() - if env_bool("RAY_TQDM_PATCH_PRINT", True): - import builtins - builtins.print = safe_print - return _manager + with _mgr_lock: + if _manager is None: + _manager = _BarManager() + if env_bool("RAY_TQDM_PATCH_PRINT", True): + import builtins + + builtins.print = safe_print + return _manager if __name__ == "__main__": diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index a55e101758b33..c256459f67b43 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -555,6 +555,10 @@ cdef extern from "ray/gcs/gcs_client/accessor.h" nogil: c_string &serialized_reply ) + CRayStatus AsyncGetClusterStatus( + int64_t timeout_ms, + const OptionalItemPyCallback[CGetClusterStatusReply] &callback) + CRayStatus ReportAutoscalingState( int64_t timeout_ms, const c_string &serialized_state @@ -728,6 +732,12 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil: void ParseFromString(const c_string &serialized) const c_string &SerializeAsString() const +cdef extern from "src/ray/protobuf/autoscaler.pb.h" nogil: + cdef cppclass CGetClusterStatusReply "ray::rpc::autoscaler::GetClusterStatusReply": + c_string serialized_cluster_status() const + void ParseFromString(const c_string &serialized) + const c_string &SerializeAsString() const + cdef extern from "ray/common/task/task_spec.h" nogil: cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup": CConcurrencyGroup( diff --git a/python/ray/includes/gcs_client.pxi b/python/ray/includes/gcs_client.pxi index 4e36348e678ab..1bd2c3e2c23a9 100644 --- a/python/ray/includes/gcs_client.pxi +++ b/python/ray/includes/gcs_client.pxi @@ -28,9 +28,10 @@ from ray.includes.common cimport ( MultiItemPyCallback, OptionalItemPyCallback, StatusPyCallback, + CGetClusterStatusReply, ) from ray.includes.optional cimport optional, make_optional -from ray.core.generated import gcs_pb2 +from ray.core.generated import gcs_pb2, autoscaler_pb2 from cython.operator import dereference, postincrement cimport cpython @@ -548,6 +549,28 @@ cdef class InnerGcsClient: return serialized_reply + def async_get_cluster_status( + self, + timeout_s=None + ) -> Future[autoscaler_pb2.GetClusterStatusReply]: + cdef: + int64_t timeout_ms = round(1000 * timeout_s) if timeout_s else -1 + fut = incremented_fut() + with nogil: + check_status_timeout_as_rpc_error( + self.inner.get() + .Autoscaler() + .AsyncGetClusterStatus( + timeout_ms, + OptionalItemPyCallback[CGetClusterStatusReply]( + &convert_get_cluster_status_reply, + assign_and_decrement_fut, + fut + ) + ) + ) + return asyncio.wrap_future(fut) + def report_autoscaling_state( self, serialzied_state: c_string, @@ -700,6 +723,21 @@ cdef convert_get_all_actor_info( except Exception as e: return None, e +cdef convert_get_cluster_status_reply( + CRayStatus status, optional[CGetClusterStatusReply]&& c_data +) with gil: # -> Tuple[autoscaler_pb2.GetClusterStatusReply, Exception] + cdef c_string serialized_reply + try: + check_status_timeout_as_rpc_error(status) + assert c_data.has_value() + with nogil: + serialized_reply = c_data.value().SerializeAsString() + proto = autoscaler_pb2.GetClusterStatusReply() + proto.ParseFromString(serialized_reply) + return proto, None + except Exception as e: + return None, e + cdef convert_status(CRayStatus status) with gil: # -> None try: diff --git a/python/ray/llm/_internal/__init__.py b/python/ray/llm/_internal/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/ray/llm/_internal/batch/__init__.py b/python/ray/llm/_internal/batch/__init__.py new file mode 100644 index 0000000000000..c1720b034f474 --- /dev/null +++ b/python/ray/llm/_internal/batch/__init__.py @@ -0,0 +1,13 @@ +from ray.llm._internal.batch.processor import ( + Processor, + ProcessorConfig, + ProcessorBuilder, + HttpRequestProcessorConfig, +) + +__all__ = [ + "Processor", + "ProcessorConfig", + "ProcessorBuilder", + "HttpRequestProcessorConfig", +] diff --git a/python/ray/llm/_internal/batch/processor/__init__.py b/python/ray/llm/_internal/batch/processor/__init__.py new file mode 100644 index 0000000000000..2c806c72a2a4c --- /dev/null +++ b/python/ray/llm/_internal/batch/processor/__init__.py @@ -0,0 +1,9 @@ +from .base import ProcessorConfig, ProcessorBuilder, Processor +from .http_request_proc import HttpRequestProcessorConfig + +__all__ = [ + "ProcessorConfig", + "ProcessorBuilder", + "HttpRequestProcessorConfig", + "Processor", +] diff --git a/python/ray/llm/_internal/batch/processor/base.py b/python/ray/llm/_internal/batch/processor/base.py new file mode 100644 index 0000000000000..eda5562b41e51 --- /dev/null +++ b/python/ray/llm/_internal/batch/processor/base.py @@ -0,0 +1,198 @@ +from collections import OrderedDict +from typing import Optional, List, Type, Callable, Dict + +from pydantic import BaseModel, Field + +from ray.data.block import UserDefinedFunction +from ray.data import Dataset +from ray.util.annotations import PublicAPI, DeveloperAPI + +from ray.llm._internal.batch.stages import ( + StatefulStage, + wrap_preprocess, + wrap_postprocess, +) + + +class ProcessorConfig(BaseModel): + """The processor configuration.""" + + batch_size: int = Field( + description="Large batch sizes are likely to saturate the compute resources " + "and could achieve higher throughput. On the other hand, small batch sizes " + "are more fault-tolerant and could reduce bubbles in the data pipeline. " + "You can tune the batch size to balance the throughput and fault-tolerance " + "based on your use case.", + ) + accelerator_type: Optional[str] = Field( + default=None, + description="The accelerator type used by the LLM stage in a processor. " + "Default to None, meaning that only the CPU will be used.", + ) + concurrency: int = Field( + default=1, + description="The number of workers for data parallelism. Default to 1.", + ) + + class Config: + validate_assignment = True + arbitrary_types_allowed = True + + +@PublicAPI(stability="alpha") +class Processor: + """A processor is composed of a preprocess stage, followed by one or more + processing stages, and finally a postprocess stage. We use processor as a + paradigm for processing data using LLMs. + + Args: + config: The processor config. + preprocess: An optional lambda function that takes a row (dict) as input + and returns a preprocessed row (dict). The output row must contain the + required fields for the following processing stages. + postprocess: An optional lambda function that takes a row (dict) as input + and returns a postprocessed row (dict). + """ + + # The internal used data column name ("__data"). Your input + # dataset should not contain this column. If you want to use this column + # in your input dataset, you have to derive and customize Processor. + data_column: str = "__data" + + def __init__( + self, + config: ProcessorConfig, + stages: List[StatefulStage], + preprocess: Optional[UserDefinedFunction] = None, + postprocess: Optional[UserDefinedFunction] = None, + ): + self.config = config + self.preprocess = None + self.postprocess = None + self.stages: OrderedDict[str, StatefulStage] = OrderedDict() + + if preprocess is not None: + self.preprocess = wrap_preprocess( + preprocess, + self.data_column, + ) + + if postprocess is not None: + self.postprocess = wrap_postprocess( + postprocess, + self.data_column, + ) + + for stage in stages: + self._append_stage(stage) + + def __call__(self, dataset: Dataset) -> Dataset: + """Execute the processor: + preprocess -> stages -> postprocess. + Note that the dataset won't be materialized during the execution. + + Args: + dataset: The input dataset. + + Returns: + The output dataset. + """ + if self.preprocess is not None: + dataset = dataset.map(self.preprocess) + + # Apply stages. + for stage in self.stages.values(): + kwargs = stage.get_dataset_map_batches_kwargs( + batch_size=self.config.batch_size, + data_column=self.data_column, + ) + dataset = dataset.map_batches(stage.fn, **kwargs) + + if self.postprocess is not None: + dataset = dataset.map(self.postprocess) + return dataset + + def _append_stage(self, stage: StatefulStage) -> None: + """Append a stage before postprocess. The stage class name will be used as + the stage name. If there are multiple stages with the same type, a suffix + will be added to the stage name to avoid conflicts. + + Args: + stage: The stage to append. + """ + stage_name = type(stage).__name__ + + # When a processor has multiple stages with the same type, + # append a index suffix to the stage name to avoid conflicts. + if stage_name in self.stages: + num_same_type_stage = len([s for s in self.stages.values() if s is stage]) + stage_name = f"{stage_name}_{num_same_type_stage + 1}" + self.stages[stage_name] = stage + + def list_stage_names(self) -> List[str]: + """List the stage names of this processor in order. Preprocess and postprocess + are not included. + + Returns: + A list of stage names. + """ + return list(self.stages.keys()) + + def get_stage_by_name(self, name: str) -> StatefulStage: + """Get a particular stage by its name. If the stage is not found, + a ValueError will be raised. + + Args: + name: The stage name. + + Returns: + The pipeline stage. + """ + if name in self.stages: + return self.stages[name] + raise ValueError(f"Stage {name} not found") + + +@DeveloperAPI +class ProcessorBuilder: + """Build a processor based on the configuration.""" + + _registry: Dict[str, Callable] = {} + + @classmethod + def register(cls, config_type: Type[ProcessorConfig], builder: Callable) -> None: + """A decorator to assoicate a particular pipeline config + with its build function. + """ + type_name = config_type.__name__ + if type_name in cls._registry: + raise ValueError(f"Processor config type {type_name} already registered.") + cls._registry[type_name] = builder + + @classmethod + def build( + cls, + config: ProcessorConfig, + override_stage_config_fn: Optional[Callable] = None, + **kwargs, + ) -> Processor: + """Build a processor. + + Args: + config: The processor config. + override_stage_config_fn: Custom stages configurations. + + Returns: + The built processor. + """ + type_name = type(config).__name__ + if type_name not in cls._registry: + raise ValueError( + f"Processor config type {type_name} not registered. " + f"Available types: {cls._registry.keys()}" + ) + processor = cls._registry[type_name](config, **kwargs) + if override_stage_config_fn is not None: + for name, stage in processor.stages.items(): + override_stage_config_fn(name, stage) + return processor diff --git a/python/ray/llm/_internal/batch/processor/http_request_proc.py b/python/ray/llm/_internal/batch/processor/http_request_proc.py new file mode 100644 index 0000000000000..e075a15f037af --- /dev/null +++ b/python/ray/llm/_internal/batch/processor/http_request_proc.py @@ -0,0 +1,66 @@ +"""The HTTP request processor.""" + +from typing import Any, Dict, Optional + +from pydantic import Field + +from ray.llm._internal.batch.processor.base import ( + Processor, + ProcessorConfig, + ProcessorBuilder, +) +from ray.llm._internal.batch.stages import HttpRequestStage + + +class HttpRequestProcessorConfig(ProcessorConfig): + """The configuration for the HTTP request processor.""" + + batch_size: int = Field( + default=64, + description="The batch size.", + ) + url: str = Field( + description="The URL to query.", + ) + headers: Optional[Dict[str, Any]] = Field( + default=None, + description="The query header. Note that we will add " + "'Content-Type: application/json' to be the header for sure " + "because we only deal with requests body in JSON.", + ) + qps: Optional[int] = Field( + default=None, + description="The maximum number of requests per second to avoid rate limit. " + "If None, the request will be sent sequentially.", + ) + + +def build_http_request_processor( + config: HttpRequestProcessorConfig, **kwargs +) -> Processor: + """Construct a Processor and configure stages. + + Args: + config: The configuration for the processor. + **kwargs: The keyword arguments for the processor. + + Returns: + The constructed processor. + """ + stages = [ + HttpRequestStage( + fn_constructor_kwargs=dict( + url=config.url, + additional_header=config.headers, + qps=config.qps, + ), + map_batches_kwargs=dict( + concurrency=config.concurrency, + ), + ) + ] + processor = Processor(config, stages, **kwargs) + return processor + + +ProcessorBuilder.register(HttpRequestProcessorConfig, build_http_request_processor) diff --git a/python/ray/llm/_internal/batch/stages/__init__.py b/python/ray/llm/_internal/batch/stages/__init__.py new file mode 100644 index 0000000000000..a23cc9eed06f0 --- /dev/null +++ b/python/ray/llm/_internal/batch/stages/__init__.py @@ -0,0 +1,20 @@ +from ray.llm._internal.batch.stages.base import ( + StatefulStage, + wrap_preprocess, + wrap_postprocess, +) +from ray.llm._internal.batch.stages.http_request_stage import HttpRequestStage +from ray.llm._internal.batch.stages.chat_template_stage import ChatTemplateStage +from ray.llm._internal.batch.stages.prepare_image_stage import PrepareImageStage +from ray.llm._internal.batch.stages.tokenize_stage import TokenizeStage, DetokenizeStage + +__all__ = [ + "StatefulStage", + "HttpRequestStage", + "ChatTemplateStage", + "TokenizeStage", + "DetokenizeStage", + "wrap_preprocess", + "wrap_postprocess", + "PrepareImageStage", +] diff --git a/python/ray/llm/_internal/batch/stages/base.py b/python/ray/llm/_internal/batch/stages/base.py new file mode 100644 index 0000000000000..4248a88c7b102 --- /dev/null +++ b/python/ray/llm/_internal/batch/stages/base.py @@ -0,0 +1,248 @@ +"""The base class for all stages.""" +import logging +from typing import Any, Dict, AsyncIterator, List, Callable + +import pyarrow +from pydantic import BaseModel, Field +from ray.data.block import UserDefinedFunction + +logger = logging.getLogger(__name__) + + +def wrap_preprocess( + fn: UserDefinedFunction, + processor_data_column: str, +) -> Callable: + """Wrap the preprocess function, so that the output schema of the + preprocess is normalized to {processor_data_column: fn(row), other input columns}. + + Args: + fn: The function to be applied. + processor_data_column: The internal data column name of the processor. + + Returns: + The wrapped function. + """ + + def _preprocess(row: dict[str, Any]) -> dict[str, Any]: + # First put everything into processor_data_column. + outputs = {processor_data_column: row} + + # Then apply the preprocess function and add its outputs. + preprocess_output = fn(row) + outputs[processor_data_column].update(preprocess_output) + return outputs + + return _preprocess + + +def wrap_postprocess( + fn: UserDefinedFunction, + processor_data_column: str, +) -> Callable: + """Wrap the postprocess function to remove the processor_data_column. + Note that we fully rely on users to determine which columns to carry over. + + Args: + fn: The function to be applied. + processor_data_column: The internal data column name of the processor. + + Returns: + The wrapped function. + """ + + def _postprocess(row: dict[str, Any]) -> dict[str, Any]: + if processor_data_column not in row: + raise ValueError( + f"[Internal] {processor_data_column} not found in row {row}" + ) + + return fn(row[processor_data_column]) + + return _postprocess + + +class StatefulStageUDF: + """A stage UDF wrapper that processes the input and output columns + before and after the UDF. + + Args: + data_column: The internal data column name of the processor. The + __call__ method will take the data column as the input of the udf + method, and encapsulate the output of the udf method into the data + column for the next stage. + """ + + def __init__(self, data_column: str): + self.data_column = data_column + + async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]: + """A stage UDF wrapper that processes the input and output columns + before and after the UDF. The expected schema of "batch" is: + {data_column: { + dataset columns, + other intermediate columns + }, + ...other metadata columns..., + }. + + The input of the UDF will then [dataset columns and other intermediate columns]. + In addition, while the output of the UDF depends on the UDF implementation, + the output schema is expected to be + {data_column: { + dataset columns, + other intermediate columns, + UDF output columns (will override above columns if they have the same name) + }, + ...other metadata columns..., + }. + And this will become the input of the next stage. + + Examples: + Input dataset columns: {A, B, C} + Preprocess: (lambda row: {"D": row["A"] + 1}) + Input: + UDF input: {A, B, C} + UDF output: {D} + Output: {__data: {A, B, C, D}} + Stage 1: + Input: {__data: {A, B, C, D}} + UDF input: {A, B, C, D} + UDF output: {E} + Output: {__data: {A, B, C, D, E}} + Stage 2: + Input: {__data: {A, B, C, D, E}} + UDF input: {A, B, C, D, E} + UDF output: {F, E} # E is in-place updated. + Output: {__data: {A, B, C, D, E, F}} + Postprocess: (lambda row: {"G": row["F"], "A": row["A"], "E": row["E"]}) + Input: {__data: {A, B, C, D, E, F}} + UDF input: {A, B, C, D, E, F} + UDF output: {G, A, E} + Output: {G, A, E} # User chooses to keep G, A, E. + + Args: + batch: The input batch. + + Returns: + An async iterator of the outputs. + """ + # Handle the case where the batch is empty. + # FIXME: This should not happen. + if isinstance(batch, pyarrow.lib.Table) and batch.num_rows == 0: + yield {} + return + + if self.data_column not in batch: + raise ValueError( + f"[Internal] {self.data_column} not found in batch {batch}" + ) + + inputs = batch.pop(self.data_column) + if hasattr(inputs, "tolist"): + inputs = inputs.tolist() + self.validate_inputs(inputs) + + # Always stream the outputs one by one to better overlapping + # batches. For example, when the output batch size is 64, Ray Data + # will collect 64 outputs, and 1) send the batch of 64 to the next stage, + # 2) get the next batch of this stage. Assuming the input batch size + # is 63 and we yield all 63 results at once, then Ray Data will wait + # for 2 batches (63 + 63 > 64) to continue proceeding. On the other hand, + # if we stream outputs one-by-one, Ray Data can form a batch of 64 before + # the second batch is done. + idx = 0 + async for output in self.udf(inputs): + # Add stage outputs to the data column of the row. + inputs[idx].update(output) + yield {self.data_column: [inputs[idx]]} + idx += 1 + + def validate_inputs(self, inputs: List[Dict[str, Any]]): + """Validate the inputs to make sure the required keys are present. + + Args: + inputs: The inputs. + + Raises: + ValueError: If the required keys are not found. + """ + expected_input_keys = self.expected_input_keys + if not expected_input_keys: + return + + input_keys = set(inputs[0].keys()) + missing_required = set(expected_input_keys) - input_keys + if missing_required: + raise ValueError( + f"Required input keys {missing_required} not found at the input of " + f"{self.__class__.__name__}. Input keys: {input_keys}" + ) + + @property + def expected_input_keys(self) -> List[str]: + """A list of required input keys. Missing required keys will raise + an exception. + + Returns: + A list of required input keys. + """ + return [] + + async def udf(self, rows: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]: + raise NotImplementedError("StageUDF must implement the udf method") + + +class StatefulStage(BaseModel): + """ + A basic building block to compose a Processor. + """ + + fn: StatefulStageUDF = Field( + description="The well-optimized stateful UDF for this stage." + ) + fn_constructor_kwargs: Dict[str, Any] = Field( + description="The keyword arguments of the UDF constructor." + ) + map_batches_kwargs: Dict[str, Any] = Field( + description="The arguments of .map_batches()." + ) + + def get_dataset_map_batches_kwargs( + self, + batch_size: int, + data_column: str, + ) -> Dict[str, Any]: + """We separate fn and fn_constructor_kwargs in Stage for better UX, + so we combine them with other map_batches_kwargs together in this method. + + Args: + batch_size: The batch size set by the processor config. + data_column: The data column name set by the processor. + + Returns: + The dataset map_batches kwargs. + """ + kwargs = self.map_batches_kwargs.copy() + batch_size_in_kwargs = kwargs.get("batch_size", batch_size) + if batch_size_in_kwargs != batch_size: + logger.warning( + "batch_size is set to %d in map_batches_kwargs, but it will be " + "overridden by the batch size configured by the processor %d.", + batch_size_in_kwargs, + batch_size, + ) + kwargs["batch_size"] = batch_size + + kwargs.update({"fn_constructor_kwargs": self.fn_constructor_kwargs}) + if "data_column" in kwargs["fn_constructor_kwargs"]: + raise ValueError( + "'data_column' cannot be used as in fn_constructor_kwargs." + ) + + kwargs["fn_constructor_kwargs"]["data_column"] = data_column + return kwargs + + class Config: + arbitrary_types_allowed = True + validate_assignment = True diff --git a/python/ray/llm/_internal/batch/stages/chat_template_stage.py b/python/ray/llm/_internal/batch/stages/chat_template_stage.py new file mode 100644 index 0000000000000..f41ee1aaec159 --- /dev/null +++ b/python/ray/llm/_internal/batch/stages/chat_template_stage.py @@ -0,0 +1,62 @@ +"""Apply chat template stage""" + +from typing import Any, Dict, AsyncIterator, List + +from ray.llm._internal.batch.stages.base import ( + StatefulStage, + StatefulStageUDF, +) +from ray.llm._internal.batch.utils import get_cached_tokenizer + + +class ChatTemplateUDF(StatefulStageUDF): + def __init__( + self, + data_column: str, + model: str, + ): + """ + Initialize the ChatTemplateUDF. + + Args: + data_column: The data column name. + model: The model to use for the chat template. + """ + from transformers import AutoTokenizer + + super().__init__(data_column) + self.tokenizer = get_cached_tokenizer(AutoTokenizer.from_pretrained(model)) + + async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]: + """ + Apply chat template to the given batch. + + Args: + batch: A list of rows to send. + + Yields: + A generator of rows with the chat template applied. + """ + for prompt in self.tokenizer.apply_chat_template( + [row["messages"].tolist() for row in batch], + tokenize=False, + add_generation_prompt=True, + ): + yield {"prompt": prompt} + + @property + def expected_input_keys(self) -> List[str]: + """The expected input keys.""" + return ["messages"] + + +class ChatTemplateStage(StatefulStage): + """ + A stage that applies chat template. + """ + + fn: StatefulStageUDF = ChatTemplateUDF + fn_constructor_kwargs: Dict[str, Any] + map_batches_kwargs: Dict[str, Any] = dict( + concurrency=1, + ) diff --git a/python/ray/llm/_internal/batch/stages/http_request_stage.py b/python/ray/llm/_internal/batch/stages/http_request_stage.py new file mode 100644 index 0000000000000..9b765ec2b8d02 --- /dev/null +++ b/python/ray/llm/_internal/batch/stages/http_request_stage.py @@ -0,0 +1,94 @@ +"""HTTP Request Stage""" + +import aiohttp +import asyncio +import time +import numpy as np +from typing import Any, Dict, AsyncIterator, Optional, List + +from ray.llm._internal.batch.stages.base import StatefulStage, StatefulStageUDF + + +class HttpRequestUDF(StatefulStageUDF): + def __init__( + self, + data_column: str, + url: str, + additional_header: Optional[Dict[str, Any]] = None, + qps: Optional[int] = None, + ): + """ + Initialize the HttpRequestUDF. + + Args: + data_column: The data column name. + url: The URL to send the HTTP request to. + additional_header: The additional headers to send with the HTTP request. + qps: The maximum number of requests per second. + """ + super().__init__(data_column) + self.url = url + self.additional_header = additional_header or {} + self.qps = qps + + async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]: + """ + Send HTTP requests to the given URL. + + Args: + batch: A list of rows to send. + + Yields: + A generator of rows of the response of the HTTP request. + """ + async with aiohttp.ClientSession() as session: + start_time = time.time() + request_count = 0 + pending_requests = [] + headers = { + "Content-Type": "application/json", + **self.additional_header, + } + + # First send all requests based on QPS + for row in batch: + # Rate limit based on qps if specified + if self.qps is not None: + request_count += 1 + expected_time = request_count / self.qps + elapsed = time.time() - start_time + if elapsed < expected_time: + await asyncio.sleep(expected_time - elapsed) + + # Normalize the row to a JSON body. + json_body = {} + for key, value in row.items(): + if isinstance(value, np.ndarray): + json_body[key] = value.tolist() + else: + json_body[key] = value + + # Create request but don't await it yet + request = session.post( + self.url, + headers=headers, + json=json_body, + ) + pending_requests.append(request) + + # Now receive all responses + for request in pending_requests: + async with await request as response: + yield await response.json() + + +class HttpRequestStage(StatefulStage): + """ + A stage that sends HTTP requests. + """ + + fn: StatefulStageUDF = HttpRequestUDF + fn_constructor_kwargs: Dict[str, Any] + map_batches_kwargs: Dict[str, Any] = dict( + concurrency=1, + ) diff --git a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py new file mode 100644 index 0000000000000..88641a8343ab7 --- /dev/null +++ b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py @@ -0,0 +1,371 @@ +"""Prepare Image Stage""" +import requests +import aiohttp +import asyncio +import base64 +import logging +import importlib +from urllib.parse import urlparse +from pathlib import Path +from io import BytesIO +from typing import ( + TYPE_CHECKING, + Any, + Dict, + AsyncIterator, + List, + Union, + Optional, + MutableMapping, + Mapping, +) + +from ray.llm._internal.batch.stages.base import ( + StatefulStage, + StatefulStageUDF, +) + +# TODO: Remove the guard once Pillow is added into the dependencies. +if TYPE_CHECKING: + from PIL import Image + +logger = logging.getLogger(__name__) + +_ImageType = Union[str, "Image.Image"] + + +class HTTPConnection: + """Adapted from vllm.connections.HTTPConnection. + Helper class to send HTTP requests. + """ + + def __init__(self, *, reuse_client: bool = True) -> None: + super().__init__() + + self.reuse_client = reuse_client + + self._sync_client: Optional[requests.Session] = None + self._async_client: Optional[aiohttp.ClientSession] = None + + def get_sync_client(self) -> requests.Session: + if self._sync_client is None or not self.reuse_client: + self._sync_client = requests.Session() + + return self._sync_client + + # NOTE: We intentionally use an async function even though it is not + # required, so that the client is only accessible inside async event loop + async def get_async_client(self) -> aiohttp.ClientSession: + if self._async_client is None or not self.reuse_client: + self._async_client = aiohttp.ClientSession() + + return self._async_client + + def _validate_http_url(self, url: str): + parsed_url = urlparse(url) + + if parsed_url.scheme not in ("http", "https"): + raise ValueError( + "Invalid HTTP URL: A valid HTTP URL " + "must have scheme 'http' or 'https'." + ) + + def _headers(self, **extras: str) -> MutableMapping[str, str]: + return {"User-Agent": "RayLLM-Batch", **extras} + + def get_response( + self, + url: str, + *, + stream: bool = False, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = self.get_sync_client() + extra_headers = extra_headers or {} + + return client.get( + url, headers=self._headers(**extra_headers), stream=stream, timeout=timeout + ) + + async def get_async_response( + self, + url: str, + *, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = await self.get_async_client() + extra_headers = extra_headers or {} + + return client.get(url, headers=self._headers(**extra_headers), timeout=timeout) + + def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.content + + async def async_get_bytes( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> bytes: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.read() + + def get_text(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.text + + async def async_get_text( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.text() + + def get_json(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.json() + + async def async_get_json( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.json() + + def download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + for chunk in r.iter_content(chunk_size): + f.write(chunk) + + return save_path + + async def async_download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + async for chunk in r.content.iter_chunked(chunk_size): + f.write(chunk) + + return save_path + + +class ImageProcessor: + """Download and load images.""" + + def __init__(self): + self.Image = importlib.import_module("PIL.Image") + self.http_connection = HTTPConnection() + + async def download_image_from_url(self, image_url: str) -> Optional[bytes]: + """Download the image from the Internet with up to 3 retries. + + Args: + image_url: The image URL to download. + + Returns: + The image bytes (None if failed to download). + """ + for _ in range(3): + try: + image_raw = await self.http_connection.async_get_bytes( + image_url, timeout=5 + ) + return image_raw + except Exception: + await asyncio.sleep(1) + return None + + async def load_image_bytes_from_url(self, image_urls: List[str]) -> List[bytes]: + """Load an image from a URL. + + Args: + image_urls: The image URLs to load. + + Returns: + The image bytes. + """ + return await asyncio.gather( + *[self.download_image_from_url(image_url) for image_url in image_urls] + ) + + async def fetch_images( + self, image_urls: List[str], *, image_mode: Optional[str] = None + ) -> List["Image.Image"]: + """ + Adapted from vllm.multimodal.utils.fetch_image. + Load a PIL image from a HTTP or base64 data URL. + + Args: + image_urls: A list of URLs of the images. + image_mode: The mode of the image. If None, the image is not converted. + + Returns: + A list of loaded images. + """ + + def _load_image_from_bytes(b: bytes): + image = self.Image.open(BytesIO(b)) + image.load() + return image + + def _load_image_from_data_url(image_url: str): + # Only split once and assume the second part is the base64 encoded image + _, image_base64 = image_url.split(",", 1) + return _load_image_from_bytes(base64.b64decode(image_base64)) + + # Check if all image URLs are of the same type. + if image_urls[0].startswith("http"): + image_url_prefix = "http" + elif image_urls[0].startswith("data:image"): + image_url_prefix = "data:image" + else: + raise ValueError(f"Invalid image URL prefix: {image_urls[0]}") + + if not all(url.startswith(image_url_prefix) for url in image_urls): + raise ValueError( + f"All image URLs must have the same prefix, got {image_url_prefix=}" + ) + + if image_url_prefix == "http": + image_raws = await self.load_image_bytes_from_url(image_urls) + images = [_load_image_from_bytes(image_raw) for image_raw in image_raws] + elif image_url_prefix == "data:image": + images = [_load_image_from_data_url(image_url) for image_url in image_urls] + else: + raise ValueError( + "Invalid 'image_url': A valid 'image_url' must start " + "with either 'data:image' or 'http'." + ) + + if image_mode is not None and images[0].mode != image_mode: + images = [image.convert(image_mode) for image in images] + return images + + async def process(self, images: List[_ImageType]) -> List["Image.Image"]: + """Load and resize an image for the model. + Args: + image: A list of images. + + Returns: + A list of processed images. + """ + if not images: + return [] + + # Check if all images are of the same type. + image_type = type(images[0]) + if not all(isinstance(img, image_type) for img in images): + raise ValueError(f"All images must be of the same type, got {image_type=}") + + if not issubclass(image_type, self.Image.Image): + images = await self.fetch_images(images) + return images + + +class PrepareImageUDF(StatefulStageUDF): + def __init__(self, data_column: str): + super().__init__(data_column) + self.Image = importlib.import_module("PIL.Image") + self.image_processor = ImageProcessor() + + def extract_image_info(self, messages: List[Dict]) -> List[_ImageType]: + """Extract vision information such as image and video from chat messages. + + Args: + messages: List of chat messages. + + Returns: + List of _ImageType. + """ + + image_info: List[_ImageType] = [] + for message in messages: + if not isinstance(message["content"], list): + continue + for content in message["content"]: + if content["type"] not in ("image", "image_url"): + continue + image = content[content["type"]] + if not isinstance(image, str) and not isinstance( + image, self.Image.Image + ): + raise ValueError(f"Cannot handle image type {type(image)}") + image_info.append(image) + return image_info + + async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]: + messages = [row["messages"] for row in batch] + + # Process all images in this batch. + all_image_info = [self.extract_image_info(message) for message in messages] + flat_all_image_info = [img for imgs in all_image_info for img in imgs] + flat_all_images = await self.image_processor.process(flat_all_image_info) + + idx = 0 + for image_info_per_req in all_image_info: + num_images_in_req = len(image_info_per_req) + if num_images_in_req == 0: + yield {} + else: + images = flat_all_images[idx : idx + num_images_in_req] + yield { + "image": images, + "image_sizes": [(img.width, img.height) for img in images], + } + idx += num_images_in_req + + @property + def expected_input_keys(self) -> List[str]: + """The expected input keys.""" + return ["messages"] + + +class PrepareImageStage(StatefulStage): + """A stage to prepare images from OpenAI chat template messages.""" + + fn: StatefulStageUDF = PrepareImageUDF + fn_constructor_kwargs: Dict[str, Any] + map_batches_kwargs: Dict[str, Any] = dict( + concurrency=1, + ) diff --git a/python/ray/llm/_internal/batch/stages/tokenize_stage.py b/python/ray/llm/_internal/batch/stages/tokenize_stage.py new file mode 100644 index 0000000000000..68e7bc53e3c0c --- /dev/null +++ b/python/ray/llm/_internal/batch/stages/tokenize_stage.py @@ -0,0 +1,116 @@ +"""Tokenize and detokenize stage""" + +from typing import Any, Dict, AsyncIterator, List + +from ray.llm._internal.batch.stages.base import ( + StatefulStage, + StatefulStageUDF, +) +from ray.llm._internal.batch.utils import get_cached_tokenizer + + +class TokenizeUDF(StatefulStageUDF): + def __init__( + self, + data_column: str, + model: str, + ): + """ + Initialize the TokenizeUDF. + + Args: + data_column: The data column name. + model: The model to use for the chat template. + """ + from transformers import AutoTokenizer + + super().__init__(data_column) + self.tokenizer = get_cached_tokenizer(AutoTokenizer.from_pretrained(model)) + + async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]: + """ + Tokenize the given batch. + + Args: + batch: A list of rows to send. + + Yields: + A generator of rows with the tokenized prompt. + """ + for row, prompt_token_ids in zip( + batch, + self.tokenizer([row["prompt"] for row in batch])["input_ids"], + ): + yield {"tokenized_prompt": prompt_token_ids} + + @property + def expected_input_keys(self) -> List[str]: + """The expected input keys.""" + return ["prompt"] + + +class TokenizeStage(StatefulStage): + """ + A stage that tokenizes the input. + """ + + fn: StatefulStageUDF = TokenizeUDF + fn_constructor_kwargs: Dict[str, Any] + map_batches_kwargs: Dict[str, Any] = dict( + concurrency=1, + ) + + +class DetokenizeUDF(StatefulStageUDF): + def __init__( + self, + data_column: str, + model: str, + ): + """ + Initialize the DetokenizeUDF. + + Args: + data_column: The data column name. + model: The model to use for the chat template. + """ + from transformers import AutoTokenizer + + super().__init__(data_column) + self.tokenizer = get_cached_tokenizer(AutoTokenizer.from_pretrained(model)) + + async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]: + """ + Detokenize the given batch. + + Args: + batch: A list of rows to send. + + Yields: + A generator of rows with the detokenized prompt. + """ + for row, generated_text in zip( + batch, + self.tokenizer.batch_decode( + [row["generated_tokens"] for row in batch], + skip_special_tokens=True, + ), + ): + yield {"generated_text": generated_text} + + @property + def expected_input_keys(self) -> List[str]: + """The expected input keys.""" + return ["generated_tokens"] + + +class DetokenizeStage(StatefulStage): + """ + A stage that detokenizes the input. + """ + + fn: StatefulStageUDF = DetokenizeUDF + fn_constructor_kwargs: Dict[str, Any] + map_batches_kwargs: Dict[str, Any] = dict( + concurrency=1, + ) diff --git a/python/ray/llm/_internal/batch/utils.py b/python/ray/llm/_internal/batch/utils.py new file mode 100644 index 0000000000000..5cca5d3dd880c --- /dev/null +++ b/python/ray/llm/_internal/batch/utils.py @@ -0,0 +1,62 @@ +"""Utility functions for batch processing.""" +import logging +from typing import TYPE_CHECKING, Any, Union + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +AnyTokenizer = Union["PreTrainedTokenizer", "PreTrainedTokenizerFast", Any] + +logger = logging.getLogger(__name__) + + +def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: + """Get tokenizer with cached properties. + This will patch the tokenizer object in place. + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access. + Args: + tokenizer: The tokenizer object. + Returns: + The patched tokenizer object. + """ + chat_template = getattr(tokenizer, "chat_template", None) + # For VLM, the text tokenizer is wrapped by a processor. + if hasattr(tokenizer, "tokenizer"): + tokenizer = tokenizer.tokenizer + # Some VLM's tokenizer has chat_template attribute (e.g. Qwen/Qwen2-VL-7B-Instruct), + # however some other VLM's tokenizer does not have chat_template attribute (e.g. + # mistral-community/pixtral-12b). Therefore, we cache the processor's chat_template. + if chat_template is None: + chat_template = getattr(tokenizer, "chat_template", None) + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = tokenizer.all_special_tokens_extended + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + tokenizer_len = len(tokenizer) + + class CachedTokenizer(tokenizer.__class__): # type: ignore + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + @property + def chat_template(self): + return chat_template + + def __len__(self): + return tokenizer_len + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer diff --git a/python/ray/llm/tests/batch/processor/BUILD b/python/ray/llm/tests/batch/processor/BUILD new file mode 100644 index 0000000000000..3777a36bc9b45 --- /dev/null +++ b/python/ray/llm/tests/batch/processor/BUILD @@ -0,0 +1,8 @@ +load("//bazel:python.bzl", "py_test_module_list") + +py_test_module_list( + files = glob(["test_*.py"]), + size = "small", + tags = ["exclusive", "team:llm"], + deps = ["//:ray_lib"], +) diff --git a/python/ray/llm/tests/batch/processor/test_base.py b/python/ray/llm/tests/batch/processor/test_base.py new file mode 100644 index 0000000000000..b186117ac551d --- /dev/null +++ b/python/ray/llm/tests/batch/processor/test_base.py @@ -0,0 +1,168 @@ +import sys +from typing import Any, AsyncIterator, Dict, List + +import pytest + +import ray +from ray.llm._internal.batch.processor.base import ( + Processor, + ProcessorConfig, + ProcessorBuilder, +) +from ray.llm._internal.batch.stages.base import StatefulStage, StatefulStageUDF + + +def test_empty_processor(): + """Test processor with only preprocess and postprocess.""" + + processor = Processor( + config=ProcessorConfig( + batch_size=64, + accelerator_type=None, + concurrency=1, + ), + stages=[], + # {id} -> {__data: {id, val}} + preprocess=lambda row: {"val": row["id"] + 5}, + # {__data: {id, val}} -> {id, result} + postprocess=lambda row: {"result": row["val"], "id": row["id"]}, + ) + + ds = ray.data.range(5) + ds = processor(ds).take_all() + for row in ds: + assert "val" not in row + assert "id" in row + assert "result" in row + + +@pytest.mark.parametrize("has_extra", [True, False]) +def test_processor_with_stages(has_extra: bool): + """Test processor with multiple stages.""" + + class DummyStatefulStageUDF(StatefulStageUDF): + def __init__( + self, + data_column: str, + factor: int, + ): + super().__init__(data_column) + self.factor = factor + + async def udf( + self, batch: List[Dict[str, Any]] + ) -> AsyncIterator[Dict[str, Any]]: + for row in batch: + answer = row["val"] * self.factor + if "extra" in row: # Optional input column. + answer += row["extra"] + yield { + # Use the same name to chain multiple dummy stages. + "val": answer, + } + + @property + def expected_input_keys(self) -> List[str]: + return ["val"] + + class DummyStage(StatefulStage): + fn: StatefulStageUDF = DummyStatefulStageUDF + fn_constructor_kwargs: Dict[str, Any] = {} + map_batches_kwargs: Dict[str, Any] = dict(concurrency=1) + + stages = [ + DummyStage(fn_constructor_kwargs=dict(factor=2)), + DummyStage(fn_constructor_kwargs=dict(factor=3)), + ] + + processor = Processor( + config=ProcessorConfig( + accelerator_type=None, + concurrency=1, + batch_size=64, + ), + stages=stages, + preprocess=lambda row: {"val": row["id"]}, + postprocess=lambda row: {"result": row["val"], "id": row["id"]}, + ) + + # Check the stage names. + stage_names = processor.list_stage_names() + assert stage_names == [ + "DummyStage", + "DummyStage_1", + ] + + # Check the stages. + for stage_name, stage in zip(stage_names, stages): + assert processor.get_stage_by_name(stage_name) == stage + + ds = ray.data.range(5) + ds = ds.map( + lambda row: { + "id": row["id"], + **({"extra": 1} if has_extra else {}), + } + ) + + ds = processor(ds).take_all() + extra = 1 if has_extra else 0 + for row in ds: + assert "id" in row + assert "result" in row + + # The final output should be the result of the last stage. + assert row["result"] == (row["id"] * 2 + extra) * 3 + extra + + +def test_builder(): + class DummyStatefulStageUDF(StatefulStageUDF): + async def udf( + self, batch: List[Dict[str, Any]] + ) -> AsyncIterator[Dict[str, Any]]: + for row in batch: + yield row + + class DummyStage(StatefulStage): + fn: StatefulStageUDF = DummyStatefulStageUDF + fn_constructor_kwargs: Dict[str, Any] = {} + map_batches_kwargs: Dict[str, Any] = {} + + class TestBuilderDummyProcessorConfig(ProcessorConfig): + pass + + def build_processor(config: ProcessorConfig) -> Processor: + stages = [ + DummyStage( + fn_constructor_kwargs=dict(), + map_batches_kwargs=dict(concurrency=1), + ) + ] + processor = Processor(config, stages) + return processor + + ProcessorBuilder.register(TestBuilderDummyProcessorConfig, build_processor) + + processor = ProcessorBuilder.build(TestBuilderDummyProcessorConfig(batch_size=64)) + assert isinstance(processor.config, TestBuilderDummyProcessorConfig) + assert processor.list_stage_names() == ["DummyStage"] + assert ( + processor.get_stage_by_name("DummyStage").map_batches_kwargs["concurrency"] == 1 + ) + + def overrider(name: str, stage: StatefulStage): + if name.startswith("DummyStage"): + stage.map_batches_kwargs["concurrency"] = 2 + + processor = ProcessorBuilder.build( + TestBuilderDummyProcessorConfig(batch_size=64), + override_stage_config_fn=overrider, + ) + assert processor.list_stage_names() == ["DummyStage"] + assert ( + processor.get_stage_by_name("DummyStage").map_batches_kwargs["concurrency"] == 2 + ) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/batch/processor/test_http_request_proc.py b/python/ray/llm/tests/batch/processor/test_http_request_proc.py new file mode 100644 index 0000000000000..96e27b094271e --- /dev/null +++ b/python/ray/llm/tests/batch/processor/test_http_request_proc.py @@ -0,0 +1,23 @@ +from ray.llm._internal.batch.processor import ProcessorBuilder +from ray.llm._internal.batch.processor.http_request_proc import ( + HttpRequestProcessorConfig, +) + + +def test_http_request_processor(): + config = HttpRequestProcessorConfig( + url="http://localhost:8000", + headers={"Authorization": "Bearer 1234567890"}, + qps=2, + concurrency=4, + batch_size=64, + ) + processor = ProcessorBuilder.build(config) + assert processor.list_stage_names() == ["HttpRequestStage"] + stage = processor.get_stage_by_name("HttpRequestStage") + assert stage.map_batches_kwargs["concurrency"] == 4 + assert stage.fn_constructor_kwargs["url"] == "http://localhost:8000" + assert stage.fn_constructor_kwargs["additional_header"] == { + "Authorization": "Bearer 1234567890" + } + assert stage.fn_constructor_kwargs["qps"] == 2 diff --git a/python/ray/llm/tests/batch/stages/BUILD b/python/ray/llm/tests/batch/stages/BUILD new file mode 100644 index 0000000000000..3777a36bc9b45 --- /dev/null +++ b/python/ray/llm/tests/batch/stages/BUILD @@ -0,0 +1,8 @@ +load("//bazel:python.bzl", "py_test_module_list") + +py_test_module_list( + files = glob(["test_*.py"]), + size = "small", + tags = ["exclusive", "team:llm"], + deps = ["//:ray_lib"], +) diff --git a/python/ray/llm/tests/batch/stages/test_base.py b/python/ray/llm/tests/batch/stages/test_base.py new file mode 100644 index 0000000000000..c99754a23b598 --- /dev/null +++ b/python/ray/llm/tests/batch/stages/test_base.py @@ -0,0 +1,109 @@ +import sys +import pytest +from typing import List, Any, AsyncIterator, Dict +from ray.llm._internal.batch.stages.base import ( + wrap_preprocess, + wrap_postprocess, + StatefulStage, + StatefulStageUDF, +) + + +def test_wrap_preprocess(): + # Test function that doubles a number + def double(x: dict) -> dict: + return {"value": x["id"] * 2} + + # Test with carry_over=True + wrapped = wrap_preprocess(double, "__data") + result = wrapped({"id": 5, "extra": "memo"}) + assert result == {"__data": {"id": 5, "extra": "memo", "value": 10}} + + +def test_wrap_postprocess(): + # Test function that converts number to string + def to_string(x: dict) -> dict: + return { + "result": str(x["value"]), + "extra": x["extra"], + } + + # Test with carry_over=True + wrapped = wrap_postprocess(to_string, "__data") + result = wrapped({"__data": {"id": 5, "extra": "memo", "value": 10}}) + assert result == {"extra": "memo", "result": "10"} + + # Test missing input column + with pytest.raises(ValueError): + wrapped({"wrong_key": 42}) + + +class TestStatefulStageUDF: + class SimpleUDF(StatefulStageUDF): + async def udf( + self, rows: list[Dict[str, Any]] + ) -> AsyncIterator[Dict[str, Any]]: + for row in rows: + yield {"processed": row["value"] * 2} + + @property + def expected_input_keys(self) -> List[str]: + return ["value"] + + @pytest.mark.asyncio + async def test_basic_processing(self): + udf = self.SimpleUDF(data_column="__data") + + batch = { + "__data": [{"value": 1, "extra": "a"}, {"value": 2, "extra": "b"}], + } + + results = [] + async for result in udf(batch): + results.append(result) + + assert len(results) == 2 + assert results[0] == { + "__data": [{"processed": 2, "value": 1, "extra": "a"}], + } + assert results[1] == { + "__data": [{"processed": 4, "value": 2, "extra": "b"}], + } + + @pytest.mark.asyncio + async def test_missing_data_column(self): + udf = self.SimpleUDF(data_column="__data") + + batch = {"extra": ["a"]} + + with pytest.raises(ValueError): + async for _ in udf(batch): + pass + + @pytest.mark.asyncio + async def test_missing_required_key(self): + udf = self.SimpleUDF(data_column="__data") + + batch = {"__data": [{"wrong_key": 1}]} + + with pytest.raises(ValueError): + async for _ in udf(batch): + pass + + +def test_stateful_stage(): + udf = TestStatefulStageUDF.SimpleUDF(data_column="__data") + + stage = StatefulStage( + fn=udf, + fn_constructor_kwargs={"data_column": "__data"}, + map_batches_kwargs={"batch_size": 10}, + ) + + assert stage.fn == udf + assert stage.fn_constructor_kwargs == {"data_column": "__data"} + assert stage.map_batches_kwargs == {"batch_size": 10} + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/batch/stages/test_chat_template_stage.py b/python/ray/llm/tests/batch/stages/test_chat_template_stage.py new file mode 100644 index 0000000000000..4d822d6fcd439 --- /dev/null +++ b/python/ray/llm/tests/batch/stages/test_chat_template_stage.py @@ -0,0 +1,99 @@ +import pytest +from unittest.mock import patch, MagicMock +from ray.llm._internal.batch.stages.chat_template_stage import ChatTemplateUDF + + +@pytest.fixture +def mock_tokenizer_setup(): + with patch( + "ray.llm._internal.batch.stages.chat_template_stage.get_cached_tokenizer" + ) as mock_get_tokenizer, patch("transformers.AutoTokenizer") as mock_auto_tokenizer: + mock_tokenizer = MagicMock() + mock_get_tokenizer.return_value = mock_tokenizer + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + yield mock_tokenizer + + +@pytest.mark.asyncio +async def test_chat_template_udf_basic(mock_tokenizer_setup): + mock_tokenizer = mock_tokenizer_setup + mock_tokenizer.apply_chat_template.return_value = ["Hello AI"] + + udf = ChatTemplateUDF( + data_column="__data", + model="test-model", + ) + + batch = [ + { + "messages": MagicMock( + tolist=lambda: [{"role": "user", "content": "Hello AI"}] + ) + } + ] + + results = [] + async for result in udf.udf(batch): + results.append(result) + + assert len(results) == 1 + assert results[0] == {"prompt": "Hello AI"} + mock_tokenizer.apply_chat_template.assert_called_once_with( + [batch[0]["messages"].tolist()], + tokenize=False, + add_generation_prompt=True, + ) + + +@pytest.mark.asyncio +async def test_chat_template_udf_multiple_messages(mock_tokenizer_setup): + mock_tokenizer = mock_tokenizer_setup + mock_tokenizer.apply_chat_template.return_value = [ + "Hello AI", + "How are you?", + ] + + udf = ChatTemplateUDF( + data_column="__data", + model="test-model", + ) + + batch = [ + { + "messages": MagicMock( + tolist=lambda: [{"role": "user", "content": "Hello AI"}] + ) + }, + { + "messages": MagicMock( + tolist=lambda: [{"role": "user", "content": "How are you?"}] + ) + }, + ] + + results = [] + async for result in udf.udf(batch): + results.append(result) + + assert len(results) == 2 + assert results[0] == {"prompt": "Hello AI"} + assert results[1] == {"prompt": "How are you?"} + mock_tokenizer.apply_chat_template.assert_called_once_with( + [msg["messages"].tolist() for msg in batch], + tokenize=False, + add_generation_prompt=True, + ) + + +def test_chat_template_udf_expected_input_keys(mock_tokenizer_setup): + mock_tokenizer = mock_tokenizer_setup + mock_tokenizer.apply_chat_template.return_value = [ + "Hello AI", + "How are you?", + ] + + udf = ChatTemplateUDF( + data_column="__data", + model="test-model", + ) + assert udf.expected_input_keys == ["messages"] diff --git a/python/ray/llm/tests/batch/stages/test_http_request_stage.py b/python/ray/llm/tests/batch/stages/test_http_request_stage.py new file mode 100644 index 0000000000000..636979f893ac6 --- /dev/null +++ b/python/ray/llm/tests/batch/stages/test_http_request_stage.py @@ -0,0 +1,87 @@ +import pytest +from unittest.mock import AsyncMock, patch +from ray.llm._internal.batch.stages.http_request_stage import HttpRequestUDF + + +@pytest.fixture +def mock_response(): + async def mock_json(): + return {"response": "test"} + + mock = AsyncMock() + mock.json = mock_json + return mock + + +@pytest.fixture +def mock_session(): + async def mock_post(*args, **kwargs): + return mock_response() + + mock = AsyncMock() + mock.post = AsyncMock(side_effect=mock_post) + return mock + + +@pytest.mark.asyncio +async def test_http_request_udf_basic(): + udf = HttpRequestUDF( + data_column="__data", + url="http://test.com/api", + additional_header={"Authorization": "Bearer 1234567890"}, + qps=None, + ) + + batch = [{"text": "hello", "metadata": "test"}] + + with patch("aiohttp.ClientSession") as mock_session_cls: + session = AsyncMock() + session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"response": "test"} + ) + mock_session_cls.return_value.__aenter__.return_value = session + + async for result in udf.udf(batch): + assert result == {"response": "test"} + + session.post.assert_called_once_with( + "http://test.com/api", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer 1234567890", + }, + json={"text": "hello", "metadata": "test"}, + ) + + +@pytest.mark.asyncio +async def test_http_request_udf_with_qps(): + udf = HttpRequestUDF( + data_column="__data", + url="http://test.com/api", + qps=2, + ) + + batch = [{"text": "hello1"}, {"text": "hello2"}] + + with patch("aiohttp.ClientSession") as mock_session_cls, patch( + "time.time" + ) as mock_time, patch("asyncio.sleep") as mock_sleep: + + session = AsyncMock() + session.post.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"response": "test"} + ) + mock_session_cls.return_value.__aenter__.return_value = session + + # Mock time to test QPS limiting. Req2 cannot be sent until 0.5s, + # so the asyncio.sleep should be called once. + # [start_time, req1_time, req2_time] + mock_time.side_effect = [0, 0.1, 0.2] + + results = [] + async for result in udf.udf(batch): + results.append(result) + + assert len(results) == 2 + assert mock_sleep.called # Should have called sleep for QPS limiting diff --git a/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py b/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py new file mode 100644 index 0000000000000..d91885c99ae0f --- /dev/null +++ b/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py @@ -0,0 +1,156 @@ +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +from PIL import Image +import io +import base64 + +from ray.llm._internal.batch.stages.prepare_image_stage import ( + PrepareImageUDF, + ImageProcessor, +) + + +@pytest.fixture +def mock_image(): + # Create a small test image + img = Image.new("RGB", (100, 100), color="red") + return img + + +@pytest.fixture +def mock_http_connection(): + with patch( + "ray.llm._internal.batch.stages.prepare_image_stage.HTTPConnection" + ) as mock: + connection = MagicMock() + connection.async_get_bytes = AsyncMock() + mock.return_value = connection + yield connection + + +@pytest.fixture +def mock_image_processor(mock_http_connection, mock_image): + with patch( + "ray.llm._internal.batch.stages.prepare_image_stage.ImageProcessor" + ) as mock: + processor = MagicMock() + processor.process = AsyncMock( + side_effect=lambda images: [mock_image] * len(images) + ) + mock.return_value = processor + yield processor + + +@pytest.mark.asyncio +async def test_prepare_image_udf_basic(mock_image_processor, mock_image): + udf = PrepareImageUDF(data_column="__data") + + # Test batch with one message containing an image URL + batch = [ + { + "messages": [ + { + "content": [ + {"type": "image", "image": "http://example.com/image.jpg"} + ] + } + ] + } + ] + + results = [] + async for result in udf.udf(batch): + results.append(result) + + assert len(results) == 1 + assert "image" in results[0] + assert "image_sizes" in results[0] + assert len(results[0]["image"]) == 1 + assert all(isinstance(img, Image.Image) for img in results[0]["image"]) + + +@pytest.mark.asyncio +async def test_prepare_image_udf_multiple_images(mock_image_processor, mock_image): + udf = PrepareImageUDF(data_column="__data") + + # Test batch with multiple images in one message + batch = [ + { + "messages": [ + { + "content": [ + {"type": "image", "image": "http://example.com/image1.jpg"}, + {"type": "image", "image": "http://example.com/image2.jpg"}, + ] + } + ] + } + ] + + results = [] + async for result in udf.udf(batch): + results.append(result) + + assert len(results) == 1 + assert len(results[0]["image"]) == 2 + assert len(results[0]["image_sizes"]) == 2 + + +@pytest.mark.asyncio +async def test_prepare_image_udf_no_images(mock_image_processor): + udf = PrepareImageUDF(data_column="__data") + + # Test batch with no images + batch = [{"messages": [{"content": "Hello, world!"}]}] + + results = [] + async for result in udf.udf(batch): + results.append(result) + + assert len(results) == 1 + assert results[0] == {} + + +@pytest.mark.asyncio +async def test_image_processor_fetch_images(mock_http_connection, mock_image): + processor = ImageProcessor() + + # Create a base64 image + img_byte_arr = io.BytesIO() + mock_image.save(img_byte_arr, format="PNG") + img_byte_arr = img_byte_arr.getvalue() + base64_image = f"data:image/png;base64,{base64.b64encode(img_byte_arr).decode()}" + + # Test HTTP image + mock_http_connection.async_get_bytes.return_value = img_byte_arr + http_images = await processor.fetch_images(["http://example.com/image.jpg"]) + assert len(http_images) == 1 + assert isinstance(http_images[0], Image.Image) + + # Test base64 image + base64_images = await processor.fetch_images([base64_image]) + assert len(base64_images) == 1 + assert isinstance(base64_images[0], Image.Image) + + +def test_prepare_image_udf_expected_keys(): + udf = PrepareImageUDF(data_column="__data") + assert udf.expected_input_keys == ["messages"] + + +@pytest.mark.asyncio +async def test_prepare_image_udf_invalid_image_type(mock_image_processor): + udf = PrepareImageUDF(data_column="__data") + + # Test batch with invalid image type + batch = [ + { + "messages": [ + {"content": [{"type": "image", "image": 123}]} # Invalid image type + ] + } + ] + + with pytest.raises(ValueError, match="Cannot handle image type"): + async for _ in udf.udf(batch): + pass diff --git a/python/ray/llm/tests/batch/stages/test_tokenize_stage.py b/python/ray/llm/tests/batch/stages/test_tokenize_stage.py new file mode 100644 index 0000000000000..8e940e10ebc0e --- /dev/null +++ b/python/ray/llm/tests/batch/stages/test_tokenize_stage.py @@ -0,0 +1,73 @@ +import pytest +from unittest.mock import patch, MagicMock +from ray.llm._internal.batch.stages.tokenize_stage import TokenizeUDF, DetokenizeUDF + + +@pytest.fixture +def mock_tokenizer_setup(): + with patch( + "ray.llm._internal.batch.stages.tokenize_stage.get_cached_tokenizer" + ) as mock_get_tokenizer, patch("transformers.AutoTokenizer") as mock_auto_tokenizer: + mock_tokenizer = MagicMock() + mock_tokenizer.side_effect = lambda texts: { + "input_ids": [[1, 2, 3] for _ in texts] + } + mock_get_tokenizer.return_value = mock_tokenizer + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + yield mock_tokenizer + + +@pytest.mark.asyncio +async def test_tokenize_udf_basic(mock_tokenizer_setup): + mock_tokenizer = mock_tokenizer_setup + mock_tokenizer.return_value = [ + {"input_ids": [1, 2, 3]}, + {"input_ids": [4, 5, 6]}, + ] + + udf = TokenizeUDF(data_column="__data", model="test-model") + batch = [{"prompt": "Hello"}, {"prompt": "World"}] + + results = [] + async for result in udf.udf(batch): + results.append(result) + + assert len(results) == 2 + assert all(result["tokenized_prompt"] == [1, 2, 3] for result in results) + assert all( + original["prompt"] == result["prompt"] + for original, result in zip(batch, results) + ) + + +@pytest.mark.asyncio +async def test_detokenize_udf_basic(mock_tokenizer_setup): + mock_tokenizer = mock_tokenizer_setup + mock_tokenizer.batch_decode.return_value = ["Hello", "World"] + + udf = DetokenizeUDF(data_column="__data", model="test-model") + batch = [ + {"generated_tokens": [1, 2, 3]}, + {"generated_tokens": [4, 5, 6]}, + ] + + results = [] + async for result in udf.udf(batch): + results.append(result) + + assert len(results) == 2 + assert results[0]["generated_text"] == "Hello" + assert results[1]["generated_text"] == "World" + mock_tokenizer.batch_decode.assert_called_once_with( + [[1, 2, 3], [4, 5, 6]], skip_special_tokens=True + ) + + +def test_tokenize_udf_expected_keys(mock_tokenizer_setup): + udf = TokenizeUDF(data_column="__data", model="test-model") + assert udf.expected_input_keys == ["prompt"] + + +def test_detokenize_udf_expected_keys(mock_tokenizer_setup): + udf = DetokenizeUDF(data_column="__data", model="test-model") + assert udf.expected_input_keys == ["generated_tokens"] diff --git a/python/ray/serve/_private/api.py b/python/ray/serve/_private/api.py index 954d71cac10d4..5b9d9f749a5b4 100644 --- a/python/ray/serve/_private/api.py +++ b/python/ray/serve/_private/api.py @@ -105,7 +105,7 @@ def _start_controller( max_concurrency=CONTROLLER_MAX_CONCURRENCY, enable_task_events=RAY_SERVE_ENABLE_TASK_EVENTS, ).remote( - http_config=http_options, + http_options=http_options, grpc_options=grpc_options, global_logging_config=global_logging_config, ) diff --git a/python/ray/serve/_private/benchmarks/proxy_benchmark.py b/python/ray/serve/_private/benchmarks/proxy_benchmark.py index 347c06854a5c0..8f3eca322f29c 100644 --- a/python/ray/serve/_private/benchmarks/proxy_benchmark.py +++ b/python/ray/serve/_private/benchmarks/proxy_benchmark.py @@ -75,7 +75,7 @@ async def fetch_http(session, data): async def fetch_grpc(stub, data): result = await stub.grpc_call(serve_pb2.RawData(nums=data)) - result.output + _ = result.output @ray.remote diff --git a/python/ray/serve/_private/controller.py b/python/ray/serve/_private/controller.py index 4aa6906b241fe..335f781166ed7 100644 --- a/python/ray/serve/_private/controller.py +++ b/python/ray/serve/_private/controller.py @@ -109,7 +109,7 @@ class ServeController: async def __init__( self, *, - http_config: HTTPOptions, + http_options: HTTPOptions, global_logging_config: LoggingConfig, grpc_options: Optional[gRPCOptions] = None, ): @@ -153,7 +153,7 @@ async def __init__( self.cluster_node_info_cache.update() self.proxy_state_manager = ProxyStateManager( - config=http_config, + http_options=http_options, head_node_id=self._controller_node_id, cluster_node_info_cache=self.cluster_node_info_cache, logging_config=self.global_logging_config, @@ -1154,9 +1154,6 @@ def __init__( except ValueError: self._controller = None if self._controller is None: - http_config = HTTPOptions() - logging_config = LoggingConfig() - http_config.port = http_proxy_port self._controller = ServeController.options( num_cpus=0, name=SERVE_CONTROLLER_NAME, @@ -1168,8 +1165,8 @@ def __init__( max_concurrency=CONTROLLER_MAX_CONCURRENCY, enable_task_events=RAY_SERVE_ENABLE_TASK_EVENTS, ).remote( - http_config=http_config, - global_logging_config=logging_config, + http_options=HTTPOptions(port=http_proxy_port), + global_logging_config=LoggingConfig(), ) def check_alive(self) -> None: diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 9a4b354adb3c8..791f515c2372f 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -2590,6 +2590,7 @@ def get_deployment_details(self, id: DeploymentID) -> Optional[DeploymentDetails id.name, self.get_deployment(id) ), target_num_replicas=deployment_state._target_state.target_num_replicas, + required_resources=deployment_state.target_info.replica_config.resource_dict, replicas=deployment_state.list_replica_details(), ) diff --git a/python/ray/serve/_private/grpc_util.py b/python/ray/serve/_private/grpc_util.py index 97d77235d7d54..235e10297bf0d 100644 --- a/python/ray/serve/_private/grpc_util.py +++ b/python/ray/serve/_private/grpc_util.py @@ -1,8 +1,12 @@ +import asyncio from typing import Callable, List, Optional, Sequence, Tuple +from unittest.mock import Mock import grpc from grpc.aio._server import Server +from ray.serve.config import gRPCOptions +from ray.serve.generated.serve_pb2_grpc import add_RayServeAPIServiceServicer_to_server from ray.serve._private.constants import DEFAULT_GRPC_SERVER_OPTIONS @@ -17,7 +21,7 @@ def __init__( self, service_handler_factory: Callable, *, - extra_options: Optional[List[Tuple[str, str]]] = None + extra_options: Optional[List[Tuple[str, str]]] = None, ): super().__init__( thread_pool=None, @@ -63,14 +67,33 @@ def add_generic_rpc_handlers( super().add_generic_rpc_handlers(generic_rpc_handlers) -class DummyServicer: - """Dummy servicer for gRPC server to call on. +async def start_grpc_server( + service_handler_factory: Callable, + grpc_options: gRPCOptions, + *, + event_loop: asyncio.AbstractEventLoop, + enable_so_reuseport: bool = False, +) -> asyncio.Task: + """Start a gRPC server that handles requests with the service handler factory. - This is a dummy class that just pass through when calling on any method. - User defined servicer function will attempt to add the method on this class to the - gRPC server, but our gRPC server will override the caller to call gRPCProxy. + Returns a task that blocks until the server exits (e.g., due to error). """ + from ray.serve._private.default_impl import add_grpc_address - def __getattr__(self, attr): - # No-op pass through. Just need this to act as the callable. - pass + server = gRPCGenericServer( + service_handler_factory, + extra_options=[("grpc.so_reuseport", str(int(enable_so_reuseport)))], + ) + add_grpc_address(server, f"[::]:{grpc_options.port}") + + # Add built-in gRPC service and user-defined services to the server. + # We pass a mock servicer because the actual implementation will be overwritten + # in the gRPCGenericServer implementation. + mock_servicer = Mock() + for servicer_fn in [ + add_RayServeAPIServiceServicer_to_server + ] + grpc_options.grpc_servicer_func_callable: + servicer_fn(mock_servicer, server) + + await server.start() + return event_loop.create_task(server.wait_for_termination()) diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index 0fd4a1ac84332..0cb430f6ee213 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -6,18 +6,23 @@ import socket from collections import deque from dataclasses import dataclass +from packaging import version from typing import Any, Awaitable, Callable, List, Optional, Tuple, Type import starlette +import uvicorn from fastapi.encoders import jsonable_encoder +from starlette.datastructures import MutableHeaders +from starlette.middleware import Middleware from starlette.types import ASGIApp, Message, Receive, Scope, Send from uvicorn.config import Config from uvicorn.lifespan.on import LifespanOn from ray._private.pydantic_compat import IS_PYDANTIC_2 +from ray.serve.config import HTTPOptions from ray.serve._private.common import RequestMetadata from ray.serve._private.constants import SERVE_LOGGER_NAME -from ray.serve._private.utils import serve_encoders +from ray.serve._private.utils import serve_encoders, generate_request_id from ray.serve.exceptions import RayServeException logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -513,7 +518,7 @@ async def __del__(self): def validate_http_proxy_callback_return( middlewares: Any, -) -> [starlette.middleware.Middleware]: +) -> [Middleware]: """Validate the return value of HTTP proxy callback. Middlewares should be a list of Starlette middlewares. If it is None, we @@ -532,9 +537,102 @@ def validate_http_proxy_callback_return( # All middlewares must be Starlette middlewares. # https://www.starlette.io/middleware/#using-pure-asgi-middleware for middleware in middlewares: - if not issubclass(type(middleware), starlette.middleware.Middleware): + if not issubclass(type(middleware), Middleware): raise ValueError( "HTTP proxy callback must return a list of Starlette middlewares, " f"instead got {type(middleware)} type item in the list." ) return middlewares + + +class RequestIdMiddleware: + def __init__(self, app: ASGIApp): + self._app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + headers = MutableHeaders(scope=scope) + if "x-request-id" not in headers: + request_id = generate_request_id() + headers.append("x-request-id", request_id) + elif "x-request-id" in headers: + request_id = headers["x-request-id"] + + async def send_with_request_id(message: Message): + if message["type"] == "http.response.start": + headers = MutableHeaders(scope=message) + headers.append("X-Request-ID", request_id) + if message["type"] == "websocket.accept": + message["X-Request-ID"] = request_id + await send(message) + + await self._app(scope, receive, send_with_request_id) + + +def _apply_middlewares(app: ASGIApp, middlewares: List[Callable]) -> ASGIApp: + """Wrap the ASGI app with the provided middlewares. + + The built-in RequestIdMiddleware will always be applied first. + """ + for middleware in [Middleware(RequestIdMiddleware)] + middlewares: + if version.parse(starlette.__version__) < version.parse("0.35.0"): + app = middleware.cls(app, **middleware.options) + else: + # In starlette >= 0.35.0, middleware.options does not exist: + # https://github.com/encode/starlette/pull/2381. + app = middleware.cls( + app, + *middleware.args, + **middleware.kwargs, + ) + + return app + + +async def start_asgi_http_server( + app: ASGIApp, + http_options: HTTPOptions, + *, + event_loop: asyncio.AbstractEventLoop, + enable_so_reuseport: bool = False, +) -> asyncio.Task: + """Start an HTTP server to run the ASGI app. + + Returns a task that blocks until the server exits (e.g., due to error). + """ + app = _apply_middlewares(app, http_options.middlewares) + + sock = socket.socket() + if enable_so_reuseport: + set_socket_reuse_port(sock) + + try: + sock.bind((http_options.host, http_options.port)) + except OSError as e: + raise RuntimeError( + f"Failed to bind to address '{http_options.host}:{http_options.port}'." + ) from e + + # NOTE: We have to use lower level uvicorn Config and Server + # class because we want to run the server as a coroutine. The only + # alternative is to call uvicorn.run which is blocking. + server = uvicorn.Server( + config=uvicorn.Config( + lambda: app, + factory=True, + host=http_options.host, + port=http_options.port, + root_path=http_options.root_path, + timeout_keep_alive=http_options.keep_alive_timeout_s, + loop=event_loop, + lifespan="off", + access_log=False, + log_level="warning", + ) + ) + + # NOTE(edoakes): we need to override install_signal_handlers here + # because the existing implementation fails if it isn't running in + # the main thread and uvicorn doesn't expose a way to configure it. + server.install_signal_handlers = lambda: None + + return event_loop.create_task(server.serve(sockets=[sock])) diff --git a/python/ray/serve/_private/proxy.py b/python/ray/serve/_private/proxy.py index 2e445a36f16f3..1f1ff2f7fae95 100644 --- a/python/ray/serve/_private/proxy.py +++ b/python/ray/serve/_private/proxy.py @@ -4,18 +4,15 @@ import logging import os import pickle -import socket import time from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple +from copy import deepcopy +from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple import grpc import starlette import starlette.routing -import uvicorn from packaging import version -from starlette.datastructures import MutableHeaders -from starlette.middleware import Middleware from starlette.types import Receive import ray @@ -31,7 +28,6 @@ ) from ray.serve._private.constants import ( DEFAULT_LATENCY_BUCKET_MS, - DEFAULT_UVICORN_KEEP_ALIVE_TIMEOUT_S, PROXY_MIN_DRAINING_PERIOD_S, RAY_SERVE_ENABLE_PROXY_GC_OPTIMIZATIONS, RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH, @@ -41,13 +37,13 @@ SERVE_MULTIPLEXED_MODEL_ID, SERVE_NAMESPACE, ) -from ray.serve._private.default_impl import add_grpc_address, get_proxy_handle -from ray.serve._private.grpc_util import DummyServicer, gRPCGenericServer +from ray.serve._private.default_impl import get_proxy_handle +from ray.serve._private.grpc_util import start_grpc_server from ray.serve._private.http_util import ( MessageQueue, convert_object_to_asgi_messages, receive_http_body, - set_socket_reuse_port, + start_asgi_http_server, validate_http_proxy_callback_return, ) from ray.serve._private.logging_utils import ( @@ -75,10 +71,9 @@ generate_request_id, get_head_node_id, ) -from ray.serve.config import gRPCOptions +from ray.serve.config import HTTPOptions, gRPCOptions from ray.serve.exceptions import BackPressureError, DeploymentUnavailableError from ray.serve.generated.serve_pb2 import HealthzResponse, ListApplicationsResponse -from ray.serve.generated.serve_pb2_grpc import add_RayServeAPIServiceServicer_to_server from ray.serve.handle import DeploymentHandle from ray.serve.schema import LoggingConfig from ray.util import metrics @@ -107,9 +102,6 @@ or float(os.environ.get("SERVE_REQUEST_PROCESSING_TIMEOUT_S", 0)) or None ) -# Controls whether Ray Serve is operating in debug-mode switching off some -# of the performance optimizations to make troubleshooting easier -RAY_SERVE_DEBUG_MODE = bool(os.environ.get("RAY_SERVE_DEBUG_MODE", 0)) if os.environ.get("SERVE_REQUEST_PROCESSING_TIMEOUT_S") is not None: logger.warning( @@ -1106,68 +1098,71 @@ async def send_request_to_replica( yield status -class RequestIdMiddleware: - def __init__(self, app): - self.app = app +def _set_proxy_default_http_options(http_options: HTTPOptions) -> HTTPOptions: + http_options = deepcopy(http_options) + # Override keep alive setting if the environment variable is set. + # TODO(edoakes): more sane behavior here. + if RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S > 0: + http_options.keep_alive_timeout_s = RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S - async def __call__(self, scope, receive, send): - headers = MutableHeaders(scope=scope) - if "x-request-id" not in headers: - # If X-Request-ID is not set, we - # generate a new request ID. - request_id = generate_request_id() - headers.append("x-request-id", request_id) - elif "x-request-id" in headers: - request_id = headers["x-request-id"] + http_options.request_timeout_s = ( + http_options.request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S + ) - async def send_with_request_id(message: Dict): - if message["type"] == "http.response.start": - headers = MutableHeaders(scope=message) - headers.append("X-Request-ID", request_id) - if message["type"] == "websocket.accept": - message["X-Request-ID"] = request_id - await send(message) + http_options.middlewares = http_options.middlewares or [] + if RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH: + logger.info( + "Calling user-provided callback from import path " + f"'{RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH}'." + ) + http_options.middlewares.extend( + validate_http_proxy_callback_return( + call_function_from_import_path( + RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH + ) + ) + ) - await self.app(scope, receive, send_with_request_id) + return http_options @ray.remote(num_cpus=0) class ProxyActor: def __init__( self, - host: str, - port: int, - root_path: str, - node_ip_address: str, + http_options: HTTPOptions, + *, + grpc_options: Optional[gRPCOptions] = None, node_id: NodeId, + node_ip_address: str, logging_config: LoggingConfig, - request_timeout_s: Optional[float] = None, - http_middlewares: Optional[List["starlette.middleware.Middleware"]] = None, - keep_alive_timeout_s: int = DEFAULT_UVICORN_KEEP_ALIVE_TIMEOUT_S, - grpc_options: Optional[gRPCOptions] = None, long_poll_client: Optional[LongPollClient] = None, ): # noqa: F821 - self.grpc_options = grpc_options or gRPCOptions() - self.host = host - self.port = port - self.grpc_port = self.grpc_options.port - self.root_path = root_path - self.keep_alive_timeout_s = ( - RAY_SERVE_HTTP_KEEP_ALIVE_TIMEOUT_S or keep_alive_timeout_s - ) - self._uvicorn_server = None - self.node_ip_address = node_ip_address + self._node_id = node_id + self._node_ip_address = node_ip_address + + # Configure proxy default HTTP and gRPC options. + http_options = _set_proxy_default_http_options(http_options) + grpc_options = grpc_options or gRPCOptions() + self._http_options = http_options + self._grpc_options = grpc_options - self.http_setup_complete = asyncio.Event() - self.grpc_setup_complete = asyncio.Event() + # We modify the HTTP and gRPC options above, so delete them to avoid + del http_options, grpc_options + + grpc_enabled = ( + self._grpc_options.port > 0 + and len(self._grpc_options.grpc_servicer_functions) > 0 + ) + event_loop = get_or_create_event_loop() self.long_poll_client = long_poll_client or LongPollClient( ray.get_actor(SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE), { LongPollNamespace.GLOBAL_LOGGING_CONFIG: self._update_logging_config, LongPollNamespace.ROUTE_TABLE: self._update_routes_in_proxies, }, - call_in_event_loop=get_or_create_event_loop(), + call_in_event_loop=event_loop, ) configure_component_logger( @@ -1176,9 +1171,9 @@ def __init__( logging_config=logging_config, ) - startup_msg = f"Proxy starting on node {node_id} (HTTP port: {port}" - if self.should_start_grpc_service(): - startup_msg += f", gRPC port: {self.grpc_options.port})." + startup_msg = f"Proxy starting on node {self._node_id} (HTTP port: {self._http_options.port}" + if grpc_enabled: + startup_msg += f", gRPC port: {self._grpc_options.port})." else: startup_msg += ")." logger.info(startup_msg) @@ -1194,77 +1189,57 @@ def __init__( component_name="proxy", component_id=node_ip_address ) - if http_middlewares is None: - http_middlewares = [Middleware(RequestIdMiddleware)] - else: - http_middlewares.append(Middleware(RequestIdMiddleware)) - - if RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH: - logger.info( - "Calling user-provided callback from import path " - f"'{RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH}'." - ) - middlewares = validate_http_proxy_callback_return( - call_function_from_import_path( - RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH - ) - ) - - http_middlewares.extend(middlewares) - - is_head = node_id == get_head_node_id() + is_head = self._node_id == get_head_node_id() self.proxy_router = ProxyRouter(get_proxy_handle) self.http_proxy = HTTPProxy( - node_id=node_id, - node_ip_address=node_ip_address, + node_id=self._node_id, + node_ip_address=self._node_ip_address, is_head=is_head, self_actor_name=ray.get_runtime_context().get_actor_name(), proxy_router=self.proxy_router, - request_timeout_s=( - request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S - ), + request_timeout_s=self._http_options.request_timeout_s, ) self.grpc_proxy = ( gRPCProxy( - node_id=node_id, - node_ip_address=node_ip_address, + node_id=self._node_id, + node_ip_address=self._node_ip_address, is_head=is_head, proxy_router=self.proxy_router, - request_timeout_s=( - request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S - ), + request_timeout_s=RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S, ) - if self.should_start_grpc_service() + if grpc_enabled else None ) - self.wrapped_http_proxy = self.http_proxy - - for middleware in http_middlewares: - if version.parse(starlette.__version__) < version.parse("0.35.0"): - self.wrapped_http_proxy = middleware.cls( - self.wrapped_http_proxy, **middleware.options - ) - else: - # In starlette >= 0.35.0, middleware.options does not exist: - # https://github.com/encode/starlette/pull/2381. - self.wrapped_http_proxy = middleware.cls( - self.wrapped_http_proxy, - *middleware.args, - **middleware.kwargs, - ) - - # Start running the HTTP server on the event loop. - # This task should be running forever. We track it in case of failure. - self.running_task_http = get_or_create_event_loop().create_task( - self.run_http_server() - ) - - # Start running the gRPC server on the event loop. - # This task should be running forever. We track it in case of failure. - self.running_task_grpc = get_or_create_event_loop().create_task( - self.run_grpc_server() + # Start a task to initialize the HTTP server. + # The result of this task is checked in the `ready` method. + self._start_http_server_task = event_loop.create_task( + start_asgi_http_server( + self.http_proxy, + self._http_options, + event_loop=event_loop, + enable_so_reuseport=SOCKET_REUSE_PORT_ENABLED, + ) ) + # A task that runs the HTTP server until it exits (currently runs forever). + # Populated with the result of self._start_http_server_task. + self._running_http_server_task: Optional[asyncio.Task] = None + + # Start a task to initialize the gRPC server. + # The result of this task is checked in the `ready` method. + self._start_grpc_server_task: Optional[asyncio.Task] = None + if grpc_enabled: + self._start_grpc_server_task = event_loop.create_task( + start_grpc_server( + self.grpc_proxy.service_handler_factory, + self._grpc_options, + event_loop=event_loop, + enable_so_reuseport=SOCKET_REUSE_PORT_ENABLED, + ), + ) + # A task that runs the gRPC server until it exits (currently runs forever). + # Populated with the result of self._start_grpc_server_task. + self._running_grpc_server_task: Optional[asyncio.Task] = None _configure_gc_options() @@ -1274,7 +1249,7 @@ def _update_routes_in_proxies(self, endpoints: Dict[DeploymentID, EndpointInfo]) def _update_logging_config(self, logging_config: LoggingConfig): configure_component_logger( component_name="proxy", - component_id=self.node_ip_address, + component_id=self._node_ip_address, logging_config=logging_config, ) @@ -1290,148 +1265,36 @@ def _dump_ingress_replicas_for_testing(self, route: str) -> Set[ReplicaID]: _, handle, _ = self.http_proxy.proxy_router.match_route(route) return handle._router._asyncio_router._replica_scheduler._replica_id_set - def should_start_grpc_service(self) -> bool: - """Determine whether gRPC service should be started. + async def ready(self) -> str: + """Blocks until the proxy HTTP (and optionally gRPC) servers are running. - gRPC service will only be started if a valid port is provided and if the - servicer functions are passed. - """ - return self.grpc_port > 0 and len(self.grpc_options.grpc_servicer_functions) > 0 + Returns JSON-serialized metadata containing the proxy's worker ID and log + file path. - async def ready(self): - """Returns when both HTTP and gRPC proxies are ready to serve traffic. - Or throw exception when either proxy is not able to serve traffic. + Raises any exceptions that occur setting up the HTTP or gRPC server. """ - http_setup_complete_wait_task = get_or_create_event_loop().create_task( - self.http_setup_complete.wait() - ) - grpc_setup_complete_wait_task = get_or_create_event_loop().create_task( - self.grpc_setup_complete.wait() - ) - - waiting_tasks_http = [ - # Either the HTTP setup has completed. - # The event is set inside self.run_http_server. - http_setup_complete_wait_task, - # Or self.run_http_server errored. - self.running_task_http, - ] - done_set_http, _ = await asyncio.wait( - waiting_tasks_http, - return_when=asyncio.FIRST_COMPLETED, - ) - waiting_tasks_grpc = [ - # Either the gRPC setup has completed. - # The event is set inside self.run_grpc_server. - grpc_setup_complete_wait_task, - # Or self.run_grpc_server errored. - self.running_task_grpc, - ] - done_set_grpc, _ = await asyncio.wait( - waiting_tasks_grpc, - return_when=asyncio.FIRST_COMPLETED, - ) - - # Return metadata, or re-throw the exception from self.running_task_http and - # self.running_task_grpc. - if self.http_setup_complete.is_set() and self.grpc_setup_complete.is_set(): - # NOTE(zcin): We need to convert the metadata to a json string because - # of cross-language scenarios. Java can't deserialize a Python tuple. - return json.dumps( - [ - ray.get_runtime_context().get_worker_id(), - get_component_logger_file_path(), - ] - ) - else: - proxy_error = None - if not self.http_setup_complete.is_set(): - try: - await done_set_http.pop() - except Exception as e: - logger.exception(e) - proxy_error = e - if not self.grpc_setup_complete.is_set(): - try: - await done_set_grpc.pop() - except Exception as e: - logger.exception(e) - proxy_error = e - raise proxy_error - - async def run_http_server(self): - sock = socket.socket() - if SOCKET_REUSE_PORT_ENABLED: - set_socket_reuse_port(sock) try: - sock.bind((self.host, self.port)) - except OSError: - # The OS failed to bind a socket to the given host and port. - raise ValueError( - f"Failed to bind Ray Serve HTTP proxy to '{self.host}:{self.port}'. " - "Please make sure your http-host and http-port are specified correctly." - ) - - # NOTE: We have to use lower level uvicorn Config and Server - # class because we want to run the server as a coroutine. The only - # alternative is to call uvicorn.run which is blocking. - config = uvicorn.Config( - self.wrapped_http_proxy, - host=self.host, - port=self.port, - loop=_determine_target_loop(), - root_path=self.root_path, - lifespan="off", - log_level="warning", - access_log=False, - timeout_keep_alive=self.keep_alive_timeout_s, - ) - self._uvicorn_server = uvicorn.Server(config=config) - # TODO(edoakes): we need to override install_signal_handlers here - # because the existing implementation fails if it isn't running in - # the main thread and uvicorn doesn't expose a way to configure it. - self._uvicorn_server.install_signal_handlers = lambda: None - - logger.debug( - "Starting HTTP server on node: " - f"{ray.get_runtime_context().get_node_id()} " - f"listening on port {self.port}" - ) - - self.http_setup_complete.set() - await self._uvicorn_server.serve(sockets=[sock]) - - async def run_grpc_server(self): - if not self.should_start_grpc_service(): - return self.grpc_setup_complete.set() - - grpc_server = gRPCGenericServer( - service_handler_factory=self.grpc_proxy.service_handler_factory, - ) - - add_grpc_address(grpc_server, f"[::]:{self.grpc_port}") - - # Dummy servicer is used to be callable for the gRPC server. Serve have a - # custom gRPC server implementation to redirect calls into gRPCProxy. - # See: ray/serve/_private/grpc_util.py - dummy_servicer = DummyServicer() - - # Add Ray Serve gRPC service and methods (e.g. ListApplications and Healthz). - add_RayServeAPIServiceServicer_to_server(dummy_servicer, grpc_server) - - # Iterate through each of user provided gRPC servicer functions and add user - # defined services and methods. - for grpc_servicer_function in self.grpc_options.grpc_servicer_func_callable: - grpc_servicer_function(dummy_servicer, grpc_server) + self._running_http_server_task = await self._start_http_server_task + except Exception as e: + logger.exception("Failed to start proxy HTTP server.") + raise e from None - await grpc_server.start() - logger.debug( - "Starting gRPC server on node: " - f"{ray.get_runtime_context().get_node_id()} " - f"listening on port {self.grpc_port}" + try: + if self._start_grpc_server_task is not None: + self._running_grpc_server_task = await self._start_grpc_server_task + except Exception as e: + logger.exception("Failed to start proxy gRPC server.") + raise e from None + + # Return proxy metadata used by the controller. + # NOTE(zcin): We need to convert the metadata to a json string because + # of cross-language scenarios. Java can't deserialize a Python tuple. + return json.dumps( + [ + ray.get_runtime_context().get_worker_id(), + get_component_logger_file_path(), + ] ) - self.grpc_setup_complete.set() - await grpc_server.wait_for_termination() async def update_draining(self, draining: bool, _after: Optional[Any] = None): """Update the draining status of the HTTP and gRPC proxies. @@ -1500,30 +1363,9 @@ def _save_cpu_profile_data(self) -> str: "the RAY_SERVE_ENABLE_CPU_PROFILING env var." ) - async def _uvicorn_keep_alive(self) -> Optional[int]: - """Get the keep alive timeout used for the running uvicorn server. - - Return the timeout_keep_alive config used on the uvicorn server if it's running. - If the server is not running, return None. - """ - if self._uvicorn_server: - return self._uvicorn_server.config.timeout_keep_alive - - -def _determine_target_loop(): - """We determine target loop based on whether RAY_SERVE_DEBUG_MODE is enabled: - - - RAY_SERVE_DEBUG_MODE=0 (default): we use "uvloop" (Cython) providing - high-performance, native implementation of the event-loop - - - RAY_SERVE_DEBUG_MODE=1: we fall back to "asyncio" (pure Python) event-loop - implementation that is considerably slower than "uvloop", - but provides for easy access to the source implementation - """ - if RAY_SERVE_DEBUG_MODE: - return "asyncio" - else: - return "uvloop" + def _get_http_options(self) -> HTTPOptions: + """Internal method to get HTTP options used by the proxy.""" + return self._http_options def _configure_gc_options(): diff --git a/python/ray/serve/_private/proxy_state.py b/python/ray/serve/_private/proxy_state.py index 3325bbe0ee10c..ea81423808294 100644 --- a/python/ray/serve/_private/proxy_state.py +++ b/python/ray/serve/_private/proxy_state.py @@ -3,6 +3,7 @@ import logging import os from abc import ABC, abstractmethod +from copy import deepcopy from typing import Dict, List, Optional, Set, Tuple, Type import ray @@ -98,7 +99,7 @@ def __init__( self, logging_config: LoggingConfig, actor_handle: Optional[ActorHandle] = None, - config: Optional[HTTPOptions] = None, + http_options: Optional[HTTPOptions] = None, grpc_options: Optional[gRPCOptions] = None, name: Optional[str] = None, node_id: Optional[str] = None, @@ -108,7 +109,7 @@ def __init__( ): # initialize with provided proxy actor handle or get or create a new one. self._actor_handle = actor_handle or self._get_or_create_proxy_actor( - config=config, + http_options=http_options, grpc_options=grpc_options, name=name, node_id=node_id, @@ -130,7 +131,7 @@ def __init__( @staticmethod def _get_or_create_proxy_actor( - config: HTTPOptions, + http_options: HTTPOptions, grpc_options: gRPCOptions, name: str, node_id: str, @@ -148,14 +149,14 @@ def _get_or_create_proxy_actor( try: proxy = ray.get_actor(name, namespace=SERVE_NAMESPACE) except ValueError: + addr = f"{http_options.host}:{http_options.port}" logger.info( - f"Starting proxy on node '{node_id}' " - f"listening on '{config.host}:{port}'.", + f"Starting proxy on node '{node_id}' listening on '{addr}'.", extra={"log_to_stderr": False}, ) - proxy = proxy or proxy_actor_class.options( - num_cpus=config.num_cpus, + return proxy or proxy_actor_class.options( + num_cpus=http_options.num_cpus, name=name, namespace=SERVE_NAMESPACE, lifetime="detached", @@ -164,18 +165,12 @@ def _get_or_create_proxy_actor( scheduling_strategy=NodeAffinitySchedulingStrategy(node_id, soft=False), enable_task_events=RAY_SERVE_ENABLE_TASK_EVENTS, ).remote( - config.host, - port, - config.root_path, - node_ip_address=node_ip_address, - node_id=node_id, - http_middlewares=config.middlewares, - request_timeout_s=config.request_timeout_s, - keep_alive_timeout_s=config.keep_alive_timeout_s, + http_options, grpc_options=grpc_options, + node_id=node_id, + node_ip_address=node_ip_address, logging_config=logging_config, ) - return proxy @property def actor_id(self) -> str: @@ -543,7 +538,7 @@ class ProxyStateManager: def __init__( self, - config: HTTPOptions, + http_options: HTTPOptions, head_node_id: str, cluster_node_info_cache: ClusterNodeInfoCache, logging_config: LoggingConfig, @@ -553,10 +548,7 @@ def __init__( timer: TimerBase = Timer(), ): self.logging_config = logging_config - if config is not None: - self._config = config - else: - self._config = HTTPOptions() + self._http_options = http_options or HTTPOptions() self._grpc_options = grpc_options or gRPCOptions() self._proxy_states: Dict[NodeId, ProxyState] = dict() self._proxy_restart_counts: Dict[NodeId, int] = dict() @@ -588,7 +580,7 @@ def is_ready_for_shutdown(self) -> bool: ) def get_config(self) -> HTTPOptions: - return self._config + return self._http_options def get_grpc_config(self) -> gRPCOptions: return self._grpc_options @@ -639,7 +631,7 @@ def update(self, proxy_nodes: Set[NodeId] = None) -> Set[str]: def _get_target_nodes(self, proxy_nodes) -> List[Tuple[str, str]]: """Return the list of (node_id, ip_address) to deploy HTTP and gRPC servers on.""" - location = self._config.location + location = self._http_options.location if location == DeploymentMode.NoServer: return [] @@ -679,7 +671,7 @@ def _start_proxy( port based on `TEST_WORKER_NODE_GRPC_PORT` env var. Passed all the required variables into the proxy actor wrapper class and return the proxy actor wrapper. """ - port = self._config.port + http_options = self._http_options grpc_options = self._grpc_options if ( @@ -690,7 +682,8 @@ def _start_proxy( f"`TEST_WORKER_NODE_HTTP_PORT` env var is set. " f"Using it for worker node {node_id}." ) - port = int(os.getenv("TEST_WORKER_NODE_HTTP_PORT")) + http_options = deepcopy(http_options) + http_options.port = int(os.getenv("TEST_WORKER_NODE_HTTP_PORT")) if ( node_id != self._head_node_id @@ -701,16 +694,16 @@ def _start_proxy( f"Using it for worker node {node_id}." f"{int(os.getenv('TEST_WORKER_NODE_GRPC_PORT'))}" ) + grpc_options = deepcopy(grpc_options) grpc_options.port = int(os.getenv("TEST_WORKER_NODE_GRPC_PORT")) return self._actor_proxy_wrapper_class( logging_config=self.logging_config, - config=self._config, + http_options=http_options, grpc_options=grpc_options, name=name, node_id=node_id, node_ip_address=node_ip_address, - port=port, proxy_actor_class=self._proxy_actor_class, ) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 089d92fa03f2d..a158921148867 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -50,10 +50,10 @@ GRPC_CONTEXT_ARG_NAME, HEALTH_CHECK_METHOD, RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE, + RAY_SERVE_METRICS_EXPORT_INTERVAL_MS, RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_PERIOD_S, RAY_SERVE_RUN_SYNC_IN_THREADPOOL, RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING, - RAY_SERVE_METRICS_EXPORT_INTERVAL_MS, RECONFIGURE_METHOD, SERVE_CONTROLLER_NAME, SERVE_LOGGER_NAME, @@ -344,6 +344,9 @@ def __init__( self._user_callable_initialized_lock = asyncio.Lock() self._initialization_latency: Optional[float] = None + # Flipped to `True` when health checks pass and `False` when they fail. May be + # used by replica subclass implementations. + self._healthy = False # Flipped to `True` once graceful shutdown is initiated. May be used by replica # subclass implementations. self._shutting_down = False @@ -801,13 +804,19 @@ async def perform_graceful_shutdown(self): await self._metrics_manager.shutdown() async def check_health(self): - # If there's no user-defined health check, nothing runs on the user code event - # loop and no future is returned. - f: Optional[ - concurrent.futures.Future - ] = self._user_callable_wrapper.call_user_health_check() - if f is not None: - await asyncio.wrap_future(f) + try: + # If there's no user-defined health check, nothing runs on the user code event + # loop and no future is returned. + f: Optional[ + concurrent.futures.Future + ] = self._user_callable_wrapper.call_user_health_check() + if f is not None: + await asyncio.wrap_future(f) + self._healthy = True + except Exception as e: + logger.warning("Replica health check failed.") + self._healthy = False + raise e from None class Replica(ReplicaBase): @@ -1601,7 +1610,7 @@ async def call_user_method( if request_metadata.is_streaming else None, ) - return await self._handle_user_method_result( + final_result = await self._handle_user_method_result( result, request_metadata, user_method_info, @@ -1610,6 +1619,10 @@ async def call_user_method( asgi_args=asgi_args, ) + if receive_task is not None and not receive_task.done(): + receive_task.cancel() + + return final_result except Exception: if ( request_metadata.is_http_request @@ -1625,11 +1638,23 @@ async def call_user_method( asgi_args, ) - raise - finally: if receive_task is not None and not receive_task.done(): receive_task.cancel() + raise + except asyncio.CancelledError: + user_method_info = self._get_user_method_info(request_metadata.call_method) + if receive_task is not None and not receive_task.done(): + # Do NOT cancel the receive task if the request has been + # cancelled, but the call is a batched call. This is + # because we cannot guarantee cancelling the batched + # call, so in the case that the call continues executing + # we should continue fetching data from the client. + if not hasattr(user_method_info.callable, "set_max_batch_size"): + receive_task.cancel() + + raise + @_run_on_user_code_event_loop async def call_destructor(self): """Explicitly call the `__del__` method of the user callable. diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index ca68fbc1d7ab2..ac1076b7bd437 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -459,10 +459,13 @@ def _run( validate_route_prefix(route_prefix) if _local_testing_mode: + if not isinstance(logging_config, LoggingConfig): + logging_config = LoggingConfig(**(logging_config or {})) + configure_component_logger( component_name="local_test", component_id="-", - logging_config=logging_config or LoggingConfig(), + logging_config=logging_config, stream_handler_only=True, ) built_app = build_app( diff --git a/python/ray/serve/batching.py b/python/ray/serve/batching.py index 3258713fefd47..456cbcf36835c 100644 --- a/python/ray/serve/batching.py +++ b/python/ray/serve/batching.py @@ -290,7 +290,13 @@ async def _process_batch(self, func: Callable) -> None: """Processes queued request batch.""" batch: List[_SingleRequest] = await self.wait_for_batch() - assert len(batch) > 0 + # Remove requests that have been cancelled from the batch. If + # all requests have been cancelled, simply return and wait for + # the next batch. + batch = [req for req in batch if not req.future.cancelled()] + if len(batch) == 0: + return + futures = [item.future for item in batch] # Most of the logic in the function should be wrapped in this try- diff --git a/python/ray/serve/grpc_util.py b/python/ray/serve/grpc_util.py index b0c132c40f18f..17da867f42178 100644 --- a/python/ray/serve/grpc_util.py +++ b/python/ray/serve/grpc_util.py @@ -18,13 +18,13 @@ def __init__(self, grpc_context: grpc._cython.cygrpc._ServicerContext): self._auth_context = grpc_context.auth_context() self._code = grpc_context.code() self._details = grpc_context.details() - self._invocation_metadata = [ + self._invocation_metadata = [ # noqa: C416 (key, value) for key, value in grpc_context.invocation_metadata() ] self._peer = grpc_context.peer() self._peer_identities = grpc_context.peer_identities() self._peer_identity_key = grpc_context.peer_identity_key() - self._trailing_metadata = [ + self._trailing_metadata = [ # noqa: C416 (key, value) for key, value in grpc_context.trailing_metadata() ] self._compression = None diff --git a/python/ray/serve/schema.py b/python/ray/serve/schema.py index ed696219b0c33..40a74644e64ca 100644 --- a/python/ray/serve/schema.py +++ b/python/ray/serve/schema.py @@ -952,6 +952,9 @@ class DeploymentDetails(BaseModel, extra=Extra.forbid, frozen=True): "number for other deployments." ) ) + required_resources: Dict = Field( + description="The resources required per replica of this deployment." + ) replicas: List[ReplicaDetails] = Field( description="Details about the live replicas of this deployment." ) diff --git a/python/ray/serve/tests/BUILD b/python/ray/serve/tests/BUILD index 9f3208084538b..3189d0ef9daea 100644 --- a/python/ray/serve/tests/BUILD +++ b/python/ray/serve/tests/BUILD @@ -66,7 +66,6 @@ py_test_module_list( "test_max_replicas_per_node.py", "test_multiplex.py", "test_proxy_response_generator.py", - "test_proxy_state.py", "test_ray_client.py", "test_replica_placement_group.py", "test_request_timeout.py", diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 05862f033b091..2f0c13873c24e 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -844,7 +844,7 @@ def test_status_constructor_error(serve_instance): @serve.deployment class A: def __init__(self): - 1 / 0 + _ = 1 / 0 serve._run(A.bind(), _blocking=False) diff --git a/python/ray/serve/tests/test_callback.py b/python/ray/serve/tests/test_callback.py index 4f6a98c8b358e..926537ef0cd0a 100644 --- a/python/ray/serve/tests/test_callback.py +++ b/python/ray/serve/tests/test_callback.py @@ -12,6 +12,7 @@ from ray import serve from ray._private.test_utils import wait_for_condition from ray.exceptions import RayActorError +from ray.serve.config import HTTPOptions from ray.serve._private.utils import call_function_from_import_path from ray.serve.context import _get_global_client from ray.serve.schema import LoggingConfig, ProxyStatus, ServeInstanceDetails @@ -159,9 +160,7 @@ def test_callback_fail(ray_instance): actor_def = ray.serve._private.proxy.ProxyActor handle = actor_def.remote( - host="http_proxy", - port=123, - root_path="/", + http_options=HTTPOptions(host="http_proxy", root_path="/", port=123), node_ip_address="127.0.0.1", node_id="123", logging_config=LoggingConfig(), @@ -172,7 +171,7 @@ def test_callback_fail(ray_instance): actor_def = ray.serve._private.controller.ServeController handle = actor_def.remote( - http_config={}, + http_options=HTTPOptions(), global_logging_config=LoggingConfig(), ) with pytest.raises(RayActorError, match="cannot be imported"): @@ -193,9 +192,7 @@ def test_http_proxy_return_aribitary_objects(ray_instance): actor_def = ray.serve._private.proxy.ProxyActor handle = actor_def.remote( - host="http_proxy", - port=123, - root_path="/", + http_options=HTTPOptions(host="http_proxy", root_path="/", port=123), node_ip_address="127.0.0.1", node_id="123", logging_config=LoggingConfig(), diff --git a/python/ray/serve/tests/test_config_files/basic_multi_http.yaml b/python/ray/serve/tests/test_config_files/basic_multi_http.yaml index 3476f019a3e51..0d30054bd4375 100644 --- a/python/ray/serve/tests/test_config_files/basic_multi_http.yaml +++ b/python/ray/serve/tests/test_config_files/basic_multi_http.yaml @@ -3,4 +3,4 @@ http_options: applications: - name: "app1" - import_path: ray.serve.tests.test_config_files.test_dag.basic_dag.DagNode \ No newline at end of file + import_path: ray.serve.tests.test_config_files.test_dag.basic_dag.DagNode diff --git a/python/ray/serve/tests/test_config_files/duplicate_app_names.yaml b/python/ray/serve/tests/test_config_files/duplicate_app_names.yaml index e824205a5a742..ed31cefb18ab3 100644 --- a/python/ray/serve/tests/test_config_files/duplicate_app_names.yaml +++ b/python/ray/serve/tests/test_config_files/duplicate_app_names.yaml @@ -5,4 +5,4 @@ applications: - name: app1 route_prefix: /b - import_path: ray.serve.tests.test_config_files.test_dag.basic_dag.DagNode \ No newline at end of file + import_path: ray.serve.tests.test_config_files.test_dag.basic_dag.DagNode diff --git a/python/ray/serve/tests/test_config_files/duplicate_app_routes.yaml b/python/ray/serve/tests/test_config_files/duplicate_app_routes.yaml index f22c124b268bf..f7b20e90cb1eb 100644 --- a/python/ray/serve/tests/test_config_files/duplicate_app_routes.yaml +++ b/python/ray/serve/tests/test_config_files/duplicate_app_routes.yaml @@ -5,4 +5,4 @@ applications: - name: app2 route_prefix: /alice - import_path: ray.serve.tests.test_config_files.test_dag.basic_dag.DagNode \ No newline at end of file + import_path: ray.serve.tests.test_config_files.test_dag.basic_dag.DagNode diff --git a/python/ray/serve/tests/test_config_files/fail.py b/python/ray/serve/tests/test_config_files/fail.py index 4b69ed6aed897..f55cb34286104 100644 --- a/python/ray/serve/tests/test_config_files/fail.py +++ b/python/ray/serve/tests/test_config_files/fail.py @@ -4,7 +4,7 @@ @serve.deployment class A: def __init__(self): - 1 / 0 + _ = 1 / 0 node = A.bind() diff --git a/python/ray/serve/tests/test_config_files/fail_2.py b/python/ray/serve/tests/test_config_files/fail_2.py index 2e95aa93d98f8..897a9ad0f3a5a 100644 --- a/python/ray/serve/tests/test_config_files/fail_2.py +++ b/python/ray/serve/tests/test_config_files/fail_2.py @@ -7,7 +7,7 @@ class A: def __init__(self): time.sleep(5) - 1 / 0 + _ = 1 / 0 node = A.bind() diff --git a/python/ray/serve/tests/test_config_files/import_error.py b/python/ray/serve/tests/test_config_files/import_error.py index 40f9b835493f8..15ea50c141124 100644 --- a/python/ray/serve/tests/test_config_files/import_error.py +++ b/python/ray/serve/tests/test_config_files/import_error.py @@ -1,6 +1,6 @@ from ray import serve -1 / 0 +_ = 1 / 0 @serve.deployment(ray_actor_options={"num_cpus": 0.1}) diff --git a/python/ray/serve/tests/test_controller.py b/python/ray/serve/tests/test_controller.py index 4c83b50dba44c..285ed0ca950b8 100644 --- a/python/ray/serve/tests/test_controller.py +++ b/python/ray/serve/tests/test_controller.py @@ -183,6 +183,7 @@ def autoscaling_app(): }, }, "target_num_replicas": 1, + "required_resources": {"CPU": 1}, "replicas": [ { "node_id": node_id, diff --git a/python/ray/serve/tests/test_fastapi.py b/python/ray/serve/tests/test_fastapi.py index 6c4e0b7ab4c8c..e9edb7b944f01 100644 --- a/python/ray/serve/tests/test_fastapi.py +++ b/python/ray/serve/tests/test_fastapi.py @@ -354,7 +354,7 @@ def test_fastapi_init_lifespan_should_not_shutdown(serve_instance): @app.on_event("shutdown") async def shutdown(): - 1 / 0 + _ = 1 / 0 @serve.deployment @serve.ingress(app) diff --git a/python/ray/serve/tests/test_http_routes.py b/python/ray/serve/tests/test_http_routes.py index 68e8b3acdf0e9..9840f52cc2024 100644 --- a/python/ray/serve/tests/test_http_routes.py +++ b/python/ray/serve/tests/test_http_routes.py @@ -215,7 +215,7 @@ def redirect_twice(self, request: Request): def test_default_error_handling(serve_instance): @serve.deployment def f(): - 1 / 0 + _ = 1 / 0 serve.run(f.bind()) r = requests.get("http://localhost:8000/f") diff --git a/python/ray/serve/tests/test_logging.py b/python/ray/serve/tests/test_logging.py index 4e16d074c0124..125523ddcfa12 100644 --- a/python/ray/serve/tests/test_logging.py +++ b/python/ray/serve/tests/test_logging.py @@ -937,7 +937,7 @@ def test_stream_to_logger(): # Calling non-existing attribute on the StreamToLogger should still raise error. with pytest.raises(AttributeError): - stream_to_logger.i_dont_exist + _ = stream_to_logger.i_dont_exist @pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.") diff --git a/python/ray/serve/tests/test_metrics.py b/python/ray/serve/tests/test_metrics.py index 6f64666a96ba7..93ac4ff828d33 100644 --- a/python/ray/serve/tests/test_metrics.py +++ b/python/ray/serve/tests/test_metrics.py @@ -487,7 +487,7 @@ def test_proxy_metrics_fields_not_found(serve_start_shutdown): # Should generate 404 responses broken_url = "http://127.0.0.1:8000/fake_route" - requests.get(broken_url).text + _ = requests.get(broken_url).text print("Sent requests to broken URL.") # Ping gRPC proxy for not existing application. @@ -540,7 +540,7 @@ def f(*args): # Deployment should generate divide-by-zero errors correct_url = "http://127.0.0.1:8000/real_route" - requests.get(correct_url).text + _ = requests.get(correct_url).text print("Sent requests to correct URL.") # Ping gPRC proxy for broken app diff --git a/python/ray/serve/tests/test_proxy.py b/python/ray/serve/tests/test_proxy.py index da86d7e2cb0f5..db678382b0889 100644 --- a/python/ray/serve/tests/test_proxy.py +++ b/python/ray/serve/tests/test_proxy.py @@ -29,7 +29,7 @@ def test_default_keep_alive_timeout_s(self, ray_shutdown): serve.start() proxy_actor = self.get_proxy_actor() assert ( - ray.get(proxy_actor._uvicorn_keep_alive.remote()) + ray.get(proxy_actor._get_http_options.remote()).keep_alive_timeout_s == DEFAULT_UVICORN_KEEP_ALIVE_TIMEOUT_S ) @@ -42,7 +42,10 @@ def test_set_keep_alive_timeout_in_http_configs(self, ray_shutdown): keep_alive_timeout_s = 222 serve.start(http_options={"keep_alive_timeout_s": keep_alive_timeout_s}) proxy_actor = self.get_proxy_actor() - assert ray.get(proxy_actor._uvicorn_keep_alive.remote()) == keep_alive_timeout_s + assert ( + ray.get(proxy_actor._get_http_options.remote()).keep_alive_timeout_s + == keep_alive_timeout_s + ) @pytest.mark.parametrize( "ray_instance", @@ -59,7 +62,9 @@ def test_set_keep_alive_timeout_in_env(self, ray_instance, ray_shutdown): """ serve.start() proxy_actor = self.get_proxy_actor() - assert ray.get(proxy_actor._uvicorn_keep_alive.remote()) == 333 + assert ( + ray.get(proxy_actor._get_http_options.remote()).keep_alive_timeout_s == 333 + ) @pytest.mark.parametrize( "ray_instance", @@ -79,7 +84,9 @@ def test_set_timeout_keep_alive_in_both_config_and_env( keep_alive_timeout_s = 222 serve.start(http_options={"keep_alive_timeout_s": keep_alive_timeout_s}) proxy_actor = self.get_proxy_actor() - assert ray.get(proxy_actor._uvicorn_keep_alive.remote()) == 333 + assert ( + ray.get(proxy_actor._get_http_options.remote()).keep_alive_timeout_s == 333 + ) if __name__ == "__main__": diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index 62b698b35401d..cae25eb9854bd 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -442,7 +442,7 @@ def hello(): @pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows") def test_http_proxy_fail_loudly(ray_shutdown): # Test that if the http server fail to start, serve.start should fail. - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): serve.start(http_options={"host": "bad.ip.address"}) diff --git a/python/ray/serve/tests/unit/test_batching.py b/python/ray/serve/tests/unit/test_batching.py index 88ca432cea5cd..05ab17df8670a 100644 --- a/python/ray/serve/tests/unit/test_batching.py +++ b/python/ray/serve/tests/unit/test_batching.py @@ -126,14 +126,14 @@ async def test_batch_size_one_long_timeout(use_class): @serve.batch(max_batch_size=1, batch_wait_timeout_s=1000) async def long_timeout(requests): if "raise" in requests: - 1 / 0 + _ = 1 / 0 return requests class LongTimeout: @serve.batch(max_batch_size=1, batch_wait_timeout_s=1000) async def long_timeout(self, requests): if "raise" in requests: - 1 / 0 + _ = 1 / 0 return requests cls = LongTimeout() @@ -158,7 +158,7 @@ async def test_batch_size_multiple_zero_timeout(use_class): async def zero_timeout(requests): await block_execution_event.wait() if "raise" in requests: - 1 / 0 + _ = 1 / 0 return requests class ZeroTimeout: @@ -166,7 +166,7 @@ class ZeroTimeout: async def zero_timeout(self, requests): await block_execution_event.wait() if "raise" in requests: - 1 / 0 + _ = 1 / 0 return requests cls = ZeroTimeout() @@ -262,14 +262,14 @@ async def test_batch_size_multiple_long_timeout(use_class): @serve.batch(max_batch_size=3, batch_wait_timeout_s=1000) async def long_timeout(requests): if "raise" in requests: - 1 / 0 + _ = 1 / 0 return requests class LongTimeout: @serve.batch(max_batch_size=3, batch_wait_timeout_s=1000) async def long_timeout(self, requests): if "raise" in requests: - 1 / 0 + _ = 1 / 0 return requests cls = LongTimeout() diff --git a/python/ray/serve/tests/unit/test_config.py b/python/ray/serve/tests/unit/test_config.py index 0d031bbd7485f..1c404f821581f 100644 --- a/python/ray/serve/tests/unit/test_config.py +++ b/python/ray/serve/tests/unit/test_config.py @@ -609,14 +609,14 @@ def test_grpc_options(): grpc_servicer_functions = ["fake.service.that.does.not.exist"] with pytest.raises(ModuleNotFoundError) as exception: grpc_options = gRPCOptions(grpc_servicer_functions=grpc_servicer_functions) - grpc_options.grpc_servicer_func_callable + _ = grpc_options.grpc_servicer_func_callable assert "can't be imported!" in str(exception) # Not callable should raise ValueError. grpc_servicer_functions = ["ray.serve._private.constants.DEFAULT_HTTP_PORT"] with pytest.raises(ValueError) as exception: grpc_options = gRPCOptions(grpc_servicer_functions=grpc_servicer_functions) - grpc_options.grpc_servicer_func_callable + _ = grpc_options.grpc_servicer_func_callable assert "is not a callable function!" in str(exception) diff --git a/python/ray/serve/tests/unit/test_deployment_scheduler.py b/python/ray/serve/tests/unit/test_deployment_scheduler.py index e3b4e2db50401..e794861e711fa 100644 --- a/python/ray/serve/tests/unit/test_deployment_scheduler.py +++ b/python/ray/serve/tests/unit/test_deployment_scheduler.py @@ -895,7 +895,7 @@ def test_placement_groups(self): ), ) - ray.util.placement_group + _ = ray.util.placement_group scheduler.on_deployment_created(d_id1, SpreadDeploymentSchedulingPolicy()) scheduler.on_deployment_created(d_id2, SpreadDeploymentSchedulingPolicy()) diff --git a/python/ray/serve/tests/unit/test_grpc_util.py b/python/ray/serve/tests/unit/test_grpc_util.py index 2ec0b5a52f1c5..65547f9a01966 100644 --- a/python/ray/serve/tests/unit/test_grpc_util.py +++ b/python/ray/serve/tests/unit/test_grpc_util.py @@ -1,5 +1,6 @@ import pickle from typing import Callable +from unittest.mock import Mock import grpc import pytest @@ -8,7 +9,6 @@ from ray import cloudpickle from ray.serve._private.default_impl import add_grpc_address from ray.serve._private.grpc_util import ( - DummyServicer, gRPCGenericServer, ) from ray.serve._private.test_utils import FakeGrpcContext @@ -30,20 +30,6 @@ def foo() -> bytes: return foo -def test_dummy_servicer_can_take_any_methods(): - """Test an instance of DummyServicer can be called with any method name without - error. - - When dummy_servicer is called with any custom defined methods, it won't raise error. - """ - dummy_servicer = DummyServicer() - dummy_servicer.foo - dummy_servicer.bar - dummy_servicer.baz - dummy_servicer.my_method - dummy_servicer.Predict - - def test_grpc_server(): """Test `gRPCGenericServer` did the correct overrides. @@ -68,7 +54,7 @@ def add_test_servicer_to_server(servicer, server): server.add_generic_rpc_handlers((generic_handler,)) grpc_server = gRPCGenericServer(fake_service_handler_factory) - dummy_servicer = DummyServicer() + dummy_servicer = Mock() # Ensure `generic_rpc_handlers` is not populated before calling # the add_servicer_to_server function. diff --git a/python/ray/serve/tests/unit/test_local_testing_mode.py b/python/ray/serve/tests/unit/test_local_testing_mode.py index d8d47d322f68a..79a6c98cfc9ad 100644 --- a/python/ray/serve/tests/unit/test_local_testing_mode.py +++ b/python/ray/serve/tests/unit/test_local_testing_mode.py @@ -1,7 +1,8 @@ import sys import pytest - +import logging +from ray.serve._private.constants import SERVE_LOGGER_NAME from ray import serve from ray.serve.handle import DeploymentHandle @@ -87,5 +88,27 @@ async def __call__(self): h.remote().result() +def test_dictionary_logging_config_with_local_mode(): + """Test that the logging config can be passed as a dictionary. + + See: https://github.com/ray-project/ray/issues/50052 + """ + + @serve.deployment + class MyApp: + def __call__(self): + logger = logging.getLogger(SERVE_LOGGER_NAME) + return logger.level + + app = MyApp.bind() + logging_config = {"log_level": "WARNING"} + + # This should not raise exception. + h = serve.run(app, logging_config=logging_config, _local_testing_mode=True) + + # The logger should be setup with WARNING level. + assert h.remote().result() == logging.WARNING + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_proxy_state.py b/python/ray/serve/tests/unit/test_proxy_state.py similarity index 98% rename from python/ray/serve/tests/test_proxy_state.py rename to python/ray/serve/tests/unit/test_proxy_state.py index 70dccde7cf952..949940efff2eb 100644 --- a/python/ray/serve/tests/test_proxy_state.py +++ b/python/ray/serve/tests/unit/test_proxy_state.py @@ -90,7 +90,7 @@ def _create_proxy_state_manager( ) -> (ProxyStateManager, ClusterNodeInfoCache): return ( ProxyStateManager( - config=http_options, + http_options=http_options, head_node_id=head_node_id, cluster_node_info_cache=cluster_node_info_cache, logging_config=LoggingConfig(), @@ -621,9 +621,7 @@ def test_proxy_state_manager_timing_out_on_start(number_of_worker_nodes, all_nod proxy_state._actor_proxy_wrapper.is_ready_response = False # Capture current proxy states (prior to updating) - prev_proxy_states = { - node_id: state for node_id, state in proxy_state_manager._proxy_states.items() - } + prev_proxy_states = dict(proxy_state_manager._proxy_states) # Trigger PSM to reconcile proxy_state_manager.update(proxy_nodes=node_ids) @@ -644,9 +642,7 @@ def test_proxy_state_manager_timing_out_on_start(number_of_worker_nodes, all_nod proxy_state._actor_proxy_wrapper.is_ready_response = True # Capture current proxy states again (prior to updating) - prev_proxy_states = { - node_id: state for node_id, state in proxy_state_manager._proxy_states.items() - } + prev_proxy_states = dict(proxy_state_manager._proxy_states) # Trigger PSM to reconcile proxy_state_manager.update(proxy_nodes=node_ids) diff --git a/python/ray/serve/tests/unit/test_schema.py b/python/ray/serve/tests/unit/test_schema.py index c7bc18788c7d7..935667d2c621b 100644 --- a/python/ray/serve/tests/unit/test_schema.py +++ b/python/ray/serve/tests/unit/test_schema.py @@ -808,6 +808,7 @@ def test_serve_instance_details_is_json_serializable(): }, }, "target_num_replicas": 0, + "required_resources": {"CPU": 1}, "replicas": [], } }, @@ -841,6 +842,7 @@ def test_serve_instance_details_is_json_serializable(): "autoscaling_config": {}, }, "target_num_replicas": 0, + "required_resources": {"CPU": 1}, "replicas": [], } }, diff --git a/python/ray/tests/accelerators/test_accelerators.py b/python/ray/tests/accelerators/test_accelerators.py index fde8483b3e78a..36ec28c2ede50 100644 --- a/python/ray/tests/accelerators/test_accelerators.py +++ b/python/ray/tests/accelerators/test_accelerators.py @@ -13,7 +13,7 @@ def test_accelerators(): AttributeError, match="module 'ray.util.accelerators' has no attribute 'NVIDIA_INVALID'", ): - accelerators.NVIDIA_INVALID + _ = accelerators.NVIDIA_INVALID with pytest.warns(RayDeprecationWarning): assert accelerators.NVIDIA_TESLA_A100 == "A100" diff --git a/python/ray/tests/accelerators/test_amd_gpu.py b/python/ray/tests/accelerators/test_amd_gpu.py index 60a5204939f57..7e96c160174fb 100644 --- a/python/ray/tests/accelerators/test_amd_gpu.py +++ b/python/ray/tests/accelerators/test_amd_gpu.py @@ -18,7 +18,7 @@ def test_visible_amd_gpu_ids(mock_get_num_accelerators, monkeypatch, shutdown_on # we call get_accelerator_manager_for_resource del get_accelerator_manager_for_resource._resource_name_to_accelerator_manager ray.init() - mock_get_num_accelerators.called + _ = mock_get_num_accelerators.called assert ray.available_resources()["GPU"] == 3 @@ -28,7 +28,7 @@ def test_visible_amd_gpu_ids(mock_get_num_accelerators, monkeypatch, shutdown_on ) def test_visible_amd_gpu_type(mock_get_amd_device_ids, shutdown_only): ray.init() - mock_get_amd_device_ids.called + _ = mock_get_amd_device_ids.called assert ( AMDGPUAcceleratorManager.get_current_node_accelerator_type() == "AMD-Instinct-MI300X-OAM" @@ -41,7 +41,7 @@ def test_visible_amd_gpu_type(mock_get_amd_device_ids, shutdown_only): ) def test_visible_amd_gpu_type_bad_device_id(mock_get_num_accelerators, shutdown_only): ray.init() - mock_get_num_accelerators.called + _ = mock_get_num_accelerators.called assert AMDGPUAcceleratorManager.get_current_node_accelerator_type() is None diff --git a/python/ray/tests/accelerators/test_hpu.py b/python/ray/tests/accelerators/test_hpu.py index e0e08fd5c9827..68a61801c6483 100644 --- a/python/ray/tests/accelerators/test_hpu.py +++ b/python/ray/tests/accelerators/test_hpu.py @@ -26,7 +26,7 @@ def test_auto_detected_more_than_visible( # Test more hpus are detected than visible. monkeypatch.setenv("HABANA_VISIBLE_MODULES", "0,1,2") ray.init() - mock_get_num_accelerators.called + _ = mock_get_num_accelerators.called assert ray.available_resources()["HPU"] == 3 @@ -37,7 +37,7 @@ def test_auto_detected_more_than_visible( def test_auto_detect_resources(mock_get_num_accelerators, shutdown_only): # Test that ray node resources are filled with auto detected count. ray.init() - mock_get_num_accelerators.called + _ = mock_get_num_accelerators.called assert ray.available_resources()["HPU"] == 2 diff --git a/python/ray/tests/accelerators/test_neuron.py b/python/ray/tests/accelerators/test_neuron.py index 88fadd4079b32..d81dc14a4424e 100644 --- a/python/ray/tests/accelerators/test_neuron.py +++ b/python/ray/tests/accelerators/test_neuron.py @@ -25,7 +25,7 @@ def test_auto_detected_more_than_visible( # Test more neuron_cores are detected than visible. monkeypatch.setenv("NEURON_RT_VISIBLE_CORES", "0,1,2") ray.init() - mock_get_num_accelerators.called + _ = mock_get_num_accelerators.called assert ray.available_resources()["neuron_cores"] == 3 @@ -36,7 +36,7 @@ def test_auto_detected_more_than_visible( def test_auto_detect_resources(mock_get_num_accelerators, shutdown_only): # Test that ray node resources are filled with auto detected count. ray.init() - mock_get_num_accelerators.called + _ = mock_get_num_accelerators.called assert ray.available_resources()["neuron_cores"] == 2 diff --git a/python/ray/tests/additional_property.yaml b/python/ray/tests/additional_property.yaml index 640d534ff3fed..2e3dd4f5628d7 100644 --- a/python/ray/tests/additional_property.yaml +++ b/python/ray/tests/additional_property.yaml @@ -26,4 +26,4 @@ setup_commands: # # Command to start ray on the head node. You don't need to change this. head_start_ray_commands: - ray stop - - ray start --head --port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml \ No newline at end of file + - ray start --head --port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml diff --git a/python/ray/tests/horovod/BUILD b/python/ray/tests/horovod/BUILD index 6a86aac7f25c2..7bf6b6bf053bb 100644 --- a/python/ray/tests/horovod/BUILD +++ b/python/ray/tests/horovod/BUILD @@ -5,5 +5,3 @@ py_test( deps = ["//:ray_lib"], tags = ["team:ml", "compat", "exclusive", "manual"] ) - - diff --git a/python/ray/tests/test_asyncio.py b/python/ray/tests/test_asyncio.py index bd0cd6bcaaf8c..290804e9a2710 100644 --- a/python/ray/tests/test_asyncio.py +++ b/python/ray/tests/test_asyncio.py @@ -130,7 +130,7 @@ def task(): @ray.remote def task_throws(): - 1 / 0 + _ = 1 / 0 with pytest.raises(ray.exceptions.RayTaskError): await task_throws.remote().as_future() @@ -148,7 +148,7 @@ def big_object(self): return "a" * (str_len) def throw_error(self): - 1 / 0 + _ = 1 / 0 actor = Actor.remote() diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 5437ab4fb1fa0..b146c7043008a 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -888,7 +888,7 @@ def __init__(self): self.val = ray.put(0) def method(self): - f + _ = f f = Foo() ray.put(f) diff --git a/python/ray/tests/test_cli.py b/python/ray/tests/test_cli.py index eec7b71c0c4c0..9f826b26a1aac 100644 --- a/python/ray/tests/test_cli.py +++ b/python/ray/tests/test_cli.py @@ -919,7 +919,7 @@ def test_ray_status(shutdown_only, monkeypatch, enable_v2): def output_ready(): result = runner.invoke(scripts.status) - result.stdout + _ = result.stdout if not result.exception and "memory" in result.output: return True raise RuntimeError( @@ -967,7 +967,7 @@ def test_ray_status_multinode(ray_start_cluster, enable_v2): def output_ready(): result = runner.invoke(scripts.status) - result.stdout + _ = result.stdout if not result.exception and "memory" in result.output: return True raise RuntimeError( diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index f248658d8e3a8..a48a10f5d38d9 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -878,7 +878,7 @@ def child_func(self): assert ray.get(handle.child_func.remote()) == 42 with pytest.raises(AttributeError): # We should raise attribute error when accessing a non-existent func - SomeClass.nonexistent_func + _ = SomeClass.nonexistent_func def test_serialize_client_actor_handle(call_ray_start_shared): diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 3185cacc1a99f..0136a47b96f65 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -518,7 +518,7 @@ def test_export_large_objects(ray_start_regular, error_pubsub): @ray.remote def f(): - large_object + _ = large_object # Invoke the function so that the definition is exported. f.remote() @@ -531,7 +531,7 @@ def f(): @ray.remote class Foo: def __init__(self): - large_object + _ = large_object Foo.remote() diff --git a/python/ray/tests/test_logging_2.py b/python/ray/tests/test_logging_2.py index dd23f831537db..6201c45ef20b3 100644 --- a/python/ray/tests/test_logging_2.py +++ b/python/ray/tests/test_logging_2.py @@ -8,7 +8,7 @@ from ray._private.ray_logging.filters import CoreContextFilter from ray._private.ray_logging.formatters import JSONFormatter, TextFormatter -from ray.job_config import LoggingConfig +from ray._private.ray_logging.logging_config import LoggingConfig from ray._private.test_utils import run_string_as_driver @@ -188,6 +188,26 @@ def test_record_with_flatten_keys_valid_dict(self, shutdown_only): assert len(record_dict) == len(should_exist) assert "exc_text" not in record_dict + def test_record_with_valid_additional_log_standard_attrs(self, shutdown_only): + formatter = JSONFormatter() + formatter.set_additional_log_standard_attrs(["name"]) + record = logging.makeLogRecord({}) + formatted = formatter.format(record) + + record_dict = json.loads(formatted) + should_exist = [ + "asctime", + "levelname", + "message", + "filename", + "lineno", + "timestamp_ns", + "name", + ] + for key in should_exist: + assert key in record_dict + assert len(record_dict) == len(should_exist) + class TestTextFormatter: def test_record_with_user_provided_context(self): @@ -211,12 +231,24 @@ def test_record_with_exception(self): for s in ["INFO", "Test message", "test.py:1000", "--"]: assert s in formatted + def test_record_with_valid_additional_log_standard_attrs(self, shutdown_only): + formatter = TextFormatter() + formatter.set_additional_log_standard_attrs(["name"]) + record = logging.makeLogRecord({}) + formatted = formatter.format(record) + assert "name=" in formatted + def test_invalid_encoding(): with pytest.raises(ValueError): LoggingConfig(encoding="INVALID") +def test_invalid_additional_log_standard_attrs(): + with pytest.raises(ValueError): + LoggingConfig(additional_log_standard_attrs=["invalid"]) + + class TestTextModeE2E: def test_text_mode_task(self, shutdown_only): script = """ @@ -224,7 +256,7 @@ def test_text_mode_task(self, shutdown_only): import logging ray.init( - logging_config=ray.LoggingConfig(encoding="TEXT") + logging_config=ray.LoggingConfig(encoding="TEXT", additional_log_standard_attrs=["name"]) ) @ray.remote @@ -244,6 +276,7 @@ def f(): "task_id", "INFO", "This is a Ray task", + "name=", ] for s in should_exist: assert s in stderr @@ -255,7 +288,7 @@ def test_text_mode_actor(self, shutdown_only): import logging ray.init( - logging_config=ray.LoggingConfig(encoding="TEXT") + logging_config=ray.LoggingConfig(encoding="TEXT", additional_log_standard_attrs=["name"]) ) @ray.remote @@ -280,6 +313,7 @@ def print_message(self): "task_id", "INFO", "This is a Ray actor", + "name=", ] for s in should_exist: assert s in stderr @@ -290,7 +324,7 @@ def test_text_mode_driver(self, shutdown_only): import logging ray.init( - logging_config=ray.LoggingConfig(encoding="TEXT") + logging_config=ray.LoggingConfig(encoding="TEXT", additional_log_standard_attrs=["name"]) ) logger = logging.getLogger() @@ -304,6 +338,7 @@ def test_text_mode_driver(self, shutdown_only): "node_id", "INFO", "This is a Ray driver", + "name=", ] for s in should_exist: assert s in stderr diff --git a/python/ray/tests/test_memory_pressure.py b/python/ray/tests/test_memory_pressure.py index 1a6f82a02896a..9dbc7a72cccfb 100644 --- a/python/ray/tests/test_memory_pressure.py +++ b/python/ray/tests/test_memory_pressure.py @@ -42,7 +42,6 @@ def get_local_state_client(): port = int(node["NodeManagerPort"]) runtime_env_agent_port = int(node["RuntimeEnvAgentPort"]) client.register_raylet_client(node_id, ip, port, runtime_env_agent_port) - client.register_agent_client(node_id, ip, port) return client diff --git a/python/ray/tests/test_runtime_env_validation_bad_2_schema.json b/python/ray/tests/test_runtime_env_validation_bad_2_schema.json index 059b8ae9ecb15..93aca677e9c02 100644 --- a/python/ray/tests/test_runtime_env_validation_bad_2_schema.json +++ b/python/ray/tests/test_runtime_env_validation_bad_2_schema.json @@ -2,4 +2,3 @@ "$schema": "http://json-schema.org/draft-07/schema#", "$id": "http://github.com/ray-project/ray/runtime_env/working_dir_schema.json", "type": "string" - diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py index e58b0d4d4a1f5..ac38f756766ed 100644 --- a/python/ray/tests/test_serialization.py +++ b/python/ray/tests/test_serialization.py @@ -628,7 +628,7 @@ def test_numpy_ufunc(ray_start_shared_local_modes): @ray.remote def f(): # add reference to the numpy ufunc - log + _ = log ray.get(f.remote()) @@ -747,7 +747,7 @@ def test(): @ray.remote def f(): - ref + _ = ref with pytest.raises(ray.exceptions.OufOfBandObjectRefSerializationException): ray.get(f.remote()) @@ -774,7 +774,7 @@ def test(): @ray.remote def f(): - ref + _ = ref ray.get(f.remote()) diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 11f9f620cabad..156af8ac3d199 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -1665,8 +1665,6 @@ def get_addr(): ) wait_for_condition(lambda: get_addr() is not None) - ip, http_port, grpc_port = json.loads(get_addr()) - client.register_agent_client(node_id, ip, grpc_port) result = await client.get_runtime_envs_info(node_id) assert isinstance(result, GetRuntimeEnvsInfoReply) @@ -1835,8 +1833,6 @@ def get_addr(): ) wait_for_condition(lambda: get_addr() is not None) - ip, http_port, grpc_port = json.loads(get_addr()) - client.register_agent_client(node_id, ip, grpc_port) @ray.remote class Actor: diff --git a/python/ray/tests/test_state_api_log.py b/python/ray/tests/test_state_api_log.py index 2aefa941b7137..d9031ced47638 100644 --- a/python/ray/tests/test_state_api_log.py +++ b/python/ray/tests/test_state_api_log.py @@ -4,7 +4,6 @@ import asyncio from typing import List import urllib -import re from unittest.mock import MagicMock, AsyncMock import pytest @@ -45,7 +44,7 @@ from ray.dashboard.tests.conftest import * # noqa from ray.util.state import get_log, list_logs, list_nodes, list_workers from ray.util.state.common import GetLogOptions -from ray.util.state.exception import DataSourceUnavailable, RayStateApiException +from ray.util.state.exception import RayStateApiException from ray.util.state.state_manager import StateDataSourceClient @@ -441,16 +440,15 @@ async def generate_logs_stream(num_chunks: int): async def test_logs_manager_list_logs(logs_manager): logs_client = logs_manager.data_source_client - logs_client.get_all_registered_log_agent_ids = MagicMock() - logs_client.get_all_registered_log_agent_ids.return_value = ["1", "2"] + async def my_list_logs(node_id, glob_filter, timeout): + if node_id != "2": + raise ValueError("Agent for node id: 3 doesn't exist.") + return generate_list_logs(["gcs_server.out"]) - logs_client.list_logs.side_effect = [ - generate_list_logs(["gcs_server.out"]), - DataSourceUnavailable(), - ] + logs_client.list_logs = AsyncMock() + logs_client.list_logs.side_effect = my_list_logs - # Unregistered node id should raise a DataSourceUnavailable. - with pytest.raises(DataSourceUnavailable): + with pytest.raises(ValueError): result = await logs_manager.list_logs( node_id="3", timeout=30, glob_filter="*gcs*" ) @@ -459,12 +457,8 @@ async def test_logs_manager_list_logs(logs_manager): assert len(result) == 1 assert result["gcs_server"] == ["gcs_server.out"] assert result["raylet"] == [] - logs_client.get_all_registered_log_agent_ids.assert_called() - logs_client.list_logs.assert_awaited_with("2", "*gcs*", timeout=30) - # The second call raises DataSourceUnavailable, which will - # return DataSourceUnavailable to the caller. - with pytest.raises(DataSourceUnavailable): + with pytest.raises(ValueError): result = await logs_manager.list_logs( node_id="1", timeout=30, glob_filter="*gcs*" ) @@ -477,8 +471,6 @@ async def test_logs_manager_resolve_file(logs_manager): Test filename is given. """ logs_client = logs_manager.data_source_client - logs_client.get_all_registered_log_agent_ids = MagicMock() - logs_client.get_all_registered_log_agent_ids.return_value = [node_id.hex()] expected_filename = "filename" res = await logs_manager.resolve_filename( node_id=node_id.hex(), @@ -699,8 +691,6 @@ async def test_logs_manager_stream_log(logs_manager): NUM_LOG_CHUNKS = 10 logs_client = logs_manager.data_source_client - logs_client.get_all_registered_log_agent_ids = MagicMock() - logs_client.get_all_registered_log_agent_ids.return_value = ["1", "2"] logs_client.ip_to_node_id = MagicMock() logs_client.stream_log.return_value = generate_logs_stream(NUM_LOG_CHUNKS) @@ -771,8 +761,6 @@ async def test_logs_manager_keepalive_no_timeout(logs_manager): NUM_LOG_CHUNKS = 10 logs_client = logs_manager.data_source_client - logs_client.get_all_registered_log_agent_ids = MagicMock() - logs_client.get_all_registered_log_agent_ids.return_value = ["1", "2"] logs_client.ip_to_node_id = MagicMock() logs_client.stream_log.return_value = generate_logs_stream(NUM_LOG_CHUNKS) @@ -1011,8 +999,8 @@ def verify(): with pytest.raises(requests.HTTPError) as e: list_logs(node_id=node_id) - assert re.match( - f"Given node id {node_id} is not available", e.value.response.json()["msg"] + assert ( + f"Agent for node id: {node_id} doesn't exist." in e.value.response.json()["msg"] ) diff --git a/python/ray/tests/tls/README b/python/ray/tests/tls/README index 0115bebbc7a10..b2e9637e805c2 100644 --- a/python/ray/tests/tls/README +++ b/python/ray/tests/tls/README @@ -24,4 +24,4 @@ openssl req \ openssl dhparam -out {str(tmp_path)}/tls/redis.dh 2048 -See https://github.com/ray-project/ray/pull/40378/ for more details \ No newline at end of file +See https://github.com/ray-project/ray/pull/40378/ for more details diff --git a/python/ray/train/BUILD b/python/ray/train/BUILD index 2b8f54085077e..dbce685dd8f6e 100644 --- a/python/ray/train/BUILD +++ b/python/ray/train/BUILD @@ -271,6 +271,14 @@ py_test( deps = [":train_lib", ":conftest"] ) +py_test( + name = "test_api_migrations", + size = "small", + srcs = ["tests/test_api_migrations.py"], + tags = ["team:ml", "exclusive"], + deps = [":train_lib", ":conftest"] +) + py_test( name = "test_backend", size = "large", diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index f142685caaab7..e27f26700c252 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -669,7 +669,12 @@ def wrapper(*args, **kwargs): @PublicAPI(stability="stable") @_warn_session_misuse() -def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: +def report( + metrics: Dict, + *, + checkpoint: Optional[Checkpoint] = None, + checkpoint_dir_name: Optional[str] = None, +) -> None: """Report metrics and optionally save a checkpoint. If a checkpoint is provided, it will be @@ -750,6 +755,13 @@ def train_func(config): metrics: The metrics you want to report. checkpoint: The optional checkpoint you want to report. """ + if checkpoint_dir_name is not None: + logger.warning( + "`checkpoint_dir_name` is only supported in the new Ray Train " + "implementation, which can be enabled with `RAY_TRAIN_V2_ENABLED=1`. " + "This argument will be ignored." + ) + # If we are running in a Tune function, switch to `ray.tune.report`. from ray.tune.trainable.trainable_fn_utils import _in_tune_session @@ -760,7 +772,9 @@ def train_func(config): _log_deprecation_warning( "`ray.train.report` should be switched to " "`ray.tune.report` when running in a function " - "passed to Ray Tune. This will be an error in the future." + "passed to Ray Tune. This will be an error in the future. " + "See this issue for more context: " + "https://github.com/ray-project/ray/issues/49454" ) return ray.tune.report(metrics, checkpoint=checkpoint) @@ -820,7 +834,9 @@ def train_func(config): _log_deprecation_warning( "`ray.train.get_checkpoint` should be switched to " "`ray.tune.get_checkpoint` when running in a function " - "passed to Ray Tune. This will be an error in the future." + "passed to Ray Tune. This will be an error in the future. " + "See this issue for more context: " + "https://github.com/ray-project/ray/issues/49454" ) return ray.tune.get_checkpoint() diff --git a/python/ray/train/_internal/syncer.py b/python/ray/train/_internal/syncer.py index 4413e92452950..8321552e5fbaa 100644 --- a/python/ray/train/_internal/syncer.py +++ b/python/ray/train/_internal/syncer.py @@ -4,11 +4,10 @@ import time import traceback from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple from ray._private.thirdparty.tabulate.tabulate import tabulate -from ray.train.constants import _DEPRECATED_VALUE -from ray.util.annotations import DeveloperAPI, PublicAPI +from ray.util.annotations import Deprecated, DeveloperAPI from ray.widgets import Template logger = logging.getLogger(__name__) @@ -20,79 +19,13 @@ DEFAULT_SYNC_TIMEOUT = 1800 -@PublicAPI(stability="stable") +@Deprecated @dataclass class SyncConfig: - """Configuration object for Train/Tune file syncing to `RunConfig(storage_path)`. - - In Ray Train/Tune, here is where syncing (mainly uploading) happens: - - The experiment driver (on the head node) syncs the experiment directory to storage - (which includes experiment state such as searcher state, the list of trials - and their statuses, and trial metadata). - - It's also possible to sync artifacts from the trial directory to storage - by setting `sync_artifacts=True`. - For a Ray Tune run with many trials, each trial will upload its trial directory - to storage, which includes arbitrary files that you dumped during the run. - For a Ray Train run doing distributed training, each remote worker will similarly - upload its trial directory to storage. - - See :ref:`persistent-storage-guide` for more details and examples. - - Args: - sync_period: Minimum time in seconds to wait between two sync operations. - A smaller ``sync_period`` will have the data in storage updated more often - but introduces more syncing overhead. Defaults to 5 minutes. - sync_timeout: Maximum time in seconds to wait for a sync process - to finish running. A sync operation will run for at most this long - before raising a `TimeoutError`. Defaults to 30 minutes. - sync_artifacts: [Beta] Whether or not to sync artifacts that are saved to the - trial directory (accessed via `train.get_context().get_trial_dir()`) - to the persistent storage configured via `train.RunConfig(storage_path)`. - The trial or remote worker will try to launch an artifact syncing - operation every time `train.report` happens, subject to `sync_period` - and `sync_artifacts_on_checkpoint`. - Defaults to False -- no artifacts are persisted by default. - sync_artifacts_on_checkpoint: If True, trial/worker artifacts are - forcefully synced on every reported checkpoint. - This only has an effect if `sync_artifacts` is True. - Defaults to True. - """ - sync_period: int = DEFAULT_SYNC_PERIOD sync_timeout: int = DEFAULT_SYNC_TIMEOUT sync_artifacts: bool = False sync_artifacts_on_checkpoint: bool = True - upload_dir: Optional[str] = _DEPRECATED_VALUE - syncer: Optional[Union[str, "Syncer"]] = _DEPRECATED_VALUE - sync_on_checkpoint: bool = _DEPRECATED_VALUE - - # TODO(justinvyu): [Deprecated] Remove in 2.11. - def _deprecation_warning(self, attr_name: str, extra_msg: str): - if getattr(self, attr_name) != _DEPRECATED_VALUE: - raise DeprecationWarning( - f"`SyncConfig({attr_name})` is a deprecated configuration " - "Please remove it from your `SyncConfig`. " - f"{extra_msg}" - ) - - def __post_init__(self): - for attr_name, extra_msg in [ - ( - "upload_dir", - "\nPlease specify `ray.train.RunConfig(storage_path)` instead.", - ), - ( - "syncer", - "\nPlease implement custom syncing logic with a custom " - "`pyarrow.fs.FileSystem` instead, and pass it into " - "`ray.train.RunConfig(storage_filesystem)`. " - "See here: https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html#custom-storage", # noqa: E501 - ), - ("sync_on_checkpoint", ""), - ]: - self._deprecation_warning(attr_name, extra_msg) def _repr_html_(self) -> str: """Generate an HTML representation of the SyncConfig.""" diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 04cac51c1ee06..ffd215e84bd4d 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -26,8 +26,13 @@ _exists_at_fs_path, get_fs_and_path, ) -from ray.util import PublicAPI -from ray.util.annotations import DeveloperAPI +from ray.train.constants import ( + _v2_migration_warnings_enabled, + V2_MIGRATION_GUIDE_MESSAGE, +) +from ray.train.context import _GET_METADATA_DEPRECATION_MESSAGE +from ray.train.utils import _log_deprecation_warning +from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI if TYPE_CHECKING: from ray.data import Dataset @@ -51,6 +56,16 @@ "https://docs.ray.io/en/master/train/user-guides/data-loading-preprocessing.html#preprocessing-structured-data " # noqa:E501 ) +_TRAINER_RESTORE_DEPRECATION_WARNING = ( + "The `restore` and `can_restore` APIs are deprecated and " + f"will be removed in a future release. {V2_MIGRATION_GUIDE_MESSAGE}" +) + +_RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING = ( + "`resume_from_checkpoint` is deprecated and will be removed in an upcoming " + f"release. {V2_MIGRATION_GUIDE_MESSAGE}" +) + @PublicAPI(stability="beta") class TrainingFailedError(RuntimeError): @@ -90,7 +105,7 @@ def _train_coordinator_fn( trainer = trainer_cls(**config) # Get the checkpoint from Tune and pass it to workers later on. - checkpoint = ray.train.get_checkpoint() + checkpoint = ray.tune.get_checkpoint() if checkpoint: # Set `starting_checkpoint` for auto-recovery fault-tolerance # as well as manual restoration. @@ -237,6 +252,12 @@ def __init__( self.datasets = datasets if datasets is not None else {} self.starting_checkpoint = resume_from_checkpoint + if _v2_migration_warnings_enabled(): + if metadata is not None: + _log_deprecation_warning(_GET_METADATA_DEPRECATION_MESSAGE) + if resume_from_checkpoint is not None: + _log_deprecation_warning(_RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING) + # These attributes should only be set through `BaseTrainer.restore` self._restore_path = None self._restore_storage_filesystem = None @@ -245,8 +266,8 @@ def __init__( air_usage.tag_air_trainer(self) - @PublicAPI(stability="alpha") @classmethod + @Deprecated(message=_TRAINER_RESTORE_DEPRECATION_WARNING) def restore( cls: Type["BaseTrainer"], path: Union[str, os.PathLike], @@ -345,6 +366,9 @@ def training_loop(self): Returns: BaseTrainer: A restored instance of the class that is calling this method. """ + if _v2_migration_warnings_enabled(): + _log_deprecation_warning(_TRAINER_RESTORE_DEPRECATION_WARNING) + if not cls.can_restore(path, storage_filesystem): raise ValueError( f"Invalid restore path: {path}. Make sure that this path exists and " @@ -401,8 +425,11 @@ def training_loop(self): trainer._restore_storage_filesystem = storage_filesystem return trainer - @PublicAPI(stability="alpha") @classmethod + @Deprecated( + message=_TRAINER_RESTORE_DEPRECATION_WARNING, + warning=_v2_migration_warnings_enabled(), + ) def can_restore( cls: Type["BaseTrainer"], path: Union[str, os.PathLike], @@ -418,6 +445,9 @@ def can_restore( Returns: bool: Whether this path exists and contains the trainer state to resume from """ + if _v2_migration_warnings_enabled(): + _log_deprecation_warning(_TRAINER_RESTORE_DEPRECATION_WARNING) + fs, fs_path = get_fs_and_path(path, storage_filesystem) trainer_pkl_path = Path(fs_path, _TRAINER_PKL).as_posix() return _exists_at_fs_path(fs, trainer_pkl_path) @@ -505,6 +535,58 @@ def _validate_attributes(self): f"with value `{self.starting_checkpoint}`." ) + self._log_v2_deprecation_warnings() + + def _log_v2_deprecation_warnings(self): + """Logs deprecation warnings for v2 migration. + + Log them here in the Ray Train case rather than in the configuration + constructors to avoid logging incorrect deprecation warnings when + `ray.train.RunConfig` is passed to Ray Tune. + """ + + if not _v2_migration_warnings_enabled(): + return + + from ray.train.v2._internal.migration_utils import ( + FAIL_FAST_DEPRECATION_MESSAGE, + TRAINER_RESOURCES_DEPRECATION_MESSAGE, + VERBOSE_DEPRECATION_MESSAGE, + LOG_TO_FILE_DEPRECATION_MESSAGE, + STOP_DEPRECATION_MESSAGE, + CALLBACKS_DEPRECATION_MESSAGE, + PROGRESS_REPORTER_DEPRECATION_MESSAGE, + SYNC_CONFIG_DEPRECATION_MESSAGE, + ) + + # ScalingConfig deprecations + if self.scaling_config.trainer_resources is not None: + _log_deprecation_warning(TRAINER_RESOURCES_DEPRECATION_MESSAGE) + + # FailureConfig deprecations + if self.run_config.failure_config.fail_fast: + _log_deprecation_warning(FAIL_FAST_DEPRECATION_MESSAGE) + + # RunConfig deprecations + # NOTE: _verbose is the original verbose value passed by the user + if self.run_config._verbose is not None: + _log_deprecation_warning(VERBOSE_DEPRECATION_MESSAGE) + + if self.run_config.log_to_file: + _log_deprecation_warning(LOG_TO_FILE_DEPRECATION_MESSAGE) + + if self.run_config.stop is not None: + _log_deprecation_warning(STOP_DEPRECATION_MESSAGE) + + if self.run_config.callbacks is not None: + _log_deprecation_warning(CALLBACKS_DEPRECATION_MESSAGE) + + if self.run_config.progress_reporter is not None: + _log_deprecation_warning(PROGRESS_REPORTER_DEPRECATION_MESSAGE) + + if self.run_config.sync_config != ray.train.SyncConfig(): + _log_deprecation_warning(SYNC_CONFIG_DEPRECATION_MESSAGE) + @classmethod def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig: """Returns scaling config dataclass after validating updated keys.""" diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py index 4e42e30f4370a..f043d43a4d3fb 100644 --- a/python/ray/train/constants.py +++ b/python/ray/train/constants.py @@ -45,6 +45,25 @@ def _get_ray_train_session_dir() -> str: # Deprecated configs can use this value to detect if the user has set it. _DEPRECATED_VALUE = "DEPRECATED" + +# ================================================== +# Train V2 constants +# ================================================== + +# Set this to 1 to enable deprecation warnings for V2 migration. +ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR = "RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS" + + +V2_MIGRATION_GUIDE_MESSAGE = ( + "See this issue for more context and migration options: " + "https://github.com/ray-project/ray/issues/49454" +) + + +def _v2_migration_warnings_enabled() -> bool: + return env_bool(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, False) + + # ================================================== # Environment Variables # ================================================== @@ -91,13 +110,6 @@ def _get_ray_train_session_dir() -> str: # Defaults to 0 RAY_TRAIN_ENABLE_STATE_TRACKING = "RAY_TRAIN_ENABLE_STATE_TRACKING" -# Set this to 1 to enable deprecation warnings for V2 migration. -ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR = "RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS" - - -def _v2_migration_warnings_enabled() -> bool: - return env_bool(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, False) - # NOTE: When adding a new environment variable, please track it in this list. TRAIN_ENV_VARS = { diff --git a/python/ray/train/context.py b/python/ray/train/context.py index bc447b36f2024..76bcb22e83c12 100644 --- a/python/ray/train/context.py +++ b/python/ray/train/context.py @@ -3,7 +3,10 @@ from ray.train._internal import session from ray.train._internal.storage import StorageContext -from ray.train.constants import _v2_migration_warnings_enabled +from ray.train.constants import ( + _v2_migration_warnings_enabled, + V2_MIGRATION_GUIDE_MESSAGE, +) from ray.train.utils import _copy_doc, _log_deprecation_warning from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI @@ -19,15 +22,16 @@ _GET_METADATA_DEPRECATION_MESSAGE = ( "`get_metadata` was an experimental API that accessed the metadata passed " "to `Trainer(metadata=...)`. This API can be replaced by passing " - "the metadata directly to the training function (e.g., via `train_loop_config`)." + "the metadata directly to the training function (e.g., via `train_loop_config`). " + f"{V2_MIGRATION_GUIDE_MESSAGE}" ) _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE = ( "`{}` is deprecated because the concept of a `Trial` will " - "soon be removed in Ray Train (see here: " - "https://github.com/ray-project/enhancements/pull/57). " + "soon be removed in Ray Train." "Ray Train will no longer assume that it's running within a Ray Tune `Trial` " - "in the future." + "in the future. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" ) @@ -127,7 +131,8 @@ def get_context() -> TrainContext: _log_deprecation_warning( "`ray.train.get_context()` should be switched to " "`ray.tune.get_context()` when running in a function " - "passed to Ray Tune. This will be an error in the future." + "passed to Ray Tune. This will be an error in the future. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" ) return get_tune_context() diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index a14dc47d36dd3..bcfd015272fb5 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -12,9 +12,10 @@ from ray.train._internal.data_config import DataConfig from ray.train._internal.session import _TrainingResult, get_session from ray.train._internal.utils import construct_train_func, count_required_parameters +from ray.train.base_trainer import _TRAINER_RESTORE_DEPRECATION_WARNING from ray.train.constants import RAY_TRAIN_ENABLE_STATE_TRACKING from ray.train.trainer import BaseTrainer, GenDataset -from ray.util.annotations import DeveloperAPI, PublicAPI +from ray.util.annotations import Deprecated, DeveloperAPI from ray.widgets import Template from ray.widgets.util import repr_with_fallback @@ -273,8 +274,8 @@ def __init__( get_or_create_state_actor() - @PublicAPI(stability="beta") @classmethod + @Deprecated(message=_TRAINER_RESTORE_DEPRECATION_WARNING) def restore( cls: Type["DataParallelTrainer"], path: str, diff --git a/python/ray/train/examples/deepspeed/deepspeed_torch_trainer.py b/python/ray/train/examples/deepspeed/deepspeed_torch_trainer.py index 8a2e7a2500541..ed46d969d813b 100644 --- a/python/ray/train/examples/deepspeed/deepspeed_torch_trainer.py +++ b/python/ray/train/examples/deepspeed/deepspeed_torch_trainer.py @@ -180,6 +180,6 @@ def collate_fn(batch): result = trainer.fit() # Retrieve the best checkponints from results - result.best_checkpoints + _ = result.best_checkpoints # __deepspeed_torch_basic_example_end__ diff --git a/python/ray/train/examples/deepspeed/deepspeed_torch_trainer_no_raydata.py b/python/ray/train/examples/deepspeed/deepspeed_torch_trainer_no_raydata.py index 5ba8bf075c583..e78d0b3df49f5 100644 --- a/python/ray/train/examples/deepspeed/deepspeed_torch_trainer_no_raydata.py +++ b/python/ray/train/examples/deepspeed/deepspeed_torch_trainer_no_raydata.py @@ -173,6 +173,6 @@ def collate_fn(batch): result = trainer.fit() # Retrieve the best checkponints from results - result.best_checkpoints + _ = result.best_checkpoints # __deepspeed_torch_basic_example_no_raydata_end__ diff --git a/python/ray/train/tests/test_api_migrations.py b/python/ray/train/tests/test_api_migrations.py new file mode 100644 index 0000000000000..dd2a941b8eae4 --- /dev/null +++ b/python/ray/train/tests/test_api_migrations.py @@ -0,0 +1,137 @@ +import sys +import warnings + +import pytest + +import ray.train +import ray.tune +from ray.train.constants import ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR +from ray.train.data_parallel_trainer import DataParallelTrainer +from ray.util.annotations import RayDeprecationWarning + + +@pytest.fixture(autouse=True) +def enable_v2_migration_deprecation_messages(monkeypatch): + monkeypatch.setenv(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, "1") + yield + monkeypatch.delenv(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR) + + +def test_trainer_restore(): + with pytest.warns(RayDeprecationWarning, match="restore"): + try: + DataParallelTrainer.restore("dummy") + except Exception: + pass + + with pytest.warns(RayDeprecationWarning, match="can_restore"): + try: + DataParallelTrainer.can_restore("dummy") + except Exception: + pass + + +def test_trainer_valid_configs(tmp_path): + with warnings.catch_warnings(): + DataParallelTrainer( + lambda _: None, + scaling_config=ray.train.ScalingConfig(num_workers=1), + run_config=ray.train.RunConfig( + storage_path=tmp_path, + failure_config=ray.train.FailureConfig(max_failures=1), + ), + ) + + +def test_trainer_deprecated_configs(): + with pytest.warns(RayDeprecationWarning, match="metadata"): + DataParallelTrainer( + lambda _: None, + metadata={"dummy": "dummy"}, + ) + + with pytest.warns(RayDeprecationWarning, match="resume_from_checkpoint"): + DataParallelTrainer( + lambda _: None, + resume_from_checkpoint=ray.train.Checkpoint.from_directory("dummy"), + ) + + with pytest.warns(RayDeprecationWarning, match="fail_fast"): + DataParallelTrainer( + lambda _: None, + run_config=ray.train.RunConfig( + failure_config=ray.train.FailureConfig(fail_fast=True) + ), + ) + + with pytest.warns(RayDeprecationWarning, match="trainer_resources"): + DataParallelTrainer( + lambda _: None, + scaling_config=ray.train.ScalingConfig(trainer_resources={"CPU": 1}), + ) + + with pytest.warns(RayDeprecationWarning, match="verbose"): + DataParallelTrainer( + lambda _: None, + run_config=ray.train.RunConfig(verbose=True), + ) + + with pytest.warns(RayDeprecationWarning, match="log_to_file"): + DataParallelTrainer( + lambda _: None, + run_config=ray.train.RunConfig(log_to_file=True), + ) + + with pytest.warns(RayDeprecationWarning, match="stop"): + DataParallelTrainer( + lambda _: None, + run_config=ray.train.RunConfig(stop={"training_iteration": 1}), + ) + + with pytest.warns(RayDeprecationWarning, match="callbacks"): + DataParallelTrainer( + lambda _: None, + run_config=ray.train.RunConfig(callbacks=[ray.tune.Callback()]), + ) + + with pytest.warns(RayDeprecationWarning, match="progress_reporter"): + DataParallelTrainer( + lambda _: None, + run_config=ray.train.RunConfig( + progress_reporter=ray.tune.ProgressReporter() + ), + ) + + with pytest.warns(RayDeprecationWarning, match="sync_config"): + DataParallelTrainer( + lambda _: None, + run_config=ray.train.RunConfig( + sync_config=ray.train.SyncConfig(sync_artifacts=True) + ), + ) + + +def test_train_context_deprecations(ray_start_4_cpus, tmp_path): + def train_fn_per_worker(config): + with pytest.warns(RayDeprecationWarning, match="get_trial_dir"): + ray.train.get_context().get_trial_dir() + + with pytest.warns(RayDeprecationWarning, match="get_trial_id"): + ray.train.get_context().get_trial_id() + + with pytest.warns(RayDeprecationWarning, match="get_trial_name"): + ray.train.get_context().get_trial_name() + + with pytest.warns(RayDeprecationWarning, match="get_trial_resources"): + ray.train.get_context().get_trial_resources() + + trainer = DataParallelTrainer( + train_fn_per_worker, + scaling_config=ray.train.ScalingConfig(num_workers=1), + run_config=ray.train.RunConfig(storage_path=tmp_path), + ) + trainer.fit() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/train/tests/test_base_trainer.py b/python/ray/train/tests/test_base_trainer.py index 5ce0ae6ccda0a..1c14b6c9b8a18 100644 --- a/python/ray/train/tests/test_base_trainer.py +++ b/python/ray/train/tests/test_base_trainer.py @@ -175,7 +175,7 @@ def test_large_params(ray_start_4_cpus): huge_array = np.zeros(shape=int(1e8)) def training_loop(self): - huge_array + _ = huge_array trainer = DummyTrainer(training_loop) trainer.fit() diff --git a/python/ray/train/utils.py b/python/ray/train/utils.py index 98b11f1f6091f..395cd01ab9682 100644 --- a/python/ray/train/utils.py +++ b/python/ray/train/utils.py @@ -11,7 +11,7 @@ def wrapped(func): return wrapped -def _log_deprecation_warning(message): +def _log_deprecation_warning(message: str): warnings.warn( message, RayDeprecationWarning, diff --git a/python/ray/train/v2/_internal/execution/callback.py b/python/ray/train/v2/_internal/execution/callback.py index a2a6ada03348b..dba5332c4854a 100644 --- a/python/ray/train/v2/_internal/execution/callback.py +++ b/python/ray/train/v2/_internal/execution/callback.py @@ -86,7 +86,6 @@ def before_controller_execute_failure_decision( def before_controller_execute_scaling_decision( self, scaling_decision: "ScalingDecision", - worker_group_status: "WorkerGroupStatus", ): """Called before the controller executes a scaling decision.""" pass diff --git a/python/ray/train/v2/_internal/execution/context.py b/python/ray/train/v2/_internal/execution/context.py index 5b7dae52a46ad..e8d754fe9e027 100644 --- a/python/ray/train/v2/_internal/execution/context.py +++ b/python/ray/train/v2/_internal/execution/context.py @@ -81,28 +81,45 @@ class TrainContext(TrainRunContext): @_copy_doc(session.get_metadata) def get_metadata(self) -> Dict[str, Any]: - raise NotImplementedError + from ray.train.context import _GET_METADATA_DEPRECATION_MESSAGE - @_copy_doc(session.get_experiment_name) - def get_experiment_name(self) -> str: - # TODO: Resolve run_config.name if it is None - return self.run_config.name + raise DeprecationWarning(_GET_METADATA_DEPRECATION_MESSAGE) @_copy_doc(session.get_trial_name) def get_trial_name(self) -> str: - raise NotImplementedError + from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE + + raise DeprecationWarning( + _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_name") + ) @_copy_doc(session.get_trial_id) def get_trial_id(self) -> str: - raise NotImplementedError + from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE + + raise DeprecationWarning( + _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_id") + ) @_copy_doc(session.get_trial_resources) def get_trial_resources(self): - raise NotImplementedError + from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE + + raise DeprecationWarning( + _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_resources") + ) @_copy_doc(session.get_trial_dir) def get_trial_dir(self) -> str: - raise NotImplementedError + from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE + + raise DeprecationWarning( + _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_dir") + ) + + @_copy_doc(session.get_experiment_name) + def get_experiment_name(self) -> str: + return self.run_config.name @_copy_doc(session.get_world_size) def get_world_size(self) -> int: @@ -128,6 +145,7 @@ def get_node_rank(self) -> int: def get_storage(self): return self.storage_context + # TODO: Don't allow these private methods to be called from user code. def get_result_queue(self): return self.execution_context.result_queue diff --git a/python/ray/train/v2/_internal/execution/controller/__init__.py b/python/ray/train/v2/_internal/execution/controller/__init__.py new file mode 100644 index 0000000000000..f28117ecf45e4 --- /dev/null +++ b/python/ray/train/v2/_internal/execution/controller/__init__.py @@ -0,0 +1,3 @@ +from .controller import TrainController + +__all__ = ["TrainController"] diff --git a/python/ray/train/v2/_internal/execution/controller.py b/python/ray/train/v2/_internal/execution/controller/controller.py similarity index 57% rename from python/ray/train/v2/_internal/execution/controller.py rename to python/ray/train/v2/_internal/execution/controller/controller.py index c4b65dc6f6e2f..31fdcb7d8d99d 100644 --- a/python/ray/train/v2/_internal/execution/controller.py +++ b/python/ray/train/v2/_internal/execution/controller/controller.py @@ -1,12 +1,13 @@ import logging import os import time -from enum import Enum -from pathlib import Path +import uuid +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional +import pandas as pd + from ray._private.auto_init_hook import wrap_auto_init -from ray.train import Checkpoint from ray.train.v2._internal.constants import ( DEFAULT_HEALTH_CHECK_INTERVAL_S, HEALTH_CHECK_INTERVAL_S_ENV_VAR, @@ -30,16 +31,28 @@ ReportCallbackHandler, ) from ray.train.v2._internal.execution.context import TrainRunContext +from ray.train.v2._internal.execution.controller.state import ( + ErroredState, + FinishedState, + InitializingState, + ReschedulingState, + ResizingState, + RestartingState, + RunningState, + SchedulingState, + TrainControllerState, +) from ray.train.v2._internal.execution.failure_handling import ( FailureDecision, FailurePolicy, ) from ray.train.v2._internal.execution.scaling_policy import ( + NoopDecision, ResizeDecision, ScalingDecision, ScalingPolicy, ) -from ray.train.v2._internal.execution.storage import StorageContext, get_fs_and_path +from ray.train.v2._internal.execution.storage import StorageContext from ray.train.v2._internal.execution.worker_group import WorkerGroup, WorkerGroupStatus from ray.train.v2._internal.logging.logging import configure_controller_logger from ray.train.v2._internal.util import time_monotonic @@ -49,26 +62,24 @@ logger = logging.getLogger(__name__) -class TrainControllerState(Enum): - """The possible states that the training controller can be in - while running the main execution control loop. +@dataclass +class TrainControllerLoopIterationResult: + """The result of a single iteration of the control loop.""" - States: - RUNNING: The training controller is actively running training tasks. - RECOVERING: The training controller is in the process of recovering - from an error. - INITIALIZING: The train controller is starting up. - This is always the initial state of the controller. - ERRORED: A terminal state indicating that training has encountered - an error and cannot continue. - FINISHED: A terminal state indicating that training has completed. - """ + run_attempt_id: str + previous_state: TrainControllerState + next_state: TrainControllerState + training_failed_error: Optional[TrainingFailedError] = None - RUNNING = "RUNNING" - INITIALIZING = "INITIALIZING" - RECOVERING = "RECOVERING" - ERRORED = "ERRORED" - FINISHED = "FINISHED" + def __repr__(self) -> str: + return ( + f"TrainControllerLoopIterationResult(\n" + f" run_attempt_id={self.run_attempt_id},\n" + f" previous_state={self.previous_state._state_type.state_name},\n" + f" next_state={self.next_state._state_type.state_name}\n" + f" training_failed_error={self.training_failed_error}\n" + f")" + ) class TrainController: @@ -91,8 +102,6 @@ def __init__( scaling_policy: ScalingPolicy, failure_policy: FailurePolicy, callbacks: Optional[List[RayTrainCallback]] = None, - # TODO: [Deprecation] - resume_from_checkpoint: Optional[Checkpoint] = None, ): self._train_run_context = train_run_context configure_controller_logger(self._train_run_context) @@ -101,7 +110,6 @@ def __init__( self._failure_policy = failure_policy self._run_config = self._train_run_context.run_config self._callbacks = callbacks or [] - self._resume_from_checkpoint = resume_from_checkpoint self._storage_context = StorageContext( storage_path=self._run_config.storage_path, experiment_dir_name=self._run_config.name, @@ -133,47 +141,65 @@ def __init__( ) ] + self._health_check_interval_s = float( + os.getenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, DEFAULT_HEALTH_CHECK_INTERVAL_S) + ) + self._worker_group = self.worker_group_cls( train_run_context=self._train_run_context, callbacks=worker_group_callbacks_to_propagate, ) - self._state = TrainControllerState.INITIALIZING + self._state = InitializingState() + # TODO: These can be attributes of a RunAttempt? self._latest_poll_time = float("-inf") - self._health_check_interval_s = float( - os.getenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, DEFAULT_HEALTH_CHECK_INTERVAL_S) - ) - self._training_failed_error: Optional[TrainingFailedError] = None def _execute_scaling_decision( - self, decision: ScalingDecision, worker_group_status: WorkerGroupStatus - ): + self, decision: ScalingDecision + ) -> TrainControllerState: """Executes scaling decisions.""" for callback in self._controller_callbacks: - callback.before_controller_execute_scaling_decision( - decision, worker_group_status - ) + callback.before_controller_execute_scaling_decision(decision) if isinstance(decision, ResizeDecision): - self._restart_worker_group( + # TODO: Add more control over restart vs. shutdown + start. + worker_group_restarted = self._restart_worker_group( num_workers=decision.num_workers, resources_per_worker=decision.resources_per_worker, ) + if worker_group_restarted: + next_state = RunningState() + else: + next_state = ReschedulingState() + + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=self._state, + next_state=next_state, + ) + def _execute_failure_decision( self, failure_decision: FailureDecision, worker_group_status: WorkerGroupStatus - ): + ) -> TrainControllerState: """Executes failure handling decisions (ex: restart, terminate).""" assert worker_group_status.errors + controller_state = self.get_state() + for callback in self._controller_callbacks: callback.before_controller_execute_failure_decision( failure_decision, worker_group_status ) + # TODO: What should we do here? + # This currently never happens because there must be errors. if failure_decision == FailureDecision.NOOP: - assert self._state == TrainControllerState.RUNNING - return + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=controller_state, + next_state=RunningState(), + ) errors_str = "\n".join( [ @@ -188,19 +214,36 @@ def _execute_failure_decision( f"failures on {len(worker_group_status.errors)} worker(s):\n" f"{errors_str}" ) - # Shutdown the worker group so that we don't keep polling errored tasks. - self._worker_group.shutdown() - self._set_state(TrainControllerState.RECOVERING) + training_failed_error = TrainingFailedError( + worker_failures=worker_group_status.errors + ) + next_state = RestartingState( + training_failed_error=training_failed_error, + ) + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=controller_state, + next_state=next_state, + training_failed_error=training_failed_error, + ) elif failure_decision == FailureDecision.RAISE: logger.error( "Terminating training worker group after encountering " f"failure(s) on {len(worker_group_status.errors)} worker(s):\n" f"{errors_str}" ) - self._set_state(TrainControllerState.ERRORED) - self._training_failed_error = TrainingFailedError( + training_failed_error = TrainingFailedError( worker_failures=worker_group_status.errors ) + next_state = ErroredState( + training_failed_error=training_failed_error, + ) + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=controller_state, + next_state=next_state, + training_failed_error=training_failed_error, + ) else: raise ValueError(f"Unexpected failure decision: {failure_decision}") @@ -217,8 +260,14 @@ def _poll_workers(self) -> WorkerGroupStatus: self._latest_poll_time = time_monotonic() return status - def _restart_worker_group(self, num_workers: int, resources_per_worker: dict): - """Restart the worker group and launch the train function.""" + def _restart_worker_group( + self, num_workers: int, resources_per_worker: dict + ) -> bool: + """Restart the worker group and launch the train function. + + Returns: + True if the worker group was successfully restarted, False otherwise. + """ self._worker_group.shutdown() # If there's a latest checkpoint that's been committed, @@ -238,7 +287,7 @@ def _restart_worker_group(self, num_workers: int, resources_per_worker: dict): num_workers=num_workers, resources_per_worker=resources_per_worker, placement_strategy=placement_strategy, - checkpoint=latest_checkpoint or self._resume_from_checkpoint, + checkpoint=latest_checkpoint, ) except (WorkerGroupStartupTimeoutError, WorkerGroupStartupFailedError) as e: logger.error( @@ -249,11 +298,10 @@ def _restart_worker_group(self, num_workers: int, resources_per_worker: dict): # TODO: Should this logic go through the failure policy? # The current logic will always try recovering unconditionally # on startup errors without a retry limit. - self._set_state(TrainControllerState.RECOVERING) - return + return False # TODO: Consider starting the worker group asynchronously. - self._set_state(TrainControllerState.RUNNING) + return True def _start(self): for callback in self._controller_callbacks: @@ -278,78 +326,132 @@ def _set_state(self, state: TrainControllerState): for callback in self._controller_callbacks: callback.after_controller_state_update(previous_state, state) - def _run_control_loop_iteration(self): + def _step(self) -> TrainControllerLoopIterationResult: """Run a single iteration of the control loop. - Steps: - 1. Poll the worker group for status. - 2. If the worker group is initializing or recovering from an error, - make a scaling decision and execute it. - 3. If the worker group has finished, set the controller state to FINISHED. - 4. If the worker group has errors, make a failure decision and execute it. - 5. Otherwise, the worker group is running healthily. - Query the scaling policy for a scaling decision and execute it. + Returns: + The result of the iteration. """ - assert self.get_state() in ( - TrainControllerState.RUNNING, - TrainControllerState.RECOVERING, - TrainControllerState.INITIALIZING, - ), self.get_state() - - worker_group_status = self._poll_workers() - - if worker_group_status.finished and not worker_group_status.errors: - self._set_state(TrainControllerState.FINISHED) - return - - if self.get_state() in ( - TrainControllerState.INITIALIZING, - TrainControllerState.RECOVERING, - ): + controller_state = self.get_state() + + if isinstance(controller_state, InitializingState): scaling_decision = ( - self._scaling_policy.make_decision_for_non_running_worker_group( - worker_group_status - ) + self._scaling_policy.make_decision_for_non_running_worker_group() + ) + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=controller_state, + next_state=SchedulingState(scaling_decision), + ) + elif isinstance(controller_state, SchedulingState): + return self._execute_scaling_decision(controller_state.scaling_decision) + elif isinstance(controller_state, ReschedulingState): + scaling_decision = ( + self._scaling_policy.make_decision_for_non_running_worker_group() + ) + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=controller_state, + next_state=SchedulingState(scaling_decision), ) - self._execute_scaling_decision(scaling_decision, worker_group_status) - elif self.get_state() == TrainControllerState.RUNNING: + elif isinstance(controller_state, RunningState): + worker_group_status = self._poll_workers() + + if worker_group_status.finished and not worker_group_status.errors: + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=controller_state, + next_state=FinishedState(), + ) if worker_group_status.errors: failure_decision = self._failure_policy.make_decision( worker_group_status ) - self._execute_failure_decision(failure_decision, worker_group_status) + return self._execute_failure_decision( + failure_decision, worker_group_status + ) else: scaling_decision = ( self._scaling_policy.make_decision_for_running_worker_group( worker_group_status ) ) - self._execute_scaling_decision(scaling_decision, worker_group_status) + + if isinstance(scaling_decision, ResizeDecision): + next_state = ResizingState( + scaling_decision=scaling_decision, + ) + elif isinstance(scaling_decision, NoopDecision): + next_state = RunningState() + else: + raise ValueError(f"Unexpected scaling decision: {scaling_decision}") + + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=controller_state, + next_state=next_state, + ) + elif isinstance(controller_state, RestartingState): + scaling_decision = ( + self._scaling_policy.make_decision_for_non_running_worker_group() + ) + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=controller_state, + next_state=SchedulingState(scaling_decision=scaling_decision), + ) + elif isinstance(controller_state, ResizingState): + return TrainControllerLoopIterationResult( + run_attempt_id=self._get_run_attempt_id(), + previous_state=controller_state, + next_state=SchedulingState( + scaling_decision=controller_state.scaling_decision + ), + ) + else: + raise ValueError(f"Unexpected controller state: {controller_state}") + + def _generate_run_attempt_id(self): + self._run_attempt_id = uuid.uuid4().hex + return self._run_attempt_id + + def _get_run_attempt_id(self): + return self._run_attempt_id + + def _run_control_loop_iteration(self): + """Run a single iteration of the control loop. + + Steps: + 1. Poll the worker group for status. + 2. If the worker group is initializing or recovering from an error, + make a scaling decision and execute it. + 3. If the worker group has finished, set the controller state to FINISHED. + 4. If the worker group has errors, make a failure decision and execute it. + 5. Otherwise, the worker group is running healthily. + Query the scaling policy for a scaling decision and execute it. + """ + controller_state = self.get_state() + assert not controller_state.is_terminal() + + if controller_state.needs_new_run_attempt(): + self._generate_run_attempt_id() + + result = self._step() + + self._set_state(result.next_state) @wrap_auto_init def run(self): """Run the main control loop. Exits when training is finished or errored.""" self._start() - while self.get_state() not in ( - TrainControllerState.ERRORED, - TrainControllerState.FINISHED, - ): + while not self.get_state().is_terminal(): self._run_control_loop_iteration() self._shutdown() - def get_result(self) -> Result: - """Get the final training result from the TrainController.""" - - controller_state = self.get_state() - if controller_state not in ( - TrainControllerState.FINISHED, - TrainControllerState.ERRORED, - ): - raise ValueError( - f"Cannot get result when controller is in state {controller_state}" - ) + def _build_result(self) -> Result: + storage = self._checkpoint_manager._storage_context latest_checkpoint_result = self._checkpoint_manager.latest_checkpoint_result latest_metrics = ( @@ -362,16 +464,43 @@ def get_result(self) -> Result: (r.checkpoint, r.metrics) for r in self._checkpoint_manager.best_checkpoint_results ] - storage_filesystem, storage_fs_path = get_fs_and_path( - self._run_config.storage_path, self._run_config.storage_filesystem - ) - experiment_fs_path = Path(storage_fs_path, self._run_config.name).as_posix() + + # Provide the history of metrics attached to checkpoints as a dataframe. + metrics_dataframe = None + if best_checkpoints: + metrics_dataframe = pd.DataFrame([m for _, m in best_checkpoints]) return Result( metrics=latest_metrics, checkpoint=latest_checkpoint, - error=self._training_failed_error, - path=experiment_fs_path, + error=self.get_training_failed_error(), + path=storage.experiment_fs_path, best_checkpoints=best_checkpoints, - _storage_filesystem=storage_filesystem, + metrics_dataframe=metrics_dataframe, + _storage_filesystem=storage.storage_filesystem, ) + + def get_result(self) -> Result: + """Get the final training result from the TrainController.""" + + controller_state = self.get_state() + if not controller_state.is_terminal(): + raise ValueError( + f"Cannot get result when controller is in state {controller_state}" + ) + + return self._build_result() + + def get_training_failed_error(self) -> Optional[TrainingFailedError]: + """Get the training failed error from the controller state. + + Returns: + The training failed error if the controller is in an errored state, + None otherwise. + """ + controller_state = self.get_state() + + if isinstance(controller_state, ErroredState): + return controller_state.training_failed_error + + return None diff --git a/python/ray/train/v2/_internal/execution/controller/state.py b/python/ray/train/v2/_internal/execution/controller/state.py new file mode 100644 index 0000000000000..9fdf255aa894c --- /dev/null +++ b/python/ray/train/v2/_internal/execution/controller/state.py @@ -0,0 +1,131 @@ +from enum import Enum + +from ray.train.v2._internal.exceptions import TrainingFailedError +from ray.train.v2._internal.execution.scaling_policy.scaling_policy import ( + ScalingDecision, +) + + +class TrainControllerStateType(Enum): + """Enum representing different states of the train controller. + + States: + INITIALIZING: The train controller is starting up. This is always the initial + state of the controller. + SCHEDULING: The training controller is in the process of scheduling a new worker + group. + RESCHEDULING: The train controller is in the process of rescheduling the worker + group. + RUNNING: The train controller is actively running training tasks. + RESTARTING: The train controller is in the process of recovering from an error. + RESIZING: The train controller is in the process of resizing a running worker + group. + ERRORED: A terminal state indicating that training has encountered an error and + cannot continue. + FINISHED: A terminal state indicating that training has completed. + + Args: + state_name: The name of the state. + is_terminal: Whether this is a terminal state that should not be further processed. + needs_new_run_attempt: Whether this state requires starting a new run attempt, where + a run attempt is a logical unit that encompasses both scheduling workers and + executing training on those workers. + """ + + INITIALIZING = ("INITIALIZING", False, True) + SCHEDULING = ("SCHEDULING", False, False) + RESCHEDULING = ("RESCHEDULING", False, False) + RUNNING = ("RUNNING", False, False) + RESTARTING = ("RESTARTING", False, True) + RESIZING = ("RESIZING", False, True) + ERRORED = ("ERRORED", True, False) + FINISHED = ("FINISHED", True, False) + + def __init__(self, state_name: str, is_terminal: bool, needs_new_run_attempt: bool): + self.state_name = state_name + self.is_terminal = is_terminal + self.needs_new_run_attempt = needs_new_run_attempt + + +class TrainControllerState: + """Base class for all train controller states. + + Methods: + get_type() -> TrainControllerStateType: Returns the type of the state. + is_terminal() -> bool: Returns whether the state is terminal. + needs_new_run_attempt() -> bool: Returns whether a new run attempt is needed. + """ + + def __init__(self, state_type: TrainControllerStateType): + self._state_type = state_type + + def __repr__(self) -> str: + attrs = { + "type": self._state_type.name, + "is_terminal": self._state_type.is_terminal, + "needs_new_run_attempt": self._state_type.needs_new_run_attempt, + **{k: v for k, v in vars(self).items() if not k.startswith("_")}, + } + attrs_str = "\n ".join(f"{k}={v}" for k, v in attrs.items()) + return f"{self.__class__.__name__}(\n {attrs_str}\n)" + + def is_terminal(self) -> bool: + return self._state_type.is_terminal + + def needs_new_run_attempt(self) -> bool: + return self._state_type.needs_new_run_attempt + + +class InitializingState(TrainControllerState): + def __init__(self): + super().__init__(state_type=TrainControllerStateType.INITIALIZING) + + +class SchedulingState(TrainControllerState): + def __init__(self, scaling_decision: ScalingDecision): + super().__init__(state_type=TrainControllerStateType.SCHEDULING) + self.scaling_decision = scaling_decision + + +class ReschedulingState(TrainControllerState): + def __init__(self): + super().__init__(state_type=TrainControllerStateType.RESCHEDULING) + + +class RunningState(TrainControllerState): + # TODO: Split into multiple more granular states, or add more fields. + # For example, we may want to indicate if any health checks failed. + def __init__(self): + super().__init__(state_type=TrainControllerStateType.RUNNING) + + +class RestartingState(TrainControllerState): + def __init__( + self, + training_failed_error: TrainingFailedError, + ): + super().__init__(state_type=TrainControllerStateType.RESTARTING) + self.training_failed_error = training_failed_error + + +class ResizingState(TrainControllerState): + def __init__( + self, + scaling_decision: ScalingDecision, + ): + super().__init__(state_type=TrainControllerStateType.RESIZING) + self.scaling_decision = scaling_decision + + +class ErroredState(TrainControllerState): + def __init__( + self, + training_failed_error: TrainingFailedError, + ): + super().__init__(state_type=TrainControllerStateType.ERRORED) + self.training_failed_error = training_failed_error + + +class FinishedState(TrainControllerState): + def __init__(self): + super().__init__(state_type=TrainControllerStateType.FINISHED) diff --git a/python/ray/train/v2/_internal/execution/scaling_policy/fixed.py b/python/ray/train/v2/_internal/execution/scaling_policy/fixed.py index e3263250ad437..094876e50abae 100644 --- a/python/ray/train/v2/_internal/execution/scaling_policy/fixed.py +++ b/python/ray/train/v2/_internal/execution/scaling_policy/fixed.py @@ -9,7 +9,7 @@ class FixedScalingPolicy(ScalingPolicy): def make_decision_for_non_running_worker_group( - self, worker_group_status: WorkerGroupStatus + self, ) -> ScalingDecision: return ResizeDecision( num_workers=self.scaling_config.num_workers, diff --git a/python/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py b/python/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py index 8f2774a514ae1..629db15cb8d5c 100644 --- a/python/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py +++ b/python/ray/train/v2/_internal/execution/scaling_policy/scaling_policy.py @@ -32,13 +32,14 @@ class ScalingPolicy(abc.ABC, ControllerCallback): Upscale decisions are optional and are made when workers are healthy. """ + # TODO: Restructure these APIs to consider different TrainControllerStates + # instead of just running and non-running worker groups. + def __init__(self, scaling_config: ScalingConfig): self.scaling_config = scaling_config @abc.abstractmethod - def make_decision_for_non_running_worker_group( - self, worker_group_status: WorkerGroupStatus - ) -> ScalingDecision: + def make_decision_for_non_running_worker_group(self) -> ScalingDecision: """Makes a scaling decision when the worker group is initializing or recovering from an error.""" raise NotImplementedError diff --git a/python/ray/train/v2/_internal/execution/worker_group/worker_group.py b/python/ray/train/v2/_internal/execution/worker_group/worker_group.py index e03562c6e480e..fc3d25b8161b8 100644 --- a/python/ray/train/v2/_internal/execution/worker_group/worker_group.py +++ b/python/ray/train/v2/_internal/execution/worker_group/worker_group.py @@ -598,10 +598,7 @@ def poll_status(self, timeout: Optional[float] = None) -> WorkerGroupStatus: worker_group_status = WorkerGroupStatus( num_workers=len(self._workers), latest_start_time=self._latest_start_time, - worker_statuses={ - world_rank: worker_status - for world_rank, worker_status in enumerate(poll_results) - }, + worker_statuses=dict(enumerate(poll_results)), ) for callback in self._callbacks: diff --git a/python/ray/train/v2/_internal/migration_utils.py b/python/ray/train/v2/_internal/migration_utils.py new file mode 100644 index 0000000000000..20ad7bff2019e --- /dev/null +++ b/python/ray/train/v2/_internal/migration_utils.py @@ -0,0 +1,71 @@ +from ray.train.constants import V2_MIGRATION_GUIDE_MESSAGE + + +FAIL_FAST_DEPRECATION_MESSAGE = ( + "`ray.train.FailureConfig(fail_fast)` is deprecated since it is " + "only relevant in the context of multiple trials running in Ray Tune. " + "This parameter is still available in `ray.tune.FailureConfig` " + "for passing into a `ray.tune.Tuner`. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" +) + +TRAINER_RESOURCES_DEPRECATION_MESSAGE = ( + "`ray.train.ScalingConfig(trainer_resources)` is deprecated. " + "This parameter was an advanced configuration that specified " + "resources for the Ray Train driver actor, which doesn't " + "need to reserve logical resources because it doesn't perform " + "any heavy computation. " + "Only the `resources_per_worker` parameter should be used " + "to specify resources for the training workers. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" +) + +VERBOSE_DEPRECATION_MESSAGE = ( + "`ray.train.RunConfig(verbose)` is deprecated. " + "This parameter controls Ray Tune logging verbosity, " + "and is only relevant when using Ray Tune. " + "This parameter is still available in `ray.tune.RunConfig` " + "for passing into a `ray.tune.Tuner`. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" +) + +LOG_TO_FILE_DEPRECATION_MESSAGE = ( + "`ray.train.RunConfig(log_to_file)` is deprecated. " + "The Ray Train driver actor and the training worker actors " + "already log stdout/stderr as part of Ray's logging system. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" +) + +STOP_DEPRECATION_MESSAGE = ( + "`ray.train.RunConfig(stop)` is deprecated. " + "This parameter is only relevant when using Ray Tune " + "and is still available in `ray.tune.RunConfig` " + "for passing into a `ray.tune.Tuner`. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" +) + +CALLBACKS_DEPRECATION_MESSAGE = ( + "`ray.train.RunConfig(callbacks: List[ray.tune.Callback])` is deprecated. " + "Ray Train no longer accepts Ray Tune callbacks, since the Ray Train " + "execution backend is being separated from Ray Tune. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" +) + +PROGRESS_REPORTER_DEPRECATION_MESSAGE = ( + "`ray.train.RunConfig(progress_reporter)` is deprecated. " + "This parameter controls the Ray Tune console output reporter, " + "and is only relevant when using Ray Tune. " + "This parameter is still available in `ray.tune.RunConfig` " + "for passing into a `ray.tune.Tuner`. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" +) + +SYNC_CONFIG_DEPRECATION_MESSAGE = ( + "`ray.train.RunConfig(sync_config)` is deprecated. " + "This configuration controls advanced syncing behavior, " + "which is either not supported or not relevant in the reworked Ray Train. " + "This parameter is still available in `ray.tune.RunConfig` " + "for passing into a `ray.tune.Tuner`. " + "The `SyncConfig` class has been moved to `ray.tune.SyncConfig`. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" +) diff --git a/python/ray/train/v2/api/config.py b/python/ray/train/v2/api/config.py index 464c186bda797..1c46583c53342 100644 --- a/python/ray/train/v2/api/config.py +++ b/python/ray/train/v2/api/config.py @@ -5,6 +5,10 @@ from ray.air.config import RunConfig as RunConfigV1 from ray.air.config import ScalingConfig as ScalingConfigV1 from ray.train.v2._internal.constants import _DEPRECATED +from ray.train.v2._internal.migration_utils import ( + FAIL_FAST_DEPRECATION_MESSAGE, + TRAINER_RESOURCES_DEPRECATION_MESSAGE, +) from ray.train.v2._internal.util import date_str if TYPE_CHECKING: @@ -61,19 +65,11 @@ class ScalingConfig(ScalingConfigV1): """ - trainer_resources: Union[Optional[dict], str] = _DEPRECATED + trainer_resources: Optional[dict] = None def __post_init__(self): - if self.trainer_resources != _DEPRECATED: - raise NotImplementedError( - "`ScalingConfig(trainer_resources)` is deprecated. " - "This parameter was an advanced configuration that specified " - "resources for the Ray Train driver actor, which doesn't " - "need to reserve logical resources because it doesn't perform " - "any heavy computation. " - "Only the `resources_per_worker` parameter is useful " - "to specify resources for the training workers." - ) + if self.trainer_resources is not None: + raise DeprecationWarning(TRAINER_RESOURCES_DEPRECATION_MESSAGE) super().__post_init__() @@ -98,17 +94,13 @@ class FailureConfig(FailureConfigV1): def __post_init__(self): # TODO(justinvyu): Add link to migration guide. if self.fail_fast != _DEPRECATED: - raise NotImplementedError("`FailureConfig(fail_fast)` is deprecated.") + raise DeprecationWarning(FAIL_FAST_DEPRECATION_MESSAGE) @dataclass class RunConfig(RunConfigV1): """Runtime configuration for training runs. - Upon resuming from a training run checkpoint, - Ray Train will automatically apply the RunConfig from - the previously checkpointed run. - Args: name: Name of the trial or experiment. If not provided, will be deduced from the Trainable. @@ -144,7 +136,9 @@ def __post_init__(self): "Ray Tune API that did not support Ray Train usage well, " "so we are dropping support going forward. " "If you heavily rely on these configurations, " - "you can run Ray Train as a single Ray Tune trial." + "you can run Ray Train as a single Ray Tune trial. " + "See this issue for more context: " + "https://github.com/ray-project/ray/issues/49454" ) unsupported_params = [ @@ -156,16 +150,19 @@ def __post_init__(self): ] for param in unsupported_params: if getattr(self, param) != _DEPRECATED: - raise NotImplementedError(run_config_deprecation_message.format(param)) + raise DeprecationWarning(run_config_deprecation_message.format(param)) if not self.name: self.name = f"ray_train_run-{date_str()}" self.callbacks = self.callbacks or [] - # TODO(justinvyu): Improve this error message and add a migration guide if - # the user is passing in a Tune callback. from ray.train.v2.api.callback import RayTrainCallback if not all(isinstance(cb, RayTrainCallback) for cb in self.callbacks): - raise ValueError("All callbacks must be instances of `RayTrainCallback`.") + raise ValueError( + "All callbacks must be instances of `ray.train.UserCallback`. " + "Passing in a Ray Tune callback is no longer supported. " + "See this issue for more context: " + "https://github.com/ray-project/ray/issues/49454" + ) diff --git a/python/ray/train/v2/api/data_parallel_trainer.py b/python/ray/train/v2/api/data_parallel_trainer.py index 8f7405d5225d8..852dea485a5fe 100644 --- a/python/ray/train/v2/api/data_parallel_trainer.py +++ b/python/ray/train/v2/api/data_parallel_trainer.py @@ -5,7 +5,12 @@ from ray._private.ray_constants import env_bool from ray.train import BackendConfig, Checkpoint from ray.train._internal.data_config import DataConfig +from ray.train.base_trainer import ( + _RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING, + _TRAINER_RESTORE_DEPRECATION_WARNING, +) from ray.train.constants import RAY_CHDIR_TO_TRIAL_DIR +from ray.train.context import _GET_METADATA_DEPRECATION_MESSAGE from ray.train.v2._internal.callbacks import ( AcceleratorSetupCallback, BackendSetupCallback, @@ -19,7 +24,6 @@ ) from ray.train.v2._internal.callbacks.user_callback import UserCallbackHandler from ray.train.v2._internal.constants import ( - _UNSUPPORTED, DEFAULT_RUN_CONTROLLER_AS_ACTOR, METRICS_ENABLED_ENV_VAR, RUN_CONTROLLER_AS_ACTOR_ENV_VAR, @@ -33,107 +37,13 @@ from ray.train.v2.api.callback import UserCallback from ray.train.v2.api.config import RunConfig, ScalingConfig from ray.train.v2.api.result import Result +from ray.util.annotations import Deprecated, DeveloperAPI from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy logger = logging.getLogger(__name__) -TRAINER_RESTORE_DEPRECATION_WARNING = """ -The `Trainer.restore` API is deprecated and will be removed in a future release. - -This API previously accepted a Train run directory path and loaded state such as the -training code and configurations from a pkl file, which was hard to use. -Now, Ray Train attempts to load the snapshot of reported checkpoints if it exists -in the run directory, which makes `ray.train.get_checkpoint` available as long as -you're pointing to the same run directory (i.e. the same `storage_path` and `name`). - -If you want to start a new training run without any prior checkpoint history, please -specify a new, unique `RunConfig(name)`. - -Job-level restoration can still be achieved, as shown below: - -Before -------- - -def train_fn_per_worker(config): - checkpoint = ray.train.get_checkpoint() - # Perform your training-specific checkpoint recovery here... - -storage_path = "s3://bucket/" -name = "" -run_path = f"{storage_path}/{name}" - -if TorchTrainer.can_restore(run_path): - # Some parameters are optionally re-specified. - trainer = TorchTrainer.restore(run_path, datasets={...}) - result = trainer.fit() -else: - trainer = TorchTrainer( - train_fn_per_worker, - datasets={...}, - scaling_config=train.ScalingConfig(num_workers=2), - run_config=train.RunConfig(storage_path=storage_path, name=name), - ) - result = trainer.fit() - -After ------ - -def train_fn_per_worker(config): - # `ray.train.get_checkpoint` will be populated as long as your run - # is pointing to the same directory. - checkpoint = ray.train.get_checkpoint() - # Perform your training-specific checkpoint recovery here... - -storage_path = "s3://bucket/" -name = "" - -trainer = TorchTrainer( - train_fn_per_worker, - datasets={...}, - scaling_config=train.ScalingConfig(num_workers=2), - run_config=train.RunConfig(storage_path=storage_path, name=name), -) -result = trainer.fit() -""" - -RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING = """ -`resume_from_checkpoint` is deprecated and will be removed in an upcoming -release, since it is conceptually confusing and can be replaced very easily. -For example: - -Before ------- - -def train_fn_per_worker(config: dict): - # This is the checkpoint passed to `resume_from_checkpoint` - # if no other checkpoints have been saved. - # Otherwise this is the latest reported checkpoint. - checkpoint = ray.train.get_checkpoint() - -trainer = TorchTrainer( - train_fn_per_worker, - ..., - resume_from_checkpoint=ray.train.Checkpoint(...) -) - -After ------ - -def train_fn_per_worker(config: dict): - # Equivalent behavior that is explicit and more flexible. - checkpoint = ( - ray.train.get_checkpoint() - or config.get("resume_from_checkpoint") - ) - -trainer = TorchTrainer( - train_fn_per_worker, - train_loop_config={"resume_from_checkpoint": ray.train.Checkpoint(...)}, -) -""" - - +@DeveloperAPI class DataParallelTrainer: def __init__( self, @@ -143,11 +53,11 @@ def __init__( backend_config: Optional[BackendConfig] = None, scaling_config: Optional[ScalingConfig] = None, run_config: Optional[RunConfig] = None, - # TODO: [Deprecated] Remove in future release - resume_from_checkpoint: Optional[Checkpoint] = None, datasets: Optional[Dict[str, GenDataset]] = None, dataset_config: Optional[DataConfig] = None, - metadata: Optional[Dict[str, Any]] = _UNSUPPORTED, + # TODO: [Deprecated] Remove in future release + resume_from_checkpoint: Optional[Checkpoint] = None, + metadata: Optional[Dict[str, Any]] = None, ): self.run_config = run_config or RunConfig() self.train_run_context = TrainRunContext(self.run_config) @@ -158,14 +68,11 @@ def __init__( self.datasets = datasets or {} self.data_config = dataset_config or DataConfig() - if resume_from_checkpoint: - logger.warning(RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING) - self.resume_from_checkpoint = resume_from_checkpoint + if resume_from_checkpoint is not None: + raise DeprecationWarning(_RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING) - # TODO: No support for below - error_msg = "The '{}' argument is not supported yet." - if metadata != _UNSUPPORTED: - raise NotImplementedError(error_msg.format("metadata")) + if metadata is not None: + raise DeprecationWarning(_GET_METADATA_DEPRECATION_MESSAGE) def fit(self) -> Result: train_fn = construct_train_func( @@ -219,7 +126,6 @@ def fit(self) -> Result: failure_policy=DefaultFailurePolicy(self.run_config.failure_config), train_run_context=self.train_run_context, callbacks=callbacks, - resume_from_checkpoint=self.resume_from_checkpoint, ) if result.error: @@ -254,14 +160,12 @@ def _initialize_and_run_controller(self, **controller_init_kwargs) -> Result: controller.run() return controller.get_result() + @Deprecated @classmethod def restore(cls, *args, **kwargs): - raise DeprecationWarning(TRAINER_RESTORE_DEPRECATION_WARNING) + raise DeprecationWarning(_TRAINER_RESTORE_DEPRECATION_WARNING) + @Deprecated @classmethod def can_restore(cls, *args, **kwargs): - raise DeprecationWarning( - "This API is deprecated and will be removed in a future release. " - "The trainer will be always restored automatically when an existing " - "training snapshot is detected in the run configuration path." - ) + raise DeprecationWarning(_TRAINER_RESTORE_DEPRECATION_WARNING) diff --git a/python/ray/train/v2/api/result.py b/python/ray/train/v2/api/result.py index 1c7d17a9bde0f..c7daffc09f38a 100644 --- a/python/ray/train/v2/api/result.py +++ b/python/ray/train/v2/api/result.py @@ -1,13 +1,14 @@ import logging import os -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Dict, Optional, Union -import pandas as pd import pyarrow from ray.air.result import Result as ResultV1 from ray.train.v2._internal.exceptions import TrainingFailedError +from ray.util.annotations import Deprecated + logger = logging.getLogger(__name__) @@ -15,8 +16,6 @@ @dataclass class Result(ResultV1): error: Optional[TrainingFailedError] - # The metrics dataframe will not be supported in the new Result class. - metrics_dataframe: Optional["pd.DataFrame"] = field(init=False, default=None) @classmethod def from_path( @@ -24,8 +23,12 @@ def from_path( path: Union[str, os.PathLike], storage_filesystem: Optional[pyarrow.fs.FileSystem] = None, ) -> "Result": - raise NotImplementedError + raise NotImplementedError("`Result.from_path` is not implemented yet.") + @Deprecated @property def config(self) -> Optional[Dict[str, Any]]: - raise NotImplementedError + raise DeprecationWarning( + "The `config` property for a `ray.train.Result` is deprecated, " + "since it is only relevant in the context of Ray Tune." + ) diff --git a/python/ray/train/v2/horovod/horovod_trainer.py b/python/ray/train/v2/horovod/horovod_trainer.py index 10fffd2bd846b..e9443ad9d89a3 100644 --- a/python/ray/train/v2/horovod/horovod_trainer.py +++ b/python/ray/train/v2/horovod/horovod_trainer.py @@ -10,10 +10,7 @@ @Deprecated class HorovodTrainer(DataParallelTrainer): - """A Trainer for data parallel Horovod training. - - Horovod Trainer is Deprecated. - """ + """A Trainer for data parallel Horovod training. HorovodTrainer is deprecated.""" def __init__( self, diff --git a/python/ray/train/v2/lightgbm/lightgbm_trainer.py b/python/ray/train/v2/lightgbm/lightgbm_trainer.py index b5a3988bd457b..f75c0f3da968b 100644 --- a/python/ray/train/v2/lightgbm/lightgbm_trainer.py +++ b/python/ray/train/v2/lightgbm/lightgbm_trainer.py @@ -5,9 +5,9 @@ from ray.train import Checkpoint from ray.train.lightgbm.config import LightGBMConfig, get_network_params # noqa from ray.train.trainer import GenDataset -from ray.train.v2._internal.constants import _UNSUPPORTED from ray.train.v2.api.config import RunConfig, ScalingConfig from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer +from ray.util.annotations import Deprecated logger = logging.getLogger(__name__) @@ -124,7 +124,8 @@ def __init__( run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, dataset_config: Optional[ray.train.DataConfig] = None, - metadata: Optional[Dict[str, Any]] = _UNSUPPORTED, + # TODO: [Deprecated] + metadata: Optional[Dict[str, Any]] = None, resume_from_checkpoint: Optional[Checkpoint] = None, ): super(LightGBMTrainer, self).__init__( @@ -139,6 +140,7 @@ def __init__( metadata=metadata, ) + @Deprecated @classmethod def get_model( cls, diff --git a/python/ray/train/v2/tensorflow/tensorflow_trainer.py b/python/ray/train/v2/tensorflow/tensorflow_trainer.py index 61da7d38f92bb..2c3fe16894bda 100644 --- a/python/ray/train/v2/tensorflow/tensorflow_trainer.py +++ b/python/ray/train/v2/tensorflow/tensorflow_trainer.py @@ -3,7 +3,6 @@ from ray.train import Checkpoint, DataConfig from ray.train.tensorflow.config import TensorflowConfig from ray.train.trainer import GenDataset -from ray.train.v2._internal.constants import _UNSUPPORTED from ray.train.v2.api.config import RunConfig, ScalingConfig from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer from ray.util import PublicAPI @@ -171,7 +170,8 @@ def __init__( dataset_config: Optional[DataConfig] = None, run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, - metadata: Optional[Dict[str, Any]] = _UNSUPPORTED, + # TODO: [Deprecated] + metadata: Optional[Dict[str, Any]] = None, resume_from_checkpoint: Optional[Checkpoint] = None, ): if not tensorflow_config: diff --git a/python/ray/train/v2/tests/test_controller.py b/python/ray/train/v2/tests/test_controller.py index 68f93c1b72a82..3b2ffd0b1d1ad 100644 --- a/python/ray/train/v2/tests/test_controller.py +++ b/python/ray/train/v2/tests/test_controller.py @@ -11,8 +11,15 @@ ) from ray.train.v2._internal.execution.callback import ControllerCallback from ray.train.v2._internal.execution.context import TrainRunContext -from ray.train.v2._internal.execution.controller import ( - TrainController, +from ray.train.v2._internal.execution.controller import TrainController +from ray.train.v2._internal.execution.controller.state import ( + InitializingState, + SchedulingState, + ReschedulingState, + RunningState, + RestartingState, + ResizingState, + ErroredState, TrainControllerState, ) from ray.train.v2._internal.execution.failure_handling import ( @@ -94,9 +101,7 @@ def __init__(self, scaling_config): super().__init__(scaling_config) - def make_decision_for_non_running_worker_group( - self, worker_group_status: WorkerGroupStatus - ) -> ScalingDecision: + def make_decision_for_non_running_worker_group(self) -> ScalingDecision: if self._recovery_decision_queue: return self._recovery_decision_queue.pop(0) return NoopDecision() @@ -158,8 +163,6 @@ def test_resize(): ) worker_group = controller.get_worker_group() - controller._checkpoint_handler = MagicMock() - decisions = [ NoopDecision(), ResizeDecision(num_workers=2, resources_per_worker={}), @@ -173,28 +176,46 @@ def test_resize(): ResizeDecision(num_workers=5, resources_per_worker={}), ] + assert isinstance(controller.get_state(), InitializingState) + worker_group_status = worker_group.poll_status() + assert worker_group_status.num_workers == 0 + # Start with 1 worker scaling_policy.queue_recovery_decision( ResizeDecision(num_workers=1, resources_per_worker={}) ) controller._run_control_loop_iteration() - prev_worker_group_status = worker_group.poll_status() - assert prev_worker_group_status.num_workers == 1 + assert isinstance(controller.get_state(), SchedulingState) + worker_group_status = worker_group.poll_status() + assert worker_group_status.num_workers == 0 + + controller._run_control_loop_iteration() + assert isinstance(controller.get_state(), RunningState) + worker_group_status = worker_group.poll_status() + assert worker_group_status.num_workers == 1 for decision in decisions: + prev_worker_group_status = worker_group_status + scaling_policy.queue_monitor_decision(decision) - controller._run_control_loop_iteration() - worker_group_status = worker_group.poll_status() if isinstance(decision, NoopDecision): + controller._run_control_loop_iteration() + assert isinstance(controller.get_state(), RunningState) + worker_group_status = worker_group.poll_status() assert ( worker_group_status.num_workers == prev_worker_group_status.num_workers ) else: + controller._run_control_loop_iteration() + assert isinstance(controller.get_state(), ResizingState) + controller._run_control_loop_iteration() + assert isinstance(controller.get_state(), SchedulingState) + controller._run_control_loop_iteration() + assert isinstance(controller.get_state(), RunningState) + worker_group_status = worker_group.poll_status() assert worker_group_status.num_workers == decision.num_workers - prev_worker_group_status = worker_group_status - def test_failure_handling(): scaling_policy = MockScalingPolicy(scaling_config=ScalingConfig()) @@ -208,30 +229,32 @@ def test_failure_handling(): ) worker_group = controller.get_worker_group() - controller._checkpoint_handler = MagicMock() - - assert controller.get_state() == TrainControllerState.INITIALIZING + assert isinstance(controller.get_state(), InitializingState) scaling_policy.queue_recovery_decision( ResizeDecision(num_workers=2, resources_per_worker={}) ) controller._run_control_loop_iteration() - assert controller.get_state() == TrainControllerState.RUNNING + assert isinstance(controller.get_state(), SchedulingState) + controller._run_control_loop_iteration() + assert isinstance(controller.get_state(), RunningState) worker_group.error_worker(1) failure_policy.queue_decision(FailureDecision.RESTART) controller._run_control_loop_iteration() - assert controller.get_state() == TrainControllerState.RECOVERING + assert isinstance(controller.get_state(), RestartingState) scaling_policy.queue_recovery_decision( ResizeDecision(num_workers=4, resources_per_worker={}) ) controller._run_control_loop_iteration() - assert controller.get_state() == TrainControllerState.RUNNING + assert isinstance(controller.get_state(), SchedulingState) + controller._run_control_loop_iteration() + assert isinstance(controller.get_state(), RunningState) worker_group.error_worker(3) failure_policy.queue_decision(FailureDecision.RAISE) controller._run_control_loop_iteration() - assert controller.get_state() == TrainControllerState.ERRORED + assert isinstance(controller.get_state(), ErroredState) @pytest.mark.parametrize( @@ -248,26 +271,35 @@ def test_worker_group_start_failure(error_type): scaling_policy=scaling_policy, failure_policy=failure_policy, ) - controller._checkpoint_handler = MagicMock() worker_group: DummyWorkerGroup = controller.get_worker_group() worker_group.set_start_failure(error_type) + assert isinstance(controller.get_state(), InitializingState) + scaling_policy.queue_recovery_decision( ResizeDecision(num_workers=2, resources_per_worker={}) ) + + controller._run_control_loop_iteration() + assert isinstance(controller.get_state(), SchedulingState) + # Worker group will fail to start, but controller should not raise - # and should go into RECOVERING state. + # and should go into RESCHEDULING state. controller._run_control_loop_iteration() - assert controller.get_state() == TrainControllerState.RECOVERING + assert isinstance(controller.get_state(), ReschedulingState) # Let the worker group start successfully the 2nd time. worker_group.set_start_failure(None) scaling_policy.queue_recovery_decision( ResizeDecision(num_workers=2, resources_per_worker={}) ) + controller._run_control_loop_iteration() - assert controller.get_state() == TrainControllerState.RUNNING + assert isinstance(controller.get_state(), SchedulingState) + + controller._run_control_loop_iteration() + assert isinstance(controller.get_state(), RunningState) def test_poll_frequency(monkeypatch): @@ -322,7 +354,6 @@ def before_controller_execute_failure_decision( def before_controller_execute_scaling_decision( self, scaling_decision: ScalingDecision, - worker_group_status: WorkerGroupStatus, ): self.scaling_decision_called = True @@ -342,7 +373,6 @@ def before_controller_shutdown(self): failure_policy=failure_policy, callbacks=[callback], ) - controller._checkpoint_handler = MagicMock() worker_group = controller.get_worker_group() controller._start() @@ -351,21 +381,25 @@ def before_controller_shutdown(self): scaling_policy.queue_recovery_decision( ResizeDecision(num_workers=2, resources_per_worker={}) ) + + controller._run_control_loop_iteration() + assert not callback.scaling_decision_called + assert isinstance(callback.latest_state_update[0], InitializingState) + assert isinstance(callback.latest_state_update[1], SchedulingState) + controller._run_control_loop_iteration() assert callback.scaling_decision_called - assert callback.latest_state_update == ( - TrainControllerState.INITIALIZING, - TrainControllerState.RUNNING, - ) + assert isinstance(callback.latest_state_update[0], SchedulingState) + assert isinstance(callback.latest_state_update[1], RunningState) worker_group.error_worker(1) failure_policy.queue_decision(FailureDecision.RAISE) + + assert not callback.failure_decision_called controller._run_control_loop_iteration() assert callback.failure_decision_called - assert callback.latest_state_update == ( - TrainControllerState.RUNNING, - TrainControllerState.ERRORED, - ) + assert isinstance(callback.latest_state_update[0], RunningState) + assert isinstance(callback.latest_state_update[1], ErroredState) controller._shutdown() assert callback.shutdown_called diff --git a/python/ray/train/v2/tests/test_persistence.py b/python/ray/train/v2/tests/test_persistence.py index 3bfa2116d5f27..f3cb9fb8b8035 100644 --- a/python/ray/train/v2/tests/test_persistence.py +++ b/python/ray/train/v2/tests/test_persistence.py @@ -24,10 +24,7 @@ ScalingConfig, ) from ray.train.v2._internal.constants import HEALTH_CHECK_INTERVAL_S_ENV_VAR -from ray.train.v2._internal.execution.storage import ( - _delete_fs_path, - _download_from_fs_path, -) +from ray.train.v2._internal.execution.storage import _download_from_fs_path from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer @@ -228,48 +225,6 @@ def train_fn(config): raise RuntimeError(f"Failing on iter={i}!!") -def _resume_from_checkpoint( - checkpoint: Checkpoint, - expected_state: dict, - storage_path: Optional[str] = None, - storage_filesystem: Optional[pyarrow.fs.FileSystem] = None, -): - print(f"\nStarting run with `resume_from_checkpoint`: {checkpoint}\n") - - def assert_fn(config): - checkpoint_to_check = ray.train.get_checkpoint() - with checkpoint_to_check.as_directory() as checkpoint_dir: - with open(os.path.join(checkpoint_dir, "checkpoint.pkl"), "rb") as f: - state = pickle.load(f) - - print("Loaded state from `resume_from_checkpoint`:", state) - print("Expected state:", expected_state) - assert state == expected_state, (state, expected_state) - - dummy_ckpt = tempfile.mkdtemp() - with open(os.path.join(dummy_ckpt, "dummy.txt"), "w") as f: - f.write("data") - ray.train.report({"dummy": 1}, checkpoint=Checkpoint.from_directory(dummy_ckpt)) - - trainer = DataParallelTrainer( - assert_fn, - scaling_config=ScalingConfig(num_workers=2), - run_config=RunConfig( - name="test_resume_from_checkpoint", - storage_path=storage_path, - storage_filesystem=storage_filesystem, - ), - resume_from_checkpoint=checkpoint, - ) - result = trainer.fit() - - # Make sure that there is only on checkpoint recorded. - assert result.best_checkpoints and len(result.best_checkpoints) == 1 - - # Clean up this run's experiment directory immediately after. - _delete_fs_path(result.filesystem, Path(result.path).parent.as_posix()) - - def _assert_storage_contents( local_inspect_dir: Path, exp_name: str, @@ -384,7 +339,6 @@ def test_trainer( print("\nStarting initial run.\n") result = trainer.fit() - # TODO: Re-enable restoration / resume_from_checkpoint coverage print("\nStarting manually restored run.\n") restored_trainer = DataParallelTrainer( train_fn, @@ -400,11 +354,6 @@ def test_trainer( ) result = restored_trainer.fit() - _resume_from_checkpoint( - result.checkpoint, - expected_state={"iter": TestConstants.NUM_ITERATIONS - 1}, - ) - local_inspect_dir, storage_fs_path = _get_local_inspect_dir( root_local_path=tmp_path, storage_path=run_config.storage_path, diff --git a/python/ray/train/v2/tests/test_report_handler.py b/python/ray/train/v2/tests/test_report_handler.py index f5bcf813adebd..5ec74ea57b582 100644 --- a/python/ray/train/v2/tests/test_report_handler.py +++ b/python/ray/train/v2/tests/test_report_handler.py @@ -38,9 +38,7 @@ def generate_worker_group_status(num_workers, num_ckpt, num_dummy, num_none): ) random.shuffle(worker_statuses) - return WorkerGroupStatus( - num_workers, 0.0, {i: ws for i, ws in enumerate(worker_statuses)} - ) + return WorkerGroupStatus(num_workers, 0.0, dict(enumerate(worker_statuses))) @pytest.mark.parametrize( diff --git a/python/ray/train/v2/tests/test_result.py b/python/ray/train/v2/tests/test_result.py index 0bfbd3b684cf5..98043a5051eb0 100644 --- a/python/ray/train/v2/tests/test_result.py +++ b/python/ray/train/v2/tests/test_result.py @@ -1,24 +1,9 @@ -import pandas as pd import pytest from ray.train import Checkpoint from ray.train.v2.api.result import Result -def test_result_raises_with_dataframe(): - """Test that the Result init function raises an error when - metrics_dataframe is passed in as a parameter. - """ - with pytest.raises(TypeError): - Result( - metrics={}, - checkpoint=None, - error=None, - path=None, - metrics_dataframe=pd.DataFrame(), - ) - - def test_result_repr(): """Test that the Result __repr__ function can return a string.""" res = Result( @@ -33,8 +18,6 @@ def test_result_repr(): def test_get_best_checkpoint(): - """Test that the Result get_best_checkpoint function returns the correct""" - res = Result( metrics={}, checkpoint=None, diff --git a/python/ray/train/v2/tests/test_v2_api.py b/python/ray/train/v2/tests/test_v2_api.py index 67bf88c021ba2..6530fcfa7b844 100644 --- a/python/ray/train/v2/tests/test_v2_api.py +++ b/python/ray/train/v2/tests/test_v2_api.py @@ -19,7 +19,7 @@ ) def test_api_configs(operation, raise_error): if raise_error: - with pytest.raises(NotImplementedError): + with pytest.raises(DeprecationWarning): operation() else: try: diff --git a/python/ray/train/v2/tests/test_xgboost_trainer.py b/python/ray/train/v2/tests/test_xgboost_trainer.py index f7b8e0f5e1044..7a62fc0818b84 100644 --- a/python/ray/train/v2/tests/test_xgboost_trainer.py +++ b/python/ray/train/v2/tests/test_xgboost_trainer.py @@ -92,7 +92,7 @@ def xgboost_train_fn_per_worker( ) result = trainer.fit() with pytest.raises(DeprecationWarning): - trainer.get_model(result.checkpoint) + XGBoostTrainer.get_model(result.checkpoint) if __name__ == "__main__": diff --git a/python/ray/train/v2/torch/torch_trainer.py b/python/ray/train/v2/torch/torch_trainer.py index 171e10564aeec..b75cb1f3b6115 100644 --- a/python/ray/train/v2/torch/torch_trainer.py +++ b/python/ray/train/v2/torch/torch_trainer.py @@ -186,6 +186,7 @@ def __init__( run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, dataset_config: Optional[DataConfig] = None, + # TODO: [Deprecated] metadata: Optional[Dict[str, Any]] = None, resume_from_checkpoint: Optional[Checkpoint] = None, ): @@ -201,7 +202,6 @@ def __init__( run_config=run_config, dataset_config=dataset_config, datasets=datasets, - # TODO: Re-enable below. - # resume_from_checkpoint=resume_from_checkpoint, - # metadata=metadata, + resume_from_checkpoint=resume_from_checkpoint, + metadata=metadata, ) diff --git a/python/ray/train/v2/torch/train_loop_utils.py b/python/ray/train/v2/torch/train_loop_utils.py index e641feb21e240..26805d4090beb 100644 --- a/python/ray/train/v2/torch/train_loop_utils.py +++ b/python/ray/train/v2/torch/train_loop_utils.py @@ -18,7 +18,7 @@ import ray.train.torch from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag from ray.train.torch.train_loop_utils import _WrappedDataLoader -from ray.util.annotations import PublicAPI +from ray.util.annotations import Deprecated, PublicAPI if Version(torch.__version__) < Version("1.11.0"): FullyShardedDataParallel = None @@ -28,6 +28,17 @@ logger = logging.getLogger(__name__) +_TORCH_AMP_DEPRECATION_MESSAGE = ( + "The `accelerate`, `backward`, and `prepare_optimizer` utility methods " + "in the `ray.train.torch` module are deprecated and will be removed in a " + "future release. " + "Please use the native PyTorch mixed precision API directly, or " + "a library such as Lightning or HuggingFace Transformers/Accelerate. " + "See this issue for more context: " + "https://github.com/ray-project/ray/issues/49454" +) + + def prepare_model( model: torch.nn.Module, move_to_device: Union[bool, torch.device] = True, @@ -246,19 +257,19 @@ def with_sampler(loader): return data_loader -@PublicAPI(stability="beta") +@Deprecated def accelerate(amp: bool = False) -> None: - raise NotImplementedError + raise DeprecationWarning(_TORCH_AMP_DEPRECATION_MESSAGE) -@PublicAPI(stability="beta") +@Deprecated def prepare_optimizer(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer: - raise NotImplementedError + raise DeprecationWarning(_TORCH_AMP_DEPRECATION_MESSAGE) -@PublicAPI(stability="beta") +@Deprecated def backward(tensor: torch.Tensor) -> None: - raise NotImplementedError + raise DeprecationWarning(_TORCH_AMP_DEPRECATION_MESSAGE) @PublicAPI(stability="stable") diff --git a/python/ray/train/v2/xgboost/xgboost_trainer.py b/python/ray/train/v2/xgboost/xgboost_trainer.py index 44efc6b2dfb7a..829966b50ce1c 100644 --- a/python/ray/train/v2/xgboost/xgboost_trainer.py +++ b/python/ray/train/v2/xgboost/xgboost_trainer.py @@ -4,10 +4,10 @@ import ray.train from ray.train import Checkpoint from ray.train.trainer import GenDataset -from ray.train.v2._internal.constants import _UNSUPPORTED from ray.train.v2.api.config import RunConfig, ScalingConfig from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer from ray.train.xgboost import XGBoostConfig +from ray.util.annotations import Deprecated logger = logging.getLogger(__name__) @@ -124,7 +124,8 @@ def __init__( run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, dataset_config: Optional[ray.train.DataConfig] = None, - metadata: Optional[Dict[str, Any]] = _UNSUPPORTED, + # TODO: [Deprecated] + metadata: Optional[Dict[str, Any]] = None, resume_from_checkpoint: Optional[Checkpoint] = None, ): super(XGBoostTrainer, self).__init__( @@ -140,14 +141,9 @@ def __init__( ) @classmethod - def get_model( - cls, - checkpoint: Checkpoint, - ): - """Retrieve the XGBoost model stored in this checkpoint. - - This API is deprecated. Use `RayTrainReportCallback.get_model` instead. - """ + @Deprecated + def get_model(cls, checkpoint: Checkpoint): + """Retrieve the XGBoost model stored in this checkpoint.""" raise DeprecationWarning( "`XGBoostTrainer.get_model` is deprecated. " "Use `RayTrainReportCallback.get_model` instead." diff --git a/python/ray/tune/context.py b/python/ray/tune/context.py index 0575a2b7af125..f905e63e9ddad 100644 --- a/python/ray/tune/context.py +++ b/python/ray/tune/context.py @@ -2,7 +2,10 @@ from typing import Any, Dict, Optional from ray.train._internal import session -from ray.train.constants import _v2_migration_warnings_enabled +from ray.train.constants import ( + _v2_migration_warnings_enabled, + V2_MIGRATION_GUIDE_MESSAGE, +) from ray.train.context import TrainContext as TrainV1Context from ray.train.utils import _copy_doc from ray.tune.execution.placement_groups import PlacementGroupFactory @@ -16,7 +19,7 @@ _TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE = ( "`{}` is deprecated for Ray Tune because there is no concept of worker ranks " "for Ray Tune, so these methods only make sense to use in the context of " - "a Ray Train worker." + f"a Ray Train worker. {V2_MIGRATION_GUIDE_MESSAGE}" ) diff --git a/python/ray/tune/execution/class_cache.py b/python/ray/tune/execution/class_cache.py index 94c4b5148a4db..3042866290a85 100644 --- a/python/ray/tune/execution/class_cache.py +++ b/python/ray/tune/execution/class_cache.py @@ -3,13 +3,14 @@ import ray from ray.air.constants import COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV from ray.train.constants import ( - ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, RAY_CHDIR_TO_TRIAL_DIR, + ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, ) from ray.train.v2._internal.constants import ( ENV_VARS_TO_PROPAGATE as TRAIN_ENV_VARS_TO_PROPAGATE, ) + DEFAULT_ENV_VARS = { # https://github.com/ray-project/ray/issues/28197 "PL_DISABLE_FORK": "1" diff --git a/python/ray/tune/execution/placement_groups.py b/python/ray/tune/execution/placement_groups.py index 0848b147878d7..a2eaf548fd35a 100644 --- a/python/ray/tune/execution/placement_groups.py +++ b/python/ray/tune/execution/placement_groups.py @@ -114,7 +114,7 @@ def resource_dict_to_pg_factory(spec: Optional[Dict[str, float]] = None): memory = spec.pop("memory", 0.0) # If there is a custom_resources key, use as base for bundle - bundle = {k: v for k, v in spec.pop("custom_resources", {}).items()} + bundle = dict(spec.pop("custom_resources", {})) # Otherwise, consider all other keys as custom resources if not bundle: diff --git a/python/ray/tune/impl/config.py b/python/ray/tune/impl/config.py index 22731637cc8cc..f956a02bdbc0f 100644 --- a/python/ray/tune/impl/config.py +++ b/python/ray/tune/impl/config.py @@ -3,7 +3,10 @@ from ray.air.config import CheckpointConfig as _CheckpointConfig from ray.air.config import FailureConfig as _FailureConfig from ray.air.config import RunConfig as _RunConfig -from ray.train.constants import _v2_migration_warnings_enabled +from ray.train.constants import ( + _v2_migration_warnings_enabled, + V2_MIGRATION_GUIDE_MESSAGE, +) from ray.train.utils import _copy_doc, _log_deprecation_warning # NOTE: This is just a pass-through wrapper around `ray.train.RunConfig` @@ -36,6 +39,7 @@ def __post_init__(self): _log_deprecation_warning( "The `CheckpointConfig` class should be imported from `ray.tune` " "when passing it to the Tuner. Please update your imports." + f"{V2_MIGRATION_GUIDE_MESSAGE}" ) if not isinstance(self.failure_config, FailureConfig): @@ -43,4 +47,5 @@ def __post_init__(self): _log_deprecation_warning( "The `FailureConfig` class should be imported from `ray.tune` " "when passing it to the Tuner. Please update your imports." + f"{V2_MIGRATION_GUIDE_MESSAGE}" ) diff --git a/python/ray/tune/impl/tuner_internal.py b/python/ray/tune/impl/tuner_internal.py index 4e548a71139b7..7d94ce65d1cbc 100644 --- a/python/ray/tune/impl/tuner_internal.py +++ b/python/ray/tune/impl/tuner_internal.py @@ -23,7 +23,10 @@ from ray.air._internal.usage import AirEntrypoint from ray.train import ScalingConfig from ray.train._internal.storage import StorageContext, get_fs_and_path -from ray.train.constants import _v2_migration_warnings_enabled +from ray.train.constants import ( + _v2_migration_warnings_enabled, + V2_MIGRATION_GUIDE_MESSAGE, +) from ray.train.utils import _log_deprecation_warning from ray.tune import ( Experiment, @@ -103,10 +106,10 @@ def __init__( if isinstance(trainable, BaseTrainer): if _v2_migration_warnings_enabled(): _log_deprecation_warning( - "Passing a Trainer to the Tuner is deprecated. " - "See the section on hyperparameter optimization in this " - "REP for more information: " - "https://github.com/ray-project/enhancements/pull/57" + "The Ray Train + Ray Tune integration has been reworked. " + "Passing a Trainer to the Tuner is deprecated and will be removed " + "in a future release. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" ) run_config = self._choose_run_config( @@ -122,7 +125,8 @@ def __init__( if _v2_migration_warnings_enabled(): _log_deprecation_warning( "The `RunConfig` class should be imported from `ray.tune` " - "when passing it to the Tuner. Please update your imports." + "when passing it to the Tuner. Please update your imports. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" ) self._entrypoint = _entrypoint diff --git a/python/ray/tune/search/optuna/optuna_search.py b/python/ray/tune/search/optuna/optuna_search.py index 6c440b380c663..8b76a1570d00d 100644 --- a/python/ray/tune/search/optuna/optuna_search.py +++ b/python/ray/tune/search/optuna/optuna_search.py @@ -607,9 +607,7 @@ def add_evaluated_point( ot_trial_state = OptunaTrialState.PRUNED if intermediate_values: - intermediate_values_dict = { - i: value for i, value in enumerate(intermediate_values) - } + intermediate_values_dict = dict(enumerate(intermediate_values)) else: intermediate_values_dict = None diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index db1ca47e12dbe..08c42239023fd 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -1,15 +1,48 @@ +from dataclasses import dataclass import logging from ray.train._internal.syncer import SyncConfig as TrainSyncConfig -from ray.util.annotations import Deprecated +from ray.util.annotations import PublicAPI logger = logging.getLogger(__name__) -@Deprecated +@PublicAPI(stability="beta") +@dataclass class SyncConfig(TrainSyncConfig): - def __new__(cls: type, *args, **kwargs): - raise DeprecationWarning( - "`ray.tune.SyncConfig` has been moved to `ray.train.SyncConfig`. " - "Please update your code to use `ray.train.SyncConfig`." - ) + """Configuration object for Tune file syncing to `RunConfig(storage_path)`. + + In Ray Tune, here is where syncing (mainly uploading) happens: + + The experiment driver (on the head node) syncs the experiment directory to storage + (which includes experiment state such as searcher state, the list of trials + and their statuses, and trial metadata). + + It's also possible to sync artifacts from the trial directory to storage + by setting `sync_artifacts=True`. + For a Ray Tune run with many trials, each trial will upload its trial directory + to storage, which includes arbitrary files that you dumped during the run. + + See :ref:`persistent-storage-guide` for more details and examples. + + Args: + sync_period: Minimum time in seconds to wait between two sync operations. + A smaller ``sync_period`` will have the data in storage updated more often + but introduces more syncing overhead. Defaults to 5 minutes. + sync_timeout: Maximum time in seconds to wait for a sync process + to finish running. A sync operation will run for at most this long + before raising a `TimeoutError`. Defaults to 30 minutes. + sync_artifacts: [Beta] Whether or not to sync artifacts that are saved to the + trial directory (accessed via `tune.get_context().get_trial_dir()`) + to the persistent storage configured via `tune.RunConfig(storage_path)`. + The trial or remote worker will try to launch an artifact syncing + operation every time `tune.report` happens, subject to `sync_period` + and `sync_artifacts_on_checkpoint`. + Defaults to False -- no artifacts are persisted by default. + sync_artifacts_on_checkpoint: If True, trial/worker artifacts are + forcefully synced on every reported checkpoint. + This only has an effect if `sync_artifacts` is True. + Defaults to True. + """ + + pass diff --git a/python/ray/tune/trainable/trainable_fn_utils.py b/python/ray/tune/trainable/trainable_fn_utils.py index 2fcb1ac529aaa..b3b2ed4e21235 100644 --- a/python/ray/tune/trainable/trainable_fn_utils.py +++ b/python/ray/tune/trainable/trainable_fn_utils.py @@ -2,7 +2,10 @@ from ray.train._checkpoint import Checkpoint as TrainCheckpoint from ray.train._internal.session import _warn_session_misuse, get_session -from ray.train.constants import _v2_migration_warnings_enabled +from ray.train.constants import ( + _v2_migration_warnings_enabled, + V2_MIGRATION_GUIDE_MESSAGE, +) from ray.train.utils import _copy_doc, _log_deprecation_warning from ray.util.annotations import PublicAPI @@ -37,8 +40,9 @@ def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: if _v2_migration_warnings_enabled(): _log_deprecation_warning( "The `Checkpoint` class should be imported from `ray.tune` " - "when passing it to `ray.tune.report` in a Tune function." - "Please update your imports." + "when passing it to `ray.tune.report` in a Tune function. " + "Please update your imports. " + f"{V2_MIGRATION_GUIDE_MESSAGE}" ) get_session().report(metrics, checkpoint=checkpoint) diff --git a/python/ray/util/client/common.py b/python/ray/util/client/common.py index caf5572c69d0b..0c1b107c1f641 100644 --- a/python/ray/util/client/common.py +++ b/python/ray/util/client/common.py @@ -707,7 +707,7 @@ def _get_client_id_from_context(context: Any) -> str: Get `client_id` from gRPC metadata. If the `client_id` is not present, this function logs an error and sets the status_code. """ - metadata = {k: v for k, v in context.invocation_metadata()} + metadata = dict(context.invocation_metadata()) client_id = metadata.get("client_id") or "" if client_id == "": logger.error("Client connecting with no client_id") diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index 9ce816856e4df..af06b89027859 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -32,7 +32,7 @@ def _get_reconnecting_from_context(context: Any) -> bool: """ Get `reconnecting` from gRPC metadata, or False if missing. """ - metadata = {k: v for k, v in context.invocation_metadata()} + metadata = dict(context.invocation_metadata()) val = metadata.get("reconnecting") if val is None or val not in ("True", "False"): logger.error( @@ -155,7 +155,7 @@ def Datapath(self, request_iterator, context): start_time = time.time() # set to True if client shuts down gracefully cleanup_requested = False - metadata = {k: v for k, v in context.invocation_metadata()} + metadata = dict(context.invocation_metadata()) client_id = metadata.get("client_id") if client_id is None: logger.error("Client connecting with no client_id") diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 57acede6bd4d5..9d661251560b8 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -724,7 +724,7 @@ def get_cluster_info( resp = self.server.ClusterInfo(req, timeout=timeout, metadata=self.metadata) if resp.WhichOneof("response_type") == "resource_table": # translate from a proto map to a python dict - output_dict = {k: v for k, v in resp.resource_table.table.items()} + output_dict = dict(resp.resource_table.table) return output_dict elif resp.WhichOneof("response_type") == "runtime_context": return resp.runtime_context diff --git a/python/ray/util/collective/requirements.txt b/python/ray/util/collective/requirements.txt index ce5057b221f15..7af6e36186cfe 100644 --- a/python/ray/util/collective/requirements.txt +++ b/python/ray/util/collective/requirements.txt @@ -1 +1 @@ -cupy-cuda100 \ No newline at end of file +cupy-cuda100 diff --git a/python/ray/util/state/state_manager.py b/python/ray/util/state/state_manager.py index a4afa825f4668..307bab124aea3 100644 --- a/python/ray/util/state/state_manager.py +++ b/python/ray/util/state/state_manager.py @@ -4,6 +4,7 @@ from collections import defaultdict from functools import wraps from typing import List, Optional, Tuple +import json import aiohttp import grpc @@ -11,6 +12,7 @@ import ray import ray.dashboard.modules.log.log_consts as log_consts +import ray.dashboard.consts as dashboard_consts from ray._private import ray_constants from ray._private.gcs_utils import GcsAioClient from ray._private.utils import hex_to_binary @@ -154,7 +156,6 @@ def __init__(self, gcs_channel: grpc.aio.Channel, gcs_aio_client: GcsAioClient): self.register_gcs_client(gcs_channel) self._raylet_stubs = {} self._runtime_env_agent_addresses = {} # {node_id -> url} - self._log_agent_stub = {} self._job_client = JobInfoStorageClient(gcs_aio_client) self._id_ip_map = IdToIpMap() self._gcs_aio_client = gcs_aio_client @@ -204,18 +205,6 @@ def unregister_raylet_client(self, node_id: str): self._runtime_env_agent_addresses.pop(node_id) self._id_ip_map.pop(node_id) - def register_agent_client(self, node_id, address: str, port: int): - options = _STATE_MANAGER_GRPC_OPTIONS - channel = ray._private.utils.init_grpc_channel( - f"{address}:{port}", options=options, asynchronous=True - ) - self._log_agent_stub[node_id] = LogServiceStub(channel) - self._id_ip_map.put(node_id, address) - - def unregister_agent_client(self, node_id: str): - self._log_agent_stub.pop(node_id) - self._id_ip_map.pop(node_id) - def get_all_registered_raylet_ids(self) -> List[str]: return self._raylet_stubs.keys() @@ -223,9 +212,21 @@ def get_all_registered_raylet_ids(self) -> List[str]: def get_all_registered_runtime_env_agent_ids(self) -> List[str]: return self._runtime_env_agent_addresses.keys() - # Returns all nod_ids which registered their log_agent_stub. - def get_all_registered_log_agent_ids(self) -> List[str]: - return self._log_agent_stub.keys() + async def get_log_service_stub(self, node_id: NodeID) -> Optional[LogServiceStub]: + """Returns None if the agent on the node is not registered in Internal KV.""" + agent_addr = await self._gcs_aio_client.internal_kv_get( + f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id.hex()}".encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + timeout=dashboard_consts.GCS_RPC_TIMEOUT_SECONDS, + ) + if not agent_addr: + return None + ip, http_port, grpc_port = json.loads(agent_addr) + options = ray_constants.GLOBAL_GRPC_OPTIONS + channel = ray._private.utils.init_grpc_channel( + f"{ip}:{grpc_port}", options=options, asynchronous=True + ) + return LogServiceStub(channel) def ip_to_node_id(self, ip: Optional[str]) -> Optional[str]: """Return the node id that corresponds to the given ip. @@ -495,7 +496,7 @@ async def get_runtime_envs_info( async def list_logs( self, node_id: str, glob_filter: str, timeout: int = None ) -> ListLogsReply: - stub = self._log_agent_stub.get(node_id) + stub = await self.get_log_service_stub(NodeID.from_hex(node_id)) if not stub: raise ValueError(f"Agent for node id: {node_id} doesn't exist.") return await stub.ListLogs( @@ -514,7 +515,7 @@ async def stream_log( start_offset: Optional[int] = None, end_offset: Optional[int] = None, ) -> UnaryStreamCall: - stub = self._log_agent_stub.get(node_id) + stub = await self.get_log_service_stub(NodeID.from_hex(node_id)) if not stub: raise ValueError(f"Agent for node id: {node_id} doesn't exist.") diff --git a/python/ray/workflow/tests/test_basic_workflows_3.py b/python/ray/workflow/tests/test_basic_workflows_3.py index d0f02b3d82c5a..90d936620e15a 100644 --- a/python/ray/workflow/tests/test_basic_workflows_3.py +++ b/python/ray/workflow/tests/test_basic_workflows_3.py @@ -60,11 +60,8 @@ def test_run_off_main_thread(workflow_start_regular_shared): def fake_data(num: int): return list(range(num)) - succ = False - # Start new thread here ⚠️ def run(): - global succ # Setup the workflow. assert workflow.run(fake_data.bind(10), workflow_id="run") == list(range(10)) diff --git a/python/requirements/test-requirements.txt b/python/requirements/test-requirements.txt index 6089f62b5f58b..46060e23c5d33 100644 --- a/python/requirements/test-requirements.txt +++ b/python/requirements/test-requirements.txt @@ -1,5 +1,4 @@ ## Requirements for running tests -# These should all be pinned to versions that work # General test requirements async-exit-stack==1.0.1 @@ -42,7 +41,7 @@ opentelemetry-exporter-opencensus==0.20b0 pexpect==4.8.0 Pillow==10.3.0; platform_system != "Windows" proxy.py==2.4.3 -pydantic==2.5.0 +pydantic>=2.9.0 pydot==1.4.2 pyopenssl==24.2.1 pygame==2.5.2 @@ -58,7 +57,6 @@ pytest-sugar==0.9.5 pytest-lazy-fixture==0.6.3 pytest-timeout==2.1.0 pytest-virtualenv==1.7.0; python_version < "3.12" -virtualenv==20.25.3 pytest-sphinx @ git+https://github.com/ray-project/pytest-sphinx redis==4.4.2 scikit-learn==1.3.2 @@ -98,10 +96,7 @@ tensorboard tensorboard-data-server==0.7.2 h11==0.12.0 markdown-it-py -attrs==21.4.0 pytz==2022.7.1 -# Compatibility with spacy 3.5 (model en_core_web_sm) -typing-extensions==4.8.0 # Aim requires segment-analytics-python, which requires backoff~=2.10, # which conflicts with the opentelemetry-api 1.1.0. segment-analytics-python==2.2.0 @@ -115,3 +110,11 @@ numexpr==2.8.4 # For `serve run --reload` CLI. watchfiles==0.19.0 + +# Upgrades +typing-extensions>=4.10 +filelock>=3.16.1 +virtualenv>=20.29 +jsonschema>=4.23.0 +attrs>=22.2.0 +openapi-schema-validator>=0.6.3 diff --git a/python/requirements_compiled.txt b/python/requirements_compiled.txt index 564664ccdb0fd..8c1b9230e930a 100644 --- a/python/requirements_compiled.txt +++ b/python/requirements_compiled.txt @@ -128,7 +128,7 @@ async-generator==1.10 # via -r /ray/ci/../python/requirements/test-requirements.txt async-timeout==4.0.3 # via redis -attrs==21.4.0 +attrs==25.1.0 # via # -r /ray/ci/../python/requirements/test-requirements.txt # aiohttp @@ -137,6 +137,7 @@ attrs==21.4.0 # jsonschema # jupyter-cache # open-spiel + # referencing # sarif-om # semgrep aws-sam-translator==1.81.0 @@ -488,9 +489,10 @@ feather-format==0.4.1 # via -r /ray/ci/../python/requirements/test-requirements.txt ffmpy==0.3.1 # via gradio -filelock==3.13.1 +filelock==3.17.0 # via # -r /ray/ci/../python/requirements.txt + # -r /ray/ci/../python/requirements/test-requirements.txt # aim # datasets # huggingface-hub @@ -846,14 +848,14 @@ jsonpointer==2.4 # via # jsonpatch # jsonschema -jsonschema==4.17.3 +jsonschema==4.23.0 # via # -r /ray/ci/../python/requirements.txt + # -r /ray/ci/../python/requirements/test-requirements.txt # altair # aws-sam-translator # cfn-lint # comet-ml - # jsonschema-spec # jupyter-events # jupyterlab-server # nbformat @@ -861,8 +863,12 @@ jsonschema==4.17.3 # openapi-spec-validator # ray # semgrep -jsonschema-spec==0.1.6 +jsonschema-path==0.3.4 # via openapi-spec-validator +jsonschema-specifications==2024.10.1 + # via + # jsonschema + # openapi-schema-validator junit-xml==1.9 # via cfn-lint jupyter-cache==0.6.1 @@ -1255,9 +1261,11 @@ onnxruntime==1.18.0 ; sys_platform != "darwin" or platform_machine != "arm64" # via -r /ray/ci/../python/requirements/ml/rllib-requirements.txt open-spiel==1.4 # via -r /ray/ci/../python/requirements/ml/rllib-test-requirements.txt -openapi-schema-validator==0.4.4 - # via openapi-spec-validator -openapi-spec-validator==0.5.7 +openapi-schema-validator==0.6.3 + # via + # -r /ray/ci/../python/requirements/test-requirements.txt + # openapi-spec-validator +openapi-spec-validator==0.7.1 # via moto opencensus==0.11.3 # via -r /ray/ci/../python/requirements.txt @@ -1393,7 +1401,7 @@ path==16.14.0 path-py==12.5.0 # via pytest-shutil pathable==0.4.3 - # via jsonschema-spec + # via jsonschema-path pathspec==0.11.2 # via black patsy==0.5.3 @@ -1548,7 +1556,7 @@ pycparser==2.21 # via cffi pycurl==7.45.3 # via -r /ray/ci/../python/requirements/anyscale-requirements.txt -pydantic==2.5.0 +pydantic==2.9.2 # via # -r /ray/ci/../python/requirements.txt # -r /ray/ci/../python/requirements/test-requirements.txt @@ -1557,7 +1565,7 @@ pydantic==2.5.0 # fastapi # gradio # pyiceberg -pydantic-core==2.14.1 +pydantic-core==2.23.4 # via pydantic pydot==1.4.2 # via -r /ray/ci/../python/requirements/test-requirements.txt @@ -1628,8 +1636,6 @@ pyro-ppl==1.9.1 # via botorch pyro4==4.82 # via hpbandster -pyrsistent==0.20.0 - # via jsonschema pysocks==1.7.1 # via requests pyspark==3.4.1 @@ -1742,7 +1748,7 @@ pyyaml==6.0.1 # distributed # gradio # huggingface-hub - # jsonschema-spec + # jsonschema-path # jupyter-cache # jupyter-events # jupytext @@ -1783,6 +1789,11 @@ raydp==1.7.0b20231020.dev0 # via -r /ray/ci/../python/requirements/ml/data-test-requirements.txt redis==4.4.2 # via -r /ray/ci/../python/requirements/test-requirements.txt +referencing==0.36.2 + # via + # jsonschema + # jsonschema-path + # jsonschema-specifications regex==2024.5.15 # via # cfn-lint @@ -1809,7 +1820,7 @@ requests==2.31.0 # gradio # gradio-client # huggingface-hub - # jsonschema-spec + # jsonschema-path # jupyterlab-server # kubernetes # mlflow @@ -1868,6 +1879,10 @@ rich==13.3.2 # pyiceberg # semgrep # typer +rpds-py==0.22.3 + # via + # jsonschema + # referencing rsa==4.7.2 # via # gcs-oauth2-boto-plugin @@ -2295,7 +2310,7 @@ types-requests==2.31.0.6 # via -r /ray/ci/../python/requirements/lint-requirements.txt types-urllib3==1.26.25.14 # via types-requests -typing-extensions==4.8.0 +typing-extensions==4.12.2 # via # -r /ray/ci/../python/requirements/test-requirements.txt # alembic @@ -2317,6 +2332,7 @@ typing-extensions==4.8.0 # pydantic # pydantic-core # pytorch-lightning + # referencing # semgrep # snowflake-connector-python # sqlalchemy @@ -2372,7 +2388,7 @@ uvloop==0.19.0 # vmc-draas-client-bindings # vsphere-automation-sdk # via vsphere-automation-sdk -virtualenv==20.25.3 +virtualenv==20.29.1 # via # -r /ray/ci/../python/requirements.txt # -r /ray/ci/../python/requirements/test-requirements.txt diff --git a/python/setup.py b/python/setup.py index c235806aa8c37..4a716102ee664 100644 --- a/python/setup.py +++ b/python/setup.py @@ -545,7 +545,7 @@ def build(build_python, build_java, build_cpp): # version of Python to build packages inside the build.sh script. Note # that certain flags will not be passed along such as --user or sudo. # TODO(rkn): Fix this. - if not os.getenv("SKIP_THIRDPARTY_INSTALL"): + if not os.getenv("SKIP_THIRDPARTY_INSTALL_CONDA_FORGE"): pip_packages = ["psutil", "setproctitle==1.2.2", "colorama"] subprocess.check_call( [ @@ -560,19 +560,20 @@ def build(build_python, build_java, build_cpp): env=dict(os.environ, CC="gcc"), ) - # runtime env agent dependenceis - runtime_env_agent_pip_packages = ["aiohttp"] - subprocess.check_call( - [ - sys.executable, - "-m", - "pip", - "install", - "-q", - "--target=" + os.path.join(ROOT_DIR, RUNTIME_ENV_AGENT_THIRDPARTY_SUBDIR), - ] - + runtime_env_agent_pip_packages - ) + # runtime env agent dependenceis + runtime_env_agent_pip_packages = ["aiohttp"] + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "-q", + "--target=" + + os.path.join(ROOT_DIR, RUNTIME_ENV_AGENT_THIRDPARTY_SUBDIR), + ] + + runtime_env_agent_pip_packages + ) bazel_flags = ["--verbose_failures"] if BAZEL_ARGS: diff --git a/release/air_tests/air_benchmarks/mlperf-train/process_imagenet.sh b/release/air_tests/air_benchmarks/mlperf-train/process_imagenet.sh index 7565d2bd40014..207a481b37f28 100755 --- a/release/air_tests/air_benchmarks/mlperf-train/process_imagenet.sh +++ b/release/air_tests/air_benchmarks/mlperf-train/process_imagenet.sh @@ -19,4 +19,4 @@ for filename in "$INPUT_DIR"/*; do out_path="$OUTPUT_DIR/$class_dir/$img_path" echo "$out_path" cp "$INPUT_DIR"/"$filename" "$out_path" -done \ No newline at end of file +done diff --git a/release/benchmarks/distributed/config.yaml b/release/benchmarks/distributed/config.yaml index 2cc52ef8ba40d..8cdcf06edd538 100644 --- a/release/benchmarks/distributed/config.yaml +++ b/release/benchmarks/distributed/config.yaml @@ -53,4 +53,4 @@ setup_commands: - pip install tqdm - sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 65535" >> /etc/security/limits.conf; echo "* hard nofile 65535" >> /etc/security/limits.conf;' -idle_timeout_minutes: 1 \ No newline at end of file +idle_timeout_minutes: 1 diff --git a/release/benchmarks/object_store/config.yaml b/release/benchmarks/object_store/config.yaml index 951e4f324da4b..2d4869d9abc40 100644 --- a/release/benchmarks/object_store/config.yaml +++ b/release/benchmarks/object_store/config.yaml @@ -36,4 +36,4 @@ setup_commands: - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.4.0/6ac5e0e5ad45070e27c77aca7267bcee30cc4b4a/ray-1.4.0-cp37-cp37m-manylinux2014_x86_64.whl - pip install tqdm numpy -idle_timeout_minutes: 5 \ No newline at end of file +idle_timeout_minutes: 5 diff --git a/release/benchmarks/single_node/config.yaml b/release/benchmarks/single_node/config.yaml index b29ea70e24512..6ce222d5595c4 100644 --- a/release/benchmarks/single_node/config.yaml +++ b/release/benchmarks/single_node/config.yaml @@ -34,4 +34,4 @@ setup_commands: - pip install numpy tqdm - sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1000000" >> /etc/security/limits.conf; echo "* hard nofile 1000000" >> /etc/security/limits.conf;' -idle_timeout_minutes: 5 \ No newline at end of file +idle_timeout_minutes: 5 diff --git a/release/long_running_tests/README.rst b/release/long_running_tests/README.rst index 80dfd6f7fbed1..484567fd259bb 100644 --- a/release/long_running_tests/README.rst +++ b/release/long_running_tests/README.rst @@ -58,4 +58,4 @@ Adding a Workload To create a new workload, simply add a new Python file under ``workloads/`` and add the workload in the run command in `ray-project/project.yaml`. -.. _`Releaser`: https://github.com/ray-project/releaser \ No newline at end of file +.. _`Releaser`: https://github.com/ray-project/releaser diff --git a/release/ml_user_tests/horovod/driver_requirements.txt b/release/ml_user_tests/horovod/driver_requirements.txt index 6c8cb758d5659..1eccd80821e7b 100755 --- a/release/ml_user_tests/horovod/driver_requirements.txt +++ b/release/ml_user_tests/horovod/driver_requirements.txt @@ -5,4 +5,4 @@ -c ../../../python/requirements/ml/dl-cpu-requirements.txt torch -torchvision \ No newline at end of file +torchvision diff --git a/release/ml_user_tests/horovod/driver_setup_latest.sh b/release/ml_user_tests/horovod/driver_setup_latest.sh index 91741f74ba37c..a4de5f3b690aa 100755 --- a/release/ml_user_tests/horovod/driver_setup_latest.sh +++ b/release/ml_user_tests/horovod/driver_setup_latest.sh @@ -9,4 +9,4 @@ pip install cmake pip install -U -r ./driver_requirements.txt -HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip install horovod \ No newline at end of file +HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip install horovod diff --git a/release/ml_user_tests/horovod/driver_setup_master.sh b/release/ml_user_tests/horovod/driver_setup_master.sh index cfa0b2bf635bc..f5b06738f1da6 100755 --- a/release/ml_user_tests/horovod/driver_setup_master.sh +++ b/release/ml_user_tests/horovod/driver_setup_master.sh @@ -8,4 +8,4 @@ pip install cmake pip install -U -r ./driver_requirements.txt -HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip install -U git+https://github.com/horovod/horovod.git \ No newline at end of file +HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip install -U git+https://github.com/horovod/horovod.git diff --git a/release/ml_user_tests/train/driver_requirements.txt b/release/ml_user_tests/train/driver_requirements.txt index df4c7e35f75eb..8135917df4193 100755 --- a/release/ml_user_tests/train/driver_requirements.txt +++ b/release/ml_user_tests/train/driver_requirements.txt @@ -5,4 +5,4 @@ -c ../../../python/requirements/ml/dl-cpu-requirements.txt torch -tensorflow \ No newline at end of file +tensorflow diff --git a/release/ml_user_tests/tune_rllib/driver_requirements.txt b/release/ml_user_tests/tune_rllib/driver_requirements.txt index 56a9ab752ce29..2b379dbb7847c 100755 --- a/release/ml_user_tests/tune_rllib/driver_requirements.txt +++ b/release/ml_user_tests/tune_rllib/driver_requirements.txt @@ -7,4 +7,4 @@ tensorflow torch # Need this library to unpickle errors -tblib \ No newline at end of file +tblib diff --git a/release/nightly_tests/dataset/batch_inference_benchmark.py b/release/nightly_tests/dataset/batch_inference_benchmark.py new file mode 100644 index 0000000000000..9eae5416ca5cc --- /dev/null +++ b/release/nightly_tests/dataset/batch_inference_benchmark.py @@ -0,0 +1,120 @@ +import argparse +import io +import uuid +from typing import Any, Dict + +import boto3 +import numpy as np +import pandas as pd +import torch +from benchmark import Benchmark +from PIL import Image +from torchvision.models import ResNet50_Weights, resnet50 + +import ray +from ray.data import ActorPoolStrategy + +BUCKET = "anyscale-imagenet" +# This Parquet file contains the keys of images in the 'anyscale-imagenet' bucket. +METADATA_PATH = "s3://anyscale-imagenet/metadata.parquet" + +# Largest batch that can fit on a T4. +BATCH_SIZE = 900 + +WRITE_PATH = f"s3://ray-data-write-benchmark/{uuid.uuid4().hex}" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--sf", + dest="scale_factor", + type=int, + default=1, + help=( + "The number of copies of ImageNet to read. Use this to simulate a larger " + "dataset." + ), + ) + return parser.parse_args() + + +def main(args: argparse.Namespace): + benchmark = Benchmark() + + metadata = pd.read_parquet(METADATA_PATH) + # Repeat the metadata 'scale_factor' times to simulate a larger dataset. + metadata = pd.concat([metadata] * args.scale_factor, ignore_index=True) + + def benchmark_fn(): + weights = ResNet50_Weights.DEFAULT + model = resnet50(weights=weights) + model_ref = ray.put(model) + + # Get the preprocessing transforms from the pre-trained weights. + transform = weights.transforms() + + ( + ray.data.from_pandas(metadata) + # TODO: There should be a way to specify "use as many actors as possible" + # with the now-recommended `concurrency` parameter. + .map(LoadImage, compute=ActorPoolStrategy(min_size=1)) + # Preprocess the images using standard preprocessing + .map(ApplyTransform(transform)) + .map_batches( + Predictor, + batch_size=BATCH_SIZE, + compute=ActorPoolStrategy(min_size=1), + num_gpus=1, + fn_constructor_kwargs={"model": model_ref, "device": "cuda"}, + ) + .write_parquet(WRITE_PATH) + ) + + benchmark.run_fn("main", benchmark_fn) + benchmark.write_result() + + +class LoadImage: + def __init__(self): + self._client = boto3.client("s3") + + def __call__(self, row): + data = io.BytesIO() + self._client.download_fileobj(BUCKET, row["key"], data) + image = Image.open(data).convert("RGB") + return {"image": np.array(image)} + + +class ApplyTransform: + def __init__(self, transform): + self._transform = transform + + def __call__(self, row: Dict[str, Any]) -> Dict[str, Any]: + # 'row["image"]' isn't writeable, and Torch only supports writeable tensors, so + # we need to maky a copy to prevent Torch from complaining. + tensor_batch = torch.as_tensor(np.copy(row["image"]), dtype=torch.float) + # (H, W, C) -> (C, H, W). This is required for the torchvision transform. + # https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights # noqa: E501 + tensor_batch = tensor_batch.permute(2, 0, 1) + transformed_batch = self._transform(tensor_batch).numpy() + return {"image": transformed_batch} + + +class Predictor: + def __init__(self, model, device): + self._model = ray.get(model) + self._model.eval() + self._model.to(device) + + self._device = device + + def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + with torch.inference_mode(): + output = self._model(torch.as_tensor(batch["image"], device=self._device)) + return {"predictions": output.cpu().numpy()} + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/release/nightly_tests/dataset/multi_node_train_16_workers.yaml b/release/nightly_tests/dataset/multi_node_train_16_workers.yaml index 62aca9ccf8e84..1b19d0ee7b443 100644 --- a/release/nightly_tests/dataset/multi_node_train_16_workers.yaml +++ b/release/nightly_tests/dataset/multi_node_train_16_workers.yaml @@ -12,4 +12,4 @@ worker_node_types: instance_type: g4dn.4xlarge max_workers: 15 min_workers: 15 - use_spot: false \ No newline at end of file + use_spot: false diff --git a/release/nightly_tests/dataset/multi_node_train_4_workers.yaml b/release/nightly_tests/dataset/multi_node_train_4_workers.yaml index abe912e55fff1..4f61aa63efdeb 100644 --- a/release/nightly_tests/dataset/multi_node_train_4_workers.yaml +++ b/release/nightly_tests/dataset/multi_node_train_4_workers.yaml @@ -12,4 +12,4 @@ worker_node_types: instance_type: g4dn.4xlarge max_workers: 3 min_workers: 3 - use_spot: false \ No newline at end of file + use_spot: false diff --git a/release/nightly_tests/dataset/run_image_loader_microbenchmark.sh b/release/nightly_tests/dataset/run_image_loader_microbenchmark.sh index a1251da3b75c3..d9f91c5fe1770 100755 --- a/release/nightly_tests/dataset/run_image_loader_microbenchmark.sh +++ b/release/nightly_tests/dataset/run_image_loader_microbenchmark.sh @@ -22,4 +22,4 @@ aws s3 sync s3://imagenetmini1000/1gb-tfrecords $TFRECORDS_DIR # Preprocess parquet and mosaic files. python preprocess_images.py --data-root "$DIR" --mosaic-data-root "$MOSAIC_DIR" --parquet-data-root "$PARQUET_DIR" -python image_loader_microbenchmark.py --data-root "$DIR" --mosaic-data-root "$MOSAIC_DIR" --parquet-data-root "$PARQUET_DIR" --tf-data-root "$TFRECORDS_DIR"/train \ No newline at end of file +python image_loader_microbenchmark.py --data-root "$DIR" --mosaic-data-root "$MOSAIC_DIR" --parquet-data-root "$PARQUET_DIR" --tf-data-root "$TFRECORDS_DIR"/train diff --git a/release/nightly_tests/stress_tests/stress_tests_single_node_oom_compute.yaml b/release/nightly_tests/stress_tests/stress_tests_single_node_oom_compute.yaml index 32a636a50ae34..51382cfd6b2b1 100644 --- a/release/nightly_tests/stress_tests/stress_tests_single_node_oom_compute.yaml +++ b/release/nightly_tests/stress_tests/stress_tests_single_node_oom_compute.yaml @@ -7,4 +7,4 @@ head_node_type: name: head_node instance_type: m5.xlarge -worker_node_types: [] \ No newline at end of file +worker_node_types: [] diff --git a/release/perf_metrics/benchmarks/many_actors.json b/release/perf_metrics/benchmarks/many_actors.json index 55488c64ff277..c6f2983d65dda 100644 --- a/release/perf_metrics/benchmarks/many_actors.json +++ b/release/perf_metrics/benchmarks/many_actors.json @@ -1,32 +1,32 @@ { - "_dashboard_memory_usage_mb": 425.73824, + "_dashboard_memory_usage_mb": 444.891136, "_dashboard_test_success": true, - "_peak_memory": 3.78, - "_peak_process_memory": "PID\tMEM\tCOMMAND\n1110\t8.12GiB\t/app/product/go/infra/anyscaled/anyscaled_/anyscaled startv2 --control_plane_url=https://console.any\n3495\t1.79GiB\t/home/ray/anaconda3/lib/python3.9/site-packages/ray/core/src/ray/gcs/gcs_server --log_dir=/tmp/ray/s\n5107\t0.89GiB\tpython distributed/test_many_actors.py\n3611\t0.37GiB\t/home/ray/anaconda3/bin/python /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/dashboa\n2744\t0.35GiB\tvector --watch-config --log-format json --config-yaml /etc/vector/vector.yaml\n583\t0.14GiB\t/app/go/infra/anyscaled/anyscaled_/anyscaled_shim --cloud_provider=aws\n3121\t0.09GiB\t/usr/bin/python3 /app/infra/dataplane/webterminal/webterminal_sidecar_image.binary.runfiles/product/\n3805\t0.09GiB\t/home/ray/anaconda3/bin/python -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/agen\n3807\t0.07GiB\t/home/ray/anaconda3/bin/python -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/_private/runti\n3835\t0.07GiB\t/home/ray/anaconda3/bin/python /home/ray/anaconda3/bin/jupyter-lab --allow-root --ip=127.0.0.1 --no-", - "actors_per_second": 605.527621831549, + "_peak_memory": 3.82, + "_peak_process_memory": "PID\tMEM\tCOMMAND\n3469\t1.85GiB\t/home/ray/anaconda3/lib/python3.9/site-packages/ray/core/src/ray/gcs/gcs_server --log_dir=/tmp/ray/s\n5358\t0.85GiB\tpython distributed/test_many_actors.py\n2835\t0.4GiB\tvector --watch-config --log-format json --config-yaml /etc/vector/vector.yaml\n3585\t0.38GiB\t/home/ray/anaconda3/bin/python3.9 /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/dash\n1099\t0.23GiB\t/app/product/go/infra/anyscaled/anyscaled_/anyscaled startv2 --control_plane_url=https://console.any\n586\t0.14GiB\t/app/go/infra/anyscaled/anyscaled_/anyscaled_shim --cloud_provider=aws\n2828\t0.09GiB\t/usr/bin/python3 /app/infra/dataplane/webterminal/webterminal_sidecar_image.binary.runfiles/product/\n3779\t0.09GiB\t/home/ray/anaconda3/bin/python3.9 -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/a\n3781\t0.07GiB\t/home/ray/anaconda3/bin/python3.9 -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/_private/ru\n3805\t0.07GiB\t/home/ray/anaconda3/bin/python3.9 /home/ray/anaconda3/bin/jupyter-lab --allow-root --ip=127.0.0.1 --", + "actors_per_second": 581.3621620015273, "num_actors": 10000, "perf_metrics": [ { "perf_metric_name": "actors_per_second", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 605.527621831549 + "perf_metric_value": 581.3621620015273 }, { "perf_metric_name": "dashboard_p50_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 138.675 + "perf_metric_value": 50.674 }, { "perf_metric_name": "dashboard_p95_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 3213.567 + "perf_metric_value": 2727.324 }, { "perf_metric_name": "dashboard_p99_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 4703.106 + "perf_metric_value": 3492.239 } ], "success": "1", - "time": 16.51452326774597 + "time": 17.200981855392456 } diff --git a/release/perf_metrics/benchmarks/many_nodes.json b/release/perf_metrics/benchmarks/many_nodes.json index b254d61d019c3..8a7e3090143c0 100644 --- a/release/perf_metrics/benchmarks/many_nodes.json +++ b/release/perf_metrics/benchmarks/many_nodes.json @@ -1,14 +1,14 @@ { - "_dashboard_memory_usage_mb": 225.501184, + "_dashboard_memory_usage_mb": 227.344384, "_dashboard_test_success": true, "_peak_memory": 1.67, - "_peak_process_memory": "PID\tMEM\tCOMMAND\n6321\t0.52GiB\t/home/ray/anaconda3/lib/python3.9/site-packages/ray/core/src/ray/gcs/gcs_server --log_dir=/tmp/ray/s\n1088\t0.27GiB\t/app/product/go/infra/anyscaled/anyscaled_/anyscaled startv2 --control_plane_url=https://console.any\n2914\t0.26GiB\tvector --watch-config --log-format json --config-yaml /etc/vector/vector.yaml\n6437\t0.21GiB\t/home/ray/anaconda3/bin/python /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/dashboa\n7449\t0.16GiB\tpython distributed/test_many_tasks.py --num-tasks=1000\n6628\t0.1GiB\t/home/ray/anaconda3/bin/python -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/agen\n2881\t0.09GiB\t/usr/bin/python3 /app/infra/dataplane/webterminal/webterminal_sidecar_image.binary.runfiles/product/\n7718\t0.08GiB\tray::StateAPIGeneratorActor.start\n6630\t0.07GiB\t/home/ray/anaconda3/bin/python -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/_private/runti\n6721\t0.07GiB\t/home/ray/anaconda3/bin/python /home/ray/anaconda3/bin/jupyter-lab --allow-root --ip=127.0.0.1 --no-", + "_peak_process_memory": "PID\tMEM\tCOMMAND\n3423\t0.51GiB\t/home/ray/anaconda3/lib/python3.9/site-packages/ray/core/src/ray/gcs/gcs_server --log_dir=/tmp/ray/s\n2972\t0.25GiB\tvector --watch-config --log-format json --config-yaml /etc/vector/vector.yaml\n2059\t0.24GiB\t/app/product/go/infra/anyscaled/anyscaled_/anyscaled startv2 --control_plane_url=https://console.any\n3539\t0.21GiB\t/home/ray/anaconda3/bin/python3.9 /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/dash\n4401\t0.16GiB\tpython distributed/test_many_tasks.py --num-tasks=1000\n3740\t0.1GiB\t/home/ray/anaconda3/bin/python3.9 -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/a\n2873\t0.09GiB\t/usr/bin/python3 /app/infra/dataplane/webterminal/webterminal_sidecar_image.binary.runfiles/product/\n4603\t0.08GiB\tray::StateAPIGeneratorActor.start\n3742\t0.07GiB\t/home/ray/anaconda3/bin/python3.9 -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/_private/ru\n4180\t0.07GiB\tray::JobSupervisor", "num_tasks": 1000, "perf_metrics": [ { "perf_metric_name": "tasks_per_second", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 471.21245501533576 + "perf_metric_value": 427.6684054263659 }, { "perf_metric_name": "used_cpus_by_deadline", @@ -18,21 +18,21 @@ { "perf_metric_name": "dashboard_p50_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 4.347 + "perf_metric_value": 4.297 }, { "perf_metric_name": "dashboard_p95_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 39.196 + "perf_metric_value": 36.499 }, { "perf_metric_name": "dashboard_p99_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 89.057 + "perf_metric_value": 89.739 } ], "success": "1", - "tasks_per_second": 471.21245501533576, - "time": 302.12218499183655, + "tasks_per_second": 427.6684054263659, + "time": 302.3382601737976, "used_cpus": 250.0 } diff --git a/release/perf_metrics/benchmarks/many_pgs.json b/release/perf_metrics/benchmarks/many_pgs.json index 5ea334288c5b0..b32ae3db5a343 100644 --- a/release/perf_metrics/benchmarks/many_pgs.json +++ b/release/perf_metrics/benchmarks/many_pgs.json @@ -1,32 +1,32 @@ { - "_dashboard_memory_usage_mb": 171.487232, + "_dashboard_memory_usage_mb": 172.720128, "_dashboard_test_success": true, "_peak_memory": 2.16, - "_peak_process_memory": "PID\tMEM\tCOMMAND\n1107\t7.43GiB\t/app/product/go/infra/anyscaled/anyscaled_/anyscaled startv2 --control_plane_url=https://console.any\n3511\t0.96GiB\t/home/ray/anaconda3/lib/python3.9/site-packages/ray/core/src/ray/gcs/gcs_server --log_dir=/tmp/ray/s\n4510\t0.36GiB\tpython distributed/test_many_pgs.py\n2880\t0.27GiB\tvector --watch-config --log-format json --config-yaml /etc/vector/vector.yaml\n585\t0.14GiB\t/app/go/infra/anyscaled/anyscaled_/anyscaled_shim --cloud_provider=aws\n3629\t0.14GiB\t/home/ray/anaconda3/bin/python /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/dashboa\n3053\t0.09GiB\t/usr/bin/python3 /app/infra/dataplane/webterminal/webterminal_sidecar_image.binary.runfiles/product/\n3822\t0.09GiB\t/home/ray/anaconda3/bin/python -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/agen\n3824\t0.07GiB\t/home/ray/anaconda3/bin/python -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/_private/runti\n3848\t0.07GiB\t/home/ray/anaconda3/bin/python /home/ray/anaconda3/bin/jupyter-lab --allow-root --ip=127.0.0.1 --no-", + "_peak_process_memory": "PID\tMEM\tCOMMAND\n3483\t0.96GiB\t/home/ray/anaconda3/lib/python3.9/site-packages/ray/core/src/ray/gcs/gcs_server --log_dir=/tmp/ray/s\n2824\t0.41GiB\tvector --watch-config --log-format json --config-yaml /etc/vector/vector.yaml\n4997\t0.36GiB\tpython distributed/test_many_pgs.py\n1098\t0.23GiB\t/app/product/go/infra/anyscaled/anyscaled_/anyscaled startv2 --control_plane_url=https://console.any\n3601\t0.13GiB\t/home/ray/anaconda3/bin/python3.9 /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/dash\n583\t0.13GiB\t/app/go/infra/anyscaled/anyscaled_/anyscaled_shim --cloud_provider=aws\n2720\t0.09GiB\t/usr/bin/python3 /app/infra/dataplane/webterminal/webterminal_sidecar_image.binary.runfiles/product/\n3794\t0.09GiB\t/home/ray/anaconda3/bin/python3.9 -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/a\n3796\t0.07GiB\t/home/ray/anaconda3/bin/python3.9 -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/_private/ru\n3818\t0.07GiB\t/home/ray/anaconda3/bin/python3.9 /home/ray/anaconda3/bin/jupyter-lab --allow-root --ip=127.0.0.1 --", "num_pgs": 1000, "perf_metrics": [ { "perf_metric_name": "pgs_per_second", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 22.635892345065848 + "perf_metric_value": 22.74259799982883 }, { "perf_metric_name": "dashboard_p50_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 3.253 + "perf_metric_value": 3.777 }, { "perf_metric_name": "dashboard_p95_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 10.134 + "perf_metric_value": 10.692 }, { "perf_metric_name": "dashboard_p99_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 832.35 + "perf_metric_value": 396.497 } ], - "pgs_per_second": 22.635892345065848, + "pgs_per_second": 22.74259799982883, "success": "1", - "time": 44.177626609802246 + "time": 43.97035026550293 } diff --git a/release/perf_metrics/benchmarks/many_tasks.json b/release/perf_metrics/benchmarks/many_tasks.json index 77be5629e27d1..e975fdd312f14 100644 --- a/release/perf_metrics/benchmarks/many_tasks.json +++ b/release/perf_metrics/benchmarks/many_tasks.json @@ -1,14 +1,14 @@ { - "_dashboard_memory_usage_mb": 768.884736, + "_dashboard_memory_usage_mb": 741.86752, "_dashboard_test_success": true, - "_peak_memory": 3.56, - "_peak_process_memory": "PID\tMEM\tCOMMAND\n3463\t1.13GiB\t/home/ray/anaconda3/lib/python3.9/site-packages/ray/core/src/ray/gcs/gcs_server --log_dir=/tmp/ray/s\n3579\t0.84GiB\t/home/ray/anaconda3/bin/python /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/dashboa\n4464\t0.74GiB\tpython distributed/test_many_tasks.py --num-tasks=10000\n2768\t0.24GiB\tvector --watch-config --log-format json --config-yaml /etc/vector/vector.yaml\n1076\t0.24GiB\t/app/product/go/infra/anyscaled/anyscaled_/anyscaled startv2 --control_plane_url=https://console.any\n580\t0.12GiB\t/app/go/infra/anyscaled/anyscaled_/anyscaled_shim --cloud_provider=aws\n3772\t0.09GiB\t/home/ray/anaconda3/bin/python -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/agen\n2955\t0.09GiB\t/usr/bin/python3 /app/infra/dataplane/webterminal/webterminal_sidecar_image.binary.runfiles/product/\n4609\t0.08GiB\tray::DashboardTester.run\n4665\t0.08GiB\tray::StateAPIGeneratorActor.start", + "_peak_memory": 3.58, + "_peak_process_memory": "PID\tMEM\tCOMMAND\n3537\t1.16GiB\t/home/ray/anaconda3/lib/python3.9/site-packages/ray/core/src/ray/gcs/gcs_server --log_dir=/tmp/ray/s\n3653\t0.87GiB\t/home/ray/anaconda3/bin/python3.9 /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/dash\n4619\t0.74GiB\tpython distributed/test_many_tasks.py --num-tasks=10000\n2985\t0.23GiB\tvector --watch-config --log-format json --config-yaml /etc/vector/vector.yaml\n1104\t0.22GiB\t/app/product/go/infra/anyscaled/anyscaled_/anyscaled startv2 --control_plane_url=https://console.any\n584\t0.13GiB\t/app/go/infra/anyscaled/anyscaled_/anyscaled_shim --cloud_provider=aws\n3846\t0.09GiB\t/home/ray/anaconda3/bin/python3.9 -u /home/ray/anaconda3/lib/python3.9/site-packages/ray/dashboard/a\n3061\t0.09GiB\t/usr/bin/python3 /app/infra/dataplane/webterminal/webterminal_sidecar_image.binary.runfiles/product/\n4833\t0.08GiB\tray::DashboardTester.run\n4889\t0.08GiB\tray::StateAPIGeneratorActor.start", "num_tasks": 10000, "perf_metrics": [ { "perf_metric_name": "tasks_per_second", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 552.9063643030292 + "perf_metric_value": 588.1874987887217 }, { "perf_metric_name": "used_cpus_by_deadline", @@ -18,21 +18,21 @@ { "perf_metric_name": "dashboard_p50_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 148.169 + "perf_metric_value": 115.202 }, { "perf_metric_name": "dashboard_p95_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 626.809 + "perf_metric_value": 522.805 }, { "perf_metric_name": "dashboard_p99_latency_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 873.471 + "perf_metric_value": 849.598 } ], "success": "1", - "tasks_per_second": 552.9063643030292, - "time": 318.08624505996704, + "tasks_per_second": 588.1874987887217, + "time": 317.0013813972473, "used_cpus": 2500.0 } diff --git a/release/perf_metrics/metadata.json b/release/perf_metrics/metadata.json index cc2c465aae131..6c032dbb0137d 100644 --- a/release/perf_metrics/metadata.json +++ b/release/perf_metrics/metadata.json @@ -1 +1 @@ -{"release_version": "2.41.0"} +{"release_version": "2.42.0"} diff --git a/release/perf_metrics/microbenchmark.json b/release/perf_metrics/microbenchmark.json index 365e3380d8872..454406d643e14 100644 --- a/release/perf_metrics/microbenchmark.json +++ b/release/perf_metrics/microbenchmark.json @@ -1,283 +1,283 @@ { "1_1_actor_calls_async": [ - 8398.56885282787, - 90.64470042092343 + 8107.035947364004, + 194.70645316452402 ], "1_1_actor_calls_concurrent": [ - 5268.769788618396, - 118.51034431437533 + 5218.943213086157, + 123.36432993082248 ], "1_1_actor_calls_sync": [ - 2071.6501933724253, - 66.54602398051257 + 1985.8445748508243, + 15.604868913363081 ], "1_1_async_actor_calls_async": [ - 4594.0039367756, - 180.69800094596656 + 4669.450952977499, + 320.07551524119316 ], "1_1_async_actor_calls_sync": [ - 1507.4722826901059, - 38.639742834270045 + 1474.7069154202795, + 36.43494597335915 ], "1_1_async_actor_calls_with_args_async": [ - 2906.3799497984073, - 162.39581991600068 + 2953.9546990180343, + 128.10840371596203 ], "1_n_actor_calls_async": [ - 8087.002681858739, - 179.1488833569923 + 8136.711770399841, + 99.07793497695116 ], "1_n_async_actor_calls_async": [ - 7747.258573610084, - 159.28066572462012 + 7488.95904603153, + 32.91861129793621 ], "client__1_1_actor_calls_async": [ - 1041.3583204574459, - 11.15719859730251 + 1010.2913968814835, + 10.885467967202615 ], "client__1_1_actor_calls_concurrent": [ - 1040.0631707535902, - 5.885662397510047 + 1005.657835173811, + 12.425189876213183 ], "client__1_1_actor_calls_sync": [ - 529.5868989134584, - 4.437050021111481 + 522.5361491166506, + 6.471488244048678 ], "client__get_calls": [ - 984.938248602108, - 19.39903981382506 + 969.0161835588581, + 23.107992379350016 ], "client__put_calls": [ - 782.2028385276735, - 22.37727916683932 + 785.9139426536512, + 14.074198087763435 ], "client__put_gigabytes": [ - 0.15273001910830053, - 0.0003966330505155439 + 0.15408690555366225, + 0.0007793485636983514 ], "client__tasks_and_get_batch": [ - 0.9748684492189577, - 0.013917827149089729 + 0.838835573987595, + 0.050083575264214385 ], "client__tasks_and_put_batch": [ - 14510.26499948894, - 142.94582485457454 + 14255.363968554979, + 77.33604494108931 ], "multi_client_put_calls_Plasma_Store": [ - 16476.91701973554, - 153.5268006228045 + 15931.811977493457, + 203.1678701142663 ], "multi_client_put_gigabytes": [ - 45.59421490337185, - 4.252528426368044 + 47.38817873082456, + 4.689430307355956 ], "multi_client_tasks_async": [ - 23754.393304342077, - 3937.6238945605232 + 22745.167201851888, + 3116.4036637471227 ], "n_n_actor_calls_async": [ - 27627.813056199215, - 851.4280591592274 + 26441.672940245888, + 632.5223605512703 ], "n_n_actor_calls_with_arg_async": [ - 2707.177086616588, - 18.84646363376401 + 2732.074477927061, + 27.55949571801805 ], "n_n_async_actor_calls_async": [ - 23879.523989612106, - 487.9621758881768 + 23390.23156817461, + 636.7923320872748 ], "perf_metrics": [ { "perf_metric_name": "single_client_get_calls_Plasma_Store", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 10641.803495982997 + "perf_metric_value": 10611.609624248378 }, { "perf_metric_name": "single_client_put_calls_Plasma_Store", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 4953.332414166378 + "perf_metric_value": 4866.041059585032 }, { "perf_metric_name": "multi_client_put_calls_Plasma_Store", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 16476.91701973554 + "perf_metric_value": 15931.811977493457 }, { "perf_metric_name": "single_client_put_gigabytes", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 17.025876066565033 + "perf_metric_value": 18.521529437957106 }, { "perf_metric_name": "single_client_tasks_and_get_batch", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 8.246266578072571 + "perf_metric_value": 7.56687497857023 }, { "perf_metric_name": "multi_client_put_gigabytes", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 45.59421490337185 + "perf_metric_value": 47.38817873082456 }, { "perf_metric_name": "single_client_get_object_containing_10k_refs", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 13.400670287257304 + "perf_metric_value": 12.987017792446045 }, { "perf_metric_name": "single_client_wait_1k_refs", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 5.5615049020254625 + "perf_metric_value": 5.424179804481537 }, { "perf_metric_name": "single_client_tasks_sync", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 1010.2085841645662 + "perf_metric_value": 1013.1673399687909 }, { "perf_metric_name": "single_client_tasks_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 7963.424588687801 + "perf_metric_value": 8032.409007811969 }, { "perf_metric_name": "multi_client_tasks_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 23754.393304342077 + "perf_metric_value": 22745.167201851888 }, { "perf_metric_name": "1_1_actor_calls_sync", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 2071.6501933724253 + "perf_metric_value": 1985.8445748508243 }, { "perf_metric_name": "1_1_actor_calls_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 8398.56885282787 + "perf_metric_value": 8107.035947364004 }, { "perf_metric_name": "1_1_actor_calls_concurrent", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 5268.769788618396 + "perf_metric_value": 5218.943213086157 }, { "perf_metric_name": "1_n_actor_calls_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 8087.002681858739 + "perf_metric_value": 8136.711770399841 }, { "perf_metric_name": "n_n_actor_calls_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 27627.813056199215 + "perf_metric_value": 26441.672940245888 }, { "perf_metric_name": "n_n_actor_calls_with_arg_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 2707.177086616588 + "perf_metric_value": 2732.074477927061 }, { "perf_metric_name": "1_1_async_actor_calls_sync", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 1507.4722826901059 + "perf_metric_value": 1474.7069154202795 }, { "perf_metric_name": "1_1_async_actor_calls_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 4594.0039367756 + "perf_metric_value": 4669.450952977499 }, { "perf_metric_name": "1_1_async_actor_calls_with_args_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 2906.3799497984073 + "perf_metric_value": 2953.9546990180343 }, { "perf_metric_name": "1_n_async_actor_calls_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 7747.258573610084 + "perf_metric_value": 7488.95904603153 }, { "perf_metric_name": "n_n_async_actor_calls_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 23879.523989612106 + "perf_metric_value": 23390.23156817461 }, { "perf_metric_name": "placement_group_create/removal", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 758.764762704843 + "perf_metric_value": 749.1128148700998 }, { "perf_metric_name": "client__get_calls", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 984.938248602108 + "perf_metric_value": 969.0161835588581 }, { "perf_metric_name": "client__put_calls", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 782.2028385276735 + "perf_metric_value": 785.9139426536512 }, { "perf_metric_name": "client__put_gigabytes", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 0.15273001910830053 + "perf_metric_value": 0.15408690555366225 }, { "perf_metric_name": "client__tasks_and_put_batch", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 14510.26499948894 + "perf_metric_value": 14255.363968554979 }, { "perf_metric_name": "client__1_1_actor_calls_sync", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 529.5868989134584 + "perf_metric_value": 522.5361491166506 }, { "perf_metric_name": "client__1_1_actor_calls_async", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 1041.3583204574459 + "perf_metric_value": 1010.2913968814835 }, { "perf_metric_name": "client__1_1_actor_calls_concurrent", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 1040.0631707535902 + "perf_metric_value": 1005.657835173811 }, { "perf_metric_name": "client__tasks_and_get_batch", "perf_metric_type": "THROUGHPUT", - "perf_metric_value": 0.9748684492189577 + "perf_metric_value": 0.838835573987595 } ], "placement_group_create/removal": [ - 758.764762704843, - 2.5049732038490395 + 749.1128148700998, + 14.03142021121613 ], "single_client_get_calls_Plasma_Store": [ - 10641.803495982997, - 150.88664147721323 + 10611.609624248378, + 213.62210043960678 ], "single_client_get_object_containing_10k_refs": [ - 13.400670287257304, - 0.3357627983063862 + 12.987017792446045, + 0.03521768912488595 ], "single_client_put_calls_Plasma_Store": [ - 4953.332414166378, - 28.875351193249827 + 4866.041059585032, + 52.743326922210485 ], "single_client_put_gigabytes": [ - 17.025876066565033, - 10.326642986853164 + 18.521529437957106, + 8.62315762744488 ], "single_client_tasks_and_get_batch": [ - 8.246266578072571, - 0.522110999548341 + 7.56687497857023, + 0.3939166243754313 ], "single_client_tasks_async": [ - 7963.424588687801, - 461.55034927417944 + 8032.409007811969, + 398.4370160720034 ], "single_client_tasks_sync": [ - 1010.2085841645662, - 6.557428364155709 + 1013.1673399687909, + 11.015801941113976 ], "single_client_wait_1k_refs": [ - 5.5615049020254625, - 0.1315180945429515 + 5.424179804481537, + 0.07320253629955845 ] } diff --git a/release/perf_metrics/scalability/object_store.json b/release/perf_metrics/scalability/object_store.json index 5cba7a0a1ac9c..5b27eefe7eb9c 100644 --- a/release/perf_metrics/scalability/object_store.json +++ b/release/perf_metrics/scalability/object_store.json @@ -1,12 +1,12 @@ { - "broadcast_time": 16.12803921599999, + "broadcast_time": 14.082180108999992, "num_nodes": 50, "object_size": 1073741824, "perf_metrics": [ { "perf_metric_name": "time_to_broadcast_1073741824_bytes_to_50_nodes", "perf_metric_type": "LATENCY", - "perf_metric_value": 16.12803921599999 + "perf_metric_value": 14.082180108999992 } ], "success": "1" diff --git a/release/perf_metrics/scalability/single_node.json b/release/perf_metrics/scalability/single_node.json index 28c077d631980..86b5ab9f0f578 100644 --- a/release/perf_metrics/scalability/single_node.json +++ b/release/perf_metrics/scalability/single_node.json @@ -1,8 +1,8 @@ { - "args_time": 18.440583357999998, - "get_time": 24.405157770000002, + "args_time": 17.283816822999995, + "get_time": 23.877336194999998, "large_object_size": 107374182400, - "large_object_time": 33.19615447799998, + "large_object_time": 30.33785850800001, "num_args": 10000, "num_get_args": 10000, "num_queued": 1000000, @@ -11,30 +11,30 @@ { "perf_metric_name": "10000_args_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 18.440583357999998 + "perf_metric_value": 17.283816822999995 }, { "perf_metric_name": "3000_returns_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 5.834724514999991 + "perf_metric_value": 5.807459645999998 }, { "perf_metric_name": "10000_get_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 24.405157770000002 + "perf_metric_value": 23.877336194999998 }, { "perf_metric_name": "1000000_queued_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 186.31882492000003 + "perf_metric_value": 192.979472547 }, { "perf_metric_name": "107374182400_large_object_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 33.19615447799998 + "perf_metric_value": 30.33785850800001 } ], - "queued_time": 186.31882492000003, - "returns_time": 5.834724514999991, + "queued_time": 192.979472547, + "returns_time": 5.807459645999998, "success": "1" } diff --git a/release/perf_metrics/stress_tests/stress_test_dead_actors.json b/release/perf_metrics/stress_tests/stress_test_dead_actors.json index ea96e9c8d13ab..90ddacdd75455 100644 --- a/release/perf_metrics/stress_tests/stress_test_dead_actors.json +++ b/release/perf_metrics/stress_tests/stress_test_dead_actors.json @@ -1,14 +1,14 @@ { - "avg_iteration_time": 0.7329249548912048, - "max_iteration_time": 1.9340107440948486, - "min_iteration_time": 0.049217939376831055, + "avg_iteration_time": 0.7432218527793885, + "max_iteration_time": 3.088859796524048, + "min_iteration_time": 0.06358766555786133, "perf_metrics": [ { "perf_metric_name": "avg_iteration_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 0.7329249548912048 + "perf_metric_value": 0.7432218527793885 } ], "success": 1, - "total_time": 73.29262828826904 + "total_time": 74.32231259346008 } diff --git a/release/perf_metrics/stress_tests/stress_test_many_tasks.json b/release/perf_metrics/stress_tests/stress_test_many_tasks.json index 167f6db0feb15..fb2d89e7e3326 100644 --- a/release/perf_metrics/stress_tests/stress_test_many_tasks.json +++ b/release/perf_metrics/stress_tests/stress_test_many_tasks.json @@ -3,45 +3,45 @@ { "perf_metric_name": "stage_0_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 4.95395827293396 + "perf_metric_value": 4.678402662277222 }, { "perf_metric_name": "stage_1_avg_iteration_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 12.305748534202575 + "perf_metric_value": 12.529016280174256 }, { "perf_metric_name": "stage_2_avg_iteration_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 39.20367703437805 + "perf_metric_value": 39.98830094337463 }, { "perf_metric_name": "stage_3_creation_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 1.1791002750396729 + "perf_metric_value": 1.1666333675384521 }, { "perf_metric_name": "stage_3_time", "perf_metric_type": "LATENCY", - "perf_metric_value": 1931.609727859497 + "perf_metric_value": 1823.5409452915192 }, { "perf_metric_name": "stage_4_spread", "perf_metric_type": "LATENCY", - "perf_metric_value": 0.2312385877525104 + "perf_metric_value": 0.20731175121781154 } ], - "stage_0_time": 4.95395827293396, - "stage_1_avg_iteration_time": 12.305748534202575, - "stage_1_max_iteration_time": 13.314747333526611, - "stage_1_min_iteration_time": 10.843807458877563, - "stage_1_time": 123.05753707885742, - "stage_2_avg_iteration_time": 39.20367703437805, - "stage_2_max_iteration_time": 39.455753803253174, - "stage_2_min_iteration_time": 39.00661754608154, - "stage_2_time": 196.01891422271729, - "stage_3_creation_time": 1.1791002750396729, - "stage_3_time": 1931.609727859497, - "stage_4_spread": 0.2312385877525104, + "stage_0_time": 4.678402662277222, + "stage_1_avg_iteration_time": 12.529016280174256, + "stage_1_max_iteration_time": 13.275079488754272, + "stage_1_min_iteration_time": 10.937284469604492, + "stage_1_time": 125.29022645950317, + "stage_2_avg_iteration_time": 39.98830094337463, + "stage_2_max_iteration_time": 40.26270246505737, + "stage_2_min_iteration_time": 39.52923345565796, + "stage_2_time": 199.94206047058105, + "stage_3_creation_time": 1.1666333675384521, + "stage_3_time": 1823.5409452915192, + "stage_4_spread": 0.20731175121781154, "success": 1 } diff --git a/release/perf_metrics/stress_tests/stress_test_placement_group.json b/release/perf_metrics/stress_tests/stress_test_placement_group.json index 65cc9b9a89bef..4408fb9e6e852 100644 --- a/release/perf_metrics/stress_tests/stress_test_placement_group.json +++ b/release/perf_metrics/stress_tests/stress_test_placement_group.json @@ -1,16 +1,16 @@ { - "avg_pg_create_time_ms": 1.5571090375378522, - "avg_pg_remove_time_ms": 1.248110743243383, + "avg_pg_create_time_ms": 1.470244198198109, + "avg_pg_remove_time_ms": 1.2978452747749973, "perf_metrics": [ { "perf_metric_name": "avg_pg_create_time_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 1.5571090375378522 + "perf_metric_value": 1.470244198198109 }, { "perf_metric_name": "avg_pg_remove_time_ms", "perf_metric_type": "LATENCY", - "perf_metric_value": 1.248110743243383 + "perf_metric_value": 1.2978452747749973 } ], "success": 1 diff --git a/release/ray_release/byod/requirements_byod_3.9.in b/release/ray_release/byod/requirements_byod_3.9.in index cca2b941f7eec..a60a833a81890 100644 --- a/release/ray_release/byod/requirements_byod_3.9.in +++ b/release/ray_release/byod/requirements_byod_3.9.in @@ -13,7 +13,7 @@ gsutil gymnasium gymnasium[atari] importlib-metadata -jsonschema==4.17.3 +jsonschema lightgbm locust==2.18.0 memray @@ -35,7 +35,7 @@ tensorflow trueskill tqdm typer -typing-extensions==4.8.0 +typing-extensions xarray xgboost zarr diff --git a/release/ray_release/byod/requirements_byod_3.9.txt b/release/ray_release/byod/requirements_byod_3.9.txt index f1dcf9ee13a84..4623ebe37fe49 100644 --- a/release/ray_release/byod/requirements_byod_3.9.txt +++ b/release/ray_release/byod/requirements_byod_3.9.txt @@ -178,13 +178,14 @@ async-timeout==4.0.3 \ # via # -c release/ray_release/byod/requirements_compiled.txt # aiohttp -attrs==21.4.0 \ - --hash=sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4 \ - --hash=sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd +attrs==25.1.0 \ + --hash=sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e \ + --hash=sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a # via # -c release/ray_release/byod/requirements_compiled.txt # aiohttp # jsonschema + # referencing bokeh==2.4.3 \ --hash=sha256:104d2f0a4ca7774ee4b11e545aa34ff76bf3e2ad6de0d33944361981b65da420 \ --hash=sha256:ef33801161af379665ab7a34684f2209861e3aefd5c803a21fbbb99d94874b03 @@ -1368,12 +1369,18 @@ joblib==1.2.0 \ # via # -c release/ray_release/byod/requirements_compiled.txt # scikit-learn -jsonschema==4.17.3 \ - --hash=sha256:0f864437ab8b6076ba6707453ef8f98a6a0d512a80e93f8abdb676f737ecb60d \ - --hash=sha256:a870ad254da1a8ca84b6a2905cac29d265f805acc57af304784962a2aa6508f6 +jsonschema==4.23.0 \ + --hash=sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4 \ + --hash=sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566 # via # -c release/ray_release/byod/requirements_compiled.txt # -r release/ray_release/byod/requirements_byod_3.9.in +jsonschema-specifications==2024.10.1 \ + --hash=sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272 \ + --hash=sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf + # via + # -c release/ray_release/byod/requirements_compiled.txt + # jsonschema keras==2.15.0 \ --hash=sha256:2dcc6d2e30cf9c951064b63c1f4c404b966c59caf09e01f3549138ec8ee0dd1f \ --hash=sha256:81871d298c064dc4ac6b58440fdae67bfcf47c8d7ad28580fab401834c06a575 @@ -2092,115 +2099,103 @@ pycparser==2.21 \ # via # -c release/ray_release/byod/requirements_compiled.txt # cffi -pydantic==2.5.0 \ - --hash=sha256:69bd6fb62d2d04b7055f59a396993486a2ee586c43a0b89231ce0000de07627c \ - --hash=sha256:7ce6e766c456ad026fe5712f7bcf036efc34bd5d107b3e669ef7ea01b3a9050c +pydantic==2.9.2 \ + --hash=sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f \ + --hash=sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12 # via # -c release/ray_release/byod/requirements_compiled.txt # -r release/ray_release/byod/requirements_byod_3.9.in # fastapi -pydantic-core==2.14.1 \ - --hash=sha256:023b6d7ec4e97890b28eb2ee24413e69a6d48de4e8b75123957edd5432f4eeb3 \ - --hash=sha256:052d8731aaf844f91fe4cd3faf28983b109a5865b3a256ec550b80a5689ead87 \ - --hash=sha256:0a8c8daf4e3aa3aeb98e3638fc3d58a359738f3d12590b2474c6bb64031a0764 \ - --hash=sha256:0d82a6ee815388a362885186e431fac84c7a06623bc136f508e9f88261d8cadb \ - --hash=sha256:101df420e954966868b8bc992aefed5fa71dd1f2755104da62ee247abab28e2f \ - --hash=sha256:102ac85a775e77821943ae38da9634ddd774b37a8d407181b4f7b05cdfb36b55 \ - --hash=sha256:1185548665bc61bbab0dc78f10c8eafa0db0aa1e920fe9a451b77782b10a65cc \ - --hash=sha256:12163197fec7c95751a3c71b36dcc1909eed9959f011ffc79cc8170a6a74c826 \ - --hash=sha256:130e49aa0cb316f743bc7792c36aefa39fc2221312f1d4b333b19edbdd71f2b1 \ - --hash=sha256:132b40e479cb5cebbbb681f77aaceabbc8355df16c9124cff1d4060ada83cde2 \ - --hash=sha256:144f2c1d5579108b6ed1193fcc9926124bd4142b0f7020a7744980d1235c8a40 \ - --hash=sha256:16f4a7e1ec6b3ea98a1e108a2739710cd659d68b33fbbeaba066202cab69c7b6 \ - --hash=sha256:184ff7b30c3f60e1b775378c060099285fd4b5249271046c9005f8b247b39377 \ - --hash=sha256:1bfb63821ada76719ffcd703fc40dd57962e0d8c253e3c565252e6de6d3e0bc6 \ - --hash=sha256:1e7208946ea9b27a8cef13822c339d4ae96e45952cc01fc4a91c7f1cb0ae2861 \ - --hash=sha256:217dcbfaf429a9b8f1d54eb380908b9c778e78f31378283b30ba463c21e89d5d \ - --hash=sha256:2459cc06572730e079ec1e694e8f68c99d977b40d98748ae72ff11ef21a56b0b \ - --hash=sha256:24ba48f9d0b8d64fc5e42e1600366c3d7db701201294989aebdaca23110c02ab \ - --hash=sha256:26242e3593d4929123615bd9365dd86ef79b7b0592d64a96cd11fd83c69c9f34 \ - --hash=sha256:2871daf5b2823bf77bf7d3d43825e5d904030c155affdf84b21a00a2e00821d2 \ - --hash=sha256:28734bcfb8fc5b03293dec5eb5ea73b32ff767f6ef79a31f6e41dad2f5470270 \ - --hash=sha256:2a7d08b39fac97540fba785fce3b21ee01a81f081a07a4d031efd791da6666f9 \ - --hash=sha256:2be018a84995b6be1bbd40d6064395dbf71592a981169cf154c0885637f5f54a \ - --hash=sha256:3303113fdfaca927ef11e0c5f109e2ec196c404f9d7ba5f8ddb63cdf287ea159 \ - --hash=sha256:36c3bf96f803e207a80dbcb633d82b98ff02a9faa76dd446e969424dec8e2b9f \ - --hash=sha256:3d5b2a4b3c10cad0615670cab99059441ff42e92cf793a0336f4bc611e895204 \ - --hash=sha256:3f48d4afd973abbd65266ac24b24de1591116880efc7729caf6b6b94a9654c9e \ - --hash=sha256:42d5d0e9bbb50481a049bd0203224b339d4db04006b78564df2b782e2fd16ebc \ - --hash=sha256:443dc5eede7fa76b2370213e0abe881eb17c96f7d694501853c11d5d56916602 \ - --hash=sha256:49ee28d65f506b2858a60745cc974ed005298ebab12693646b97641dd7c99c35 \ - --hash=sha256:4f0788699a92d604f348e9c1ac5e97e304e97127ba8325c7d0af88dcc7d35bd3 \ - --hash=sha256:51506e7652a2ef1d1cf763c4b51b972ff4568d1dddc96ca83931a6941f5e6389 \ - --hash=sha256:53efe03cc383a83660cfdda6a3cb40ee31372cedea0fde0b2a2e55e838873ab6 \ - --hash=sha256:55713d155da1e508083c4b08d0b1ad2c3054f68b8ef7eb3d3864822e456f0bb5 \ - --hash=sha256:581bb606a31749a00796f5257947a0968182d7fe91e1dada41f06aeb6bfbc91a \ - --hash=sha256:5879ac4791508d8f0eb7dec71ff8521855180688dac0c55f8c99fc4d1a939845 \ - --hash=sha256:587d75aec9ae50d0d63788cec38bf13c5128b3fc1411aa4b9398ebac884ab179 \ - --hash=sha256:59fa83873223f856d898452c6162a390af4297756f6ba38493a67533387d85d9 \ - --hash=sha256:5a1570875eb0d1479fb2270ed80c88c231aaaf68b0c3f114f35e7fb610435e4f \ - --hash=sha256:5b45b7be9f99991405ecd6f6172fb6798908a8097106ae78d5cc5cc15121bad9 \ - --hash=sha256:6015beb28deb5306049ecf2519a59627e9e050892927850a884df6d5672f8c7d \ - --hash=sha256:6590ed9d13eb51b28ea17ddcc6c8dbd6050b4eb589d497105f0e13339f223b72 \ - --hash=sha256:66dc0e63349ec39c1ea66622aa5c2c1f84382112afd3ab2fa0cca4fb01f7db39 \ - --hash=sha256:679cc4e184f213c8227862e57340d12fd4d4d19dc0e3ddb0f653f86f01e90f94 \ - --hash=sha256:69cd74e55a5326d920e7b46daa2d81c2bdb8bcf588eafb2330d981297b742ddc \ - --hash=sha256:69df82892ff00491d673b1929538efb8c8d68f534fdc6cb7fd3ac8a5852b9034 \ - --hash=sha256:72c2ef3787c3b577e5d6225d73a77167b942d12cef3c1fbd5e74e55b7f881c36 \ - --hash=sha256:744b807fe2733b6da3b53e8ad93e8b3ea3ee3dfc3abece4dd2824cc1f39aa343 \ - --hash=sha256:7977e261cac5f99873dc2c6f044315d09b19a71c4246560e1e67593889a90978 \ - --hash=sha256:798590d38c9381f07c48d13af1f1ef337cebf76ee452fcec5deb04aceced51c7 \ - --hash=sha256:812beca1dcb2b722cccc7e9c620bd972cbc323321194ec2725eab3222e6ac573 \ - --hash=sha256:8276bbab68a9dbe721da92d19cbc061f76655248fe24fb63969d0c3e0e5755e7 \ - --hash=sha256:85bb66d661be51b2cba9ca06759264b3469d2dbb53c3e6effb3f05fec6322be6 \ - --hash=sha256:871c641a83719caaa856a11dcc61c5e5b35b0db888e1a0d338fe67ce744575e2 \ - --hash=sha256:893bf4fb9bfb9c4639bc12f3de323325ada4c6d60e478d5cded65453e9364890 \ - --hash=sha256:8d927d042c0ef04607ee7822828b208ab045867d20477ec6593d612156798547 \ - --hash=sha256:8e17f0c3ba4cb07faa0038a59ce162de584ed48ba645c8d05a5de1e40d4c21e7 \ - --hash=sha256:9486e27bb3f137f33e2315be2baa0b0b983dae9e2f5f5395240178ad8e644728 \ - --hash=sha256:94cf6d0274eb899d39189144dcf52814c67f9b0fd196f211420d9aac793df2da \ - --hash=sha256:97246f896b4df7fd84caa8a75a67abb95f94bc0b547665bf0889e3262b060399 \ - --hash=sha256:9d59e0d7cdfe8ed1d4fcd28aad09625c715dc18976c7067e37d8a11b06f4be3e \ - --hash=sha256:a15f6e5588f7afb7f6fc4b0f4ff064749e515d34f34c666ed6e37933873d8ad8 \ - --hash=sha256:a2ccdc53cb88e51c7d47d74c59630d7be844428f6b8d463055ffad6f0392d8da \ - --hash=sha256:a68a36d71c7f638dda6c9e6b67f6aabf3fa1471b198d246457bfdc7c777cdeb7 \ - --hash=sha256:a7991f25b98038252363a03e6a9fe92e60fe390fda2631d238dc3b0e396632f8 \ - --hash=sha256:aadf74a40a7ae49c3c1aa7d32334fe94f4f968e21dd948e301bb4ed431fb2412 \ - --hash=sha256:abae6fd5504e5e438e4f6f739f8364fd9ff5a5cdca897e68363e2318af90bc28 \ - --hash=sha256:ac417312bf6b7a0223ba73fb12e26b2854c93bf5b1911f7afef6d24c379b22aa \ - --hash=sha256:ad9ea86f5fc50f1b62c31184767fe0cacaa13b54fe57d38898c3776d30602411 \ - --hash=sha256:b4ff385a525017f5adf6066d7f9fb309f99ade725dcf17ed623dc7dce1f85d9f \ - --hash=sha256:b89821a2c77cc1b8f2c1fc3aacd6a3ecc5df8f7e518dc3f18aef8c4dcf66003d \ - --hash=sha256:b8ff0302518dcd001bd722bbe342919c29e5066c7eda86828fe08cdc112668b8 \ - --hash=sha256:b91b5ec423e88caa16777094c4b2b97f11453283e7a837e5e5e1b886abba1251 \ - --hash=sha256:ba55d73a2df4771b211d0bcdea8b79454980a81ed34a1d77a19ddcc81f98c895 \ - --hash=sha256:bb1c6ecb53e4b907ee8486f453dd940b8cbb509946e2b671e3bf807d310a96fc \ - --hash=sha256:bc6a4ea9f88a810cb65ccae14404da846e2a02dd5c0ad21dee712ff69d142638 \ - --hash=sha256:c36987f5eb2a7856b5f5feacc3be206b4d1852a6ce799f6799dd9ffb0cba56ae \ - --hash=sha256:c6e98227eb02623d57e1fd061788837834b68bb995a869565211b9abf3de4bf4 \ - --hash=sha256:c7411cd06afeb263182e38c6ca5b4f5fe4f20d91466ad7db0cd6af453a02edec \ - --hash=sha256:c8c466facec2ccdf025b0b1455b18f2c3d574d5f64d24df905d3d7b8f05d5f4e \ - --hash=sha256:c964c0cc443d6c08a2347c0e5c1fc2d85a272dc66c1a6f3cde4fc4843882ada4 \ - --hash=sha256:ca942a2dc066ca5e04c27feaa8dfb9d353ddad14c6641660c565149186095343 \ - --hash=sha256:cb2fd3ab67558eb16aecfb4f2db4febb4d37dc74e6b8613dc2e7160fb58158a9 \ - --hash=sha256:d312ad20e3c6d179cb97c42232b53111bcd8dcdd5c1136083db9d6bdd489bc73 \ - --hash=sha256:d965bdb50725a805b083f5f58d05669a85705f50a6a864e31b545c589290ee31 \ - --hash=sha256:d983222223f63e323a5f497f5b85e211557a5d8fb670dc88f343784502b466ba \ - --hash=sha256:dee4682bd7947afc682d342a8d65ad1834583132383f8e801601a8698cb8d17a \ - --hash=sha256:e2be646a5155d408e68b560c0553e8a83dc7b9f90ec6e5a2fc3ff216719385db \ - --hash=sha256:e2c689439f262c29cf3fcd5364da1e64d8600facecf9eabea8643b8755d2f0de \ - --hash=sha256:e5a111f9158555582deadd202a60bd7803b6c68f406391b7cf6905adf0af6811 \ - --hash=sha256:e905014815687d88cbb14bbc0496420526cf20d49f20606537d87646b70f1046 \ - --hash=sha256:ebc79120e105e4bcd7865f369e3b9dbabb0d492d221e1a7f62a3e8e292550278 \ - --hash=sha256:f1a30eef060e21af22c7d23349f1028de0611f522941c80efa51c05a63142c62 \ - --hash=sha256:f483467c046f549572f8aca3b7128829e09ae3a9fe933ea421f7cb7c58120edb \ - --hash=sha256:f523e116879bc6714e61d447ce934676473b068069dce6563ea040381dc7a257 \ - --hash=sha256:f53a3ccdc30234cb4342cec541e3e6ed87799c7ca552f0b5f44e3967a5fed526 \ - --hash=sha256:fb290491f1f0786a7da4585250f1feee200fc17ff64855bdd7c42fb54526fa29 \ - --hash=sha256:fc3227408808ba7df8e95eb1d8389f4ba2203bed8240b308de1d7ae66d828f24 \ - --hash=sha256:fd80a2d383940eec3db6a5b59d1820f947317acc5c75482ff8d79bf700f8ad6a \ - --hash=sha256:fd937733bf2fe7d6a8bf208c12741f1f730b7bf5636033877767a75093c29b8a \ - --hash=sha256:ffba979801e3931a19cd30ed2049450820effe8f152aaa317e2fd93795d318d7 +pydantic-core==2.23.4 \ + --hash=sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36 \ + --hash=sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05 \ + --hash=sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071 \ + --hash=sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327 \ + --hash=sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c \ + --hash=sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36 \ + --hash=sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29 \ + --hash=sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744 \ + --hash=sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d \ + --hash=sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec \ + --hash=sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e \ + --hash=sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e \ + --hash=sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577 \ + --hash=sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232 \ + --hash=sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863 \ + --hash=sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6 \ + --hash=sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368 \ + --hash=sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480 \ + --hash=sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2 \ + --hash=sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2 \ + --hash=sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6 \ + --hash=sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769 \ + --hash=sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d \ + --hash=sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2 \ + --hash=sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84 \ + --hash=sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166 \ + --hash=sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271 \ + --hash=sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5 \ + --hash=sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb \ + --hash=sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13 \ + --hash=sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323 \ + --hash=sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556 \ + --hash=sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665 \ + --hash=sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef \ + --hash=sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb \ + --hash=sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119 \ + --hash=sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126 \ + --hash=sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510 \ + --hash=sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b \ + --hash=sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87 \ + --hash=sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f \ + --hash=sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc \ + --hash=sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8 \ + --hash=sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21 \ + --hash=sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f \ + --hash=sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6 \ + --hash=sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658 \ + --hash=sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b \ + --hash=sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3 \ + --hash=sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb \ + --hash=sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59 \ + --hash=sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24 \ + --hash=sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9 \ + --hash=sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3 \ + --hash=sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd \ + --hash=sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753 \ + --hash=sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55 \ + --hash=sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad \ + --hash=sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a \ + --hash=sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605 \ + --hash=sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e \ + --hash=sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b \ + --hash=sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433 \ + --hash=sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8 \ + --hash=sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07 \ + --hash=sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728 \ + --hash=sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0 \ + --hash=sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327 \ + --hash=sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555 \ + --hash=sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64 \ + --hash=sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6 \ + --hash=sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea \ + --hash=sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b \ + --hash=sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df \ + --hash=sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e \ + --hash=sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd \ + --hash=sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068 \ + --hash=sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3 \ + --hash=sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040 \ + --hash=sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12 \ + --hash=sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916 \ + --hash=sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f \ + --hash=sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f \ + --hash=sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801 \ + --hash=sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231 \ + --hash=sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5 \ + --hash=sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8 \ + --hash=sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee \ + --hash=sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607 # via # -c release/ray_release/byod/requirements_compiled.txt # pydantic @@ -2223,42 +2218,6 @@ pyparsing==3.1.1 \ # via # -c release/ray_release/byod/requirements_compiled.txt # httplib2 -pyrsistent==0.20.0 \ - --hash=sha256:0724c506cd8b63c69c7f883cc233aac948c1ea946ea95996ad8b1380c25e1d3f \ - --hash=sha256:09848306523a3aba463c4b49493a760e7a6ca52e4826aa100ee99d8d39b7ad1e \ - --hash=sha256:0f3b1bcaa1f0629c978b355a7c37acd58907390149b7311b5db1b37648eb6958 \ - --hash=sha256:21cc459636983764e692b9eba7144cdd54fdec23ccdb1e8ba392a63666c60c34 \ - --hash=sha256:2e14c95c16211d166f59c6611533d0dacce2e25de0f76e4c140fde250997b3ca \ - --hash=sha256:2e2c116cc804d9b09ce9814d17df5edf1df0c624aba3b43bc1ad90411487036d \ - --hash=sha256:4021a7f963d88ccd15b523787d18ed5e5269ce57aa4037146a2377ff607ae87d \ - --hash=sha256:4c48f78f62ab596c679086084d0dd13254ae4f3d6c72a83ffdf5ebdef8f265a4 \ - --hash=sha256:4f5c2d012671b7391803263419e31b5c7c21e7c95c8760d7fc35602353dee714 \ - --hash=sha256:58b8f6366e152092194ae68fefe18b9f0b4f89227dfd86a07770c3d86097aebf \ - --hash=sha256:59a89bccd615551391f3237e00006a26bcf98a4d18623a19909a2c48b8e986ee \ - --hash=sha256:5cdd7ef1ea7a491ae70d826b6cc64868de09a1d5ff9ef8d574250d0940e275b8 \ - --hash=sha256:6288b3fa6622ad8a91e6eb759cfc48ff3089e7c17fb1d4c59a919769314af224 \ - --hash=sha256:6d270ec9dd33cdb13f4d62c95c1a5a50e6b7cdd86302b494217137f760495b9d \ - --hash=sha256:79ed12ba79935adaac1664fd7e0e585a22caa539dfc9b7c7c6d5ebf91fb89054 \ - --hash=sha256:7d29c23bdf6e5438c755b941cef867ec2a4a172ceb9f50553b6ed70d50dfd656 \ - --hash=sha256:8441cf9616d642c475684d6cf2520dd24812e996ba9af15e606df5f6fd9d04a7 \ - --hash=sha256:881bbea27bbd32d37eb24dd320a5e745a2a5b092a17f6debc1349252fac85423 \ - --hash=sha256:8c3aba3e01235221e5b229a6c05f585f344734bd1ad42a8ac51493d74722bbce \ - --hash=sha256:a14798c3005ec892bbada26485c2eea3b54109cb2533713e355c806891f63c5e \ - --hash=sha256:b14decb628fac50db5e02ee5a35a9c0772d20277824cfe845c8a8b717c15daa3 \ - --hash=sha256:b318ca24db0f0518630e8b6f3831e9cba78f099ed5c1d65ffe3e023003043ba0 \ - --hash=sha256:c1beb78af5423b879edaf23c5591ff292cf7c33979734c99aa66d5914ead880f \ - --hash=sha256:c55acc4733aad6560a7f5f818466631f07efc001fd023f34a6c203f8b6df0f0b \ - --hash=sha256:ca52d1ceae015859d16aded12584c59eb3825f7b50c6cfd621d4231a6cc624ce \ - --hash=sha256:cae40a9e3ce178415040a0383f00e8d68b569e97f31928a3a8ad37e3fde6df6a \ - --hash=sha256:e78d0c7c1e99a4a45c99143900ea0546025e41bb59ebc10182e947cf1ece9174 \ - --hash=sha256:ef3992833fbd686ee783590639f4b8343a57f1f75de8633749d984dc0eb16c86 \ - --hash=sha256:f058a615031eea4ef94ead6456f5ec2026c19fb5bd6bfe86e9665c4158cf802f \ - --hash=sha256:f5ac696f02b3fc01a710427585c855f65cd9c640e14f52abe52020722bb4906b \ - --hash=sha256:f920385a11207dc372a028b3f1e1038bb244b3ec38d448e6d8e43c6b3ba20e98 \ - --hash=sha256:fed2c3216a605dc9a6ea50c7e84c82906e3684c4e80d2908208f662a6cbf9022 - # via - # -c release/ray_release/byod/requirements_compiled.txt - # jsonschema pyspark==3.4.1 \ --hash=sha256:72cd66ab8cf61a75854e5a753f75bea35ee075c3a96f9de4e2a66d02ec7fc652 # via @@ -2437,6 +2396,13 @@ pyzmq==26.0.3 \ # -c release/ray_release/byod/requirements_compiled.txt # locust # petastorm +referencing==0.36.2 \ + --hash=sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa \ + --hash=sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0 + # via + # -c release/ray_release/byod/requirements_compiled.txt + # jsonschema + # jsonschema-specifications requests==2.31.0 \ --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 @@ -2472,6 +2438,114 @@ rich==13.3.2 \ roundrobin==0.0.4 \ --hash=sha256:7e9d19a5bd6123d99993fb935fa86d25c88bb2096e493885f61737ed0f5e9abd # via locust +rpds-py==0.22.3 \ + --hash=sha256:009de23c9c9ee54bf11303a966edf4d9087cd43a6003672e6aa7def643d06518 \ + --hash=sha256:02fbb9c288ae08bcb34fb41d516d5eeb0455ac35b5512d03181d755d80810059 \ + --hash=sha256:0a0461200769ab3b9ab7e513f6013b7a97fdeee41c29b9db343f3c5a8e2b9e61 \ + --hash=sha256:0b09865a9abc0ddff4e50b5ef65467cd94176bf1e0004184eb915cbc10fc05c5 \ + --hash=sha256:0b8db6b5b2d4491ad5b6bdc2bc7c017eec108acbf4e6785f42a9eb0ba234f4c9 \ + --hash=sha256:0c150c7a61ed4a4f4955a96626574e9baf1adf772c2fb61ef6a5027e52803543 \ + --hash=sha256:0f3cec041684de9a4684b1572fe28c7267410e02450f4561700ca5a3bc6695a2 \ + --hash=sha256:1352ae4f7c717ae8cba93421a63373e582d19d55d2ee2cbb184344c82d2ae55a \ + --hash=sha256:177c7c0fce2855833819c98e43c262007f42ce86651ffbb84f37883308cb0e7d \ + --hash=sha256:1978d0021e943aae58b9b0b196fb4895a25cc53d3956b8e35e0b7682eefb6d56 \ + --hash=sha256:1a60bce91f81ddaac922a40bbb571a12c1070cb20ebd6d49c48e0b101d87300d \ + --hash=sha256:1aef18820ef3e4587ebe8b3bc9ba6e55892a6d7b93bac6d29d9f631a3b4befbd \ + --hash=sha256:1e9663daaf7a63ceccbbb8e3808fe90415b0757e2abddbfc2e06c857bf8c5e2b \ + --hash=sha256:20070c65396f7373f5df4005862fa162db5d25d56150bddd0b3e8214e8ef45b4 \ + --hash=sha256:214b7a953d73b5e87f0ebece4a32a5bd83c60a3ecc9d4ec8f1dca968a2d91e99 \ + --hash=sha256:22bebe05a9ffc70ebfa127efbc429bc26ec9e9b4ee4d15a740033efda515cf3d \ + --hash=sha256:24e8abb5878e250f2eb0d7859a8e561846f98910326d06c0d51381fed59357bd \ + --hash=sha256:26fd7cac7dd51011a245f29a2cc6489c4608b5a8ce8d75661bb4a1066c52dfbe \ + --hash=sha256:27b1d3b3915a99208fee9ab092b8184c420f2905b7d7feb4aeb5e4a9c509b8a1 \ + --hash=sha256:27e98004595899949bd7a7b34e91fa7c44d7a97c40fcaf1d874168bb652ec67e \ + --hash=sha256:2b8f60e1b739a74bab7e01fcbe3dddd4657ec685caa04681df9d562ef15b625f \ + --hash=sha256:2de29005e11637e7a2361fa151f780ff8eb2543a0da1413bb951e9f14b699ef3 \ + --hash=sha256:2e8b55d8517a2fda8d95cb45d62a5a8bbf9dd0ad39c5b25c8833efea07b880ca \ + --hash=sha256:2fa4331c200c2521512595253f5bb70858b90f750d39b8cbfd67465f8d1b596d \ + --hash=sha256:3445e07bf2e8ecfeef6ef67ac83de670358abf2996916039b16a218e3d95e97e \ + --hash=sha256:3453e8d41fe5f17d1f8e9c383a7473cd46a63661628ec58e07777c2fff7196dc \ + --hash=sha256:378753b4a4de2a7b34063d6f95ae81bfa7b15f2c1a04a9518e8644e81807ebea \ + --hash=sha256:3af6e48651c4e0d2d166dc1b033b7042ea3f871504b6805ba5f4fe31581d8d38 \ + --hash=sha256:3dfcbc95bd7992b16f3f7ba05af8a64ca694331bd24f9157b49dadeeb287493b \ + --hash=sha256:3f21f0495edea7fdbaaa87e633a8689cd285f8f4af5c869f27bc8074638ad69c \ + --hash=sha256:4041711832360a9b75cfb11b25a6a97c8fb49c07b8bd43d0d02b45d0b499a4ff \ + --hash=sha256:44d61b4b7d0c2c9ac019c314e52d7cbda0ae31078aabd0f22e583af3e0d79723 \ + --hash=sha256:4617e1915a539a0d9a9567795023de41a87106522ff83fbfaf1f6baf8e85437e \ + --hash=sha256:4b232061ca880db21fa14defe219840ad9b74b6158adb52ddf0e87bead9e8493 \ + --hash=sha256:5246b14ca64a8675e0a7161f7af68fe3e910e6b90542b4bfb5439ba752191df6 \ + --hash=sha256:5725dd9cc02068996d4438d397e255dcb1df776b7ceea3b9cb972bdb11260a83 \ + --hash=sha256:583f6a1993ca3369e0f80ba99d796d8e6b1a3a2a442dd4e1a79e652116413091 \ + --hash=sha256:59259dc58e57b10e7e18ce02c311804c10c5a793e6568f8af4dead03264584d1 \ + --hash=sha256:593eba61ba0c3baae5bc9be2f5232430453fb4432048de28399ca7376de9c627 \ + --hash=sha256:59f4a79c19232a5774aee369a0c296712ad0e77f24e62cad53160312b1c1eaa1 \ + --hash=sha256:5f0e260eaf54380380ac3808aa4ebe2d8ca28b9087cf411649f96bad6900c728 \ + --hash=sha256:62d9cfcf4948683a18a9aff0ab7e1474d407b7bab2ca03116109f8464698ab16 \ + --hash=sha256:64607d4cbf1b7e3c3c8a14948b99345eda0e161b852e122c6bb71aab6d1d798c \ + --hash=sha256:655ca44a831ecb238d124e0402d98f6212ac527a0ba6c55ca26f616604e60a45 \ + --hash=sha256:666ecce376999bf619756a24ce15bb14c5bfaf04bf00abc7e663ce17c3f34fe7 \ + --hash=sha256:68049202f67380ff9aa52f12e92b1c30115f32e6895cd7198fa2a7961621fc5a \ + --hash=sha256:69803198097467ee7282750acb507fba35ca22cc3b85f16cf45fb01cb9097730 \ + --hash=sha256:6c7b99ca52c2c1752b544e310101b98a659b720b21db00e65edca34483259967 \ + --hash=sha256:6dd9412824c4ce1aca56c47b0991e65bebb7ac3f4edccfd3f156150c96a7bf25 \ + --hash=sha256:70eb60b3ae9245ddea20f8a4190bd79c705a22f8028aaf8bbdebe4716c3fab24 \ + --hash=sha256:70fb28128acbfd264eda9bf47015537ba3fe86e40d046eb2963d75024be4d055 \ + --hash=sha256:7b2513ba235829860b13faa931f3b6846548021846ac808455301c23a101689d \ + --hash=sha256:7ef9d9da710be50ff6809fed8f1963fecdfecc8b86656cadfca3bc24289414b0 \ + --hash=sha256:81e69b0a0e2537f26d73b4e43ad7bc8c8efb39621639b4434b76a3de50c6966e \ + --hash=sha256:8633e471c6207a039eff6aa116e35f69f3156b3989ea3e2d755f7bc41754a4a7 \ + --hash=sha256:8bd7c8cfc0b8247c8799080fbff54e0b9619e17cdfeb0478ba7295d43f635d7c \ + --hash=sha256:9253fc214112405f0afa7db88739294295f0e08466987f1d70e29930262b4c8f \ + --hash=sha256:99b37292234e61325e7a5bb9689e55e48c3f5f603af88b1642666277a81f1fbd \ + --hash=sha256:9bd7228827ec7bb817089e2eb301d907c0d9827a9e558f22f762bb690b131652 \ + --hash=sha256:9beeb01d8c190d7581a4d59522cd3d4b6887040dcfc744af99aa59fef3e041a8 \ + --hash=sha256:a63cbdd98acef6570c62b92a1e43266f9e8b21e699c363c0fef13bd530799c11 \ + --hash=sha256:a76e42402542b1fae59798fab64432b2d015ab9d0c8c47ba7addddbaf7952333 \ + --hash=sha256:ac0a03221cdb5058ce0167ecc92a8c89e8d0decdc9e99a2ec23380793c4dcb96 \ + --hash=sha256:b0b4136a252cadfa1adb705bb81524eee47d9f6aab4f2ee4fa1e9d3cd4581f64 \ + --hash=sha256:b25bc607423935079e05619d7de556c91fb6adeae9d5f80868dde3468657994b \ + --hash=sha256:b3d504047aba448d70cf6fa22e06cb09f7cbd761939fdd47604f5e007675c24e \ + --hash=sha256:bb47271f60660803ad11f4c61b42242b8c1312a31c98c578f79ef9387bbde21c \ + --hash=sha256:bbb232860e3d03d544bc03ac57855cd82ddf19c7a07651a7c0fdb95e9efea8b9 \ + --hash=sha256:bc27863442d388870c1809a87507727b799c8460573cfbb6dc0eeaef5a11b5ec \ + --hash=sha256:bc51abd01f08117283c5ebf64844a35144a0843ff7b2983e0648e4d3d9f10dbb \ + --hash=sha256:be2eb3f2495ba669d2a985f9b426c1797b7d48d6963899276d22f23e33d47e37 \ + --hash=sha256:bf9db5488121b596dbfc6718c76092fda77b703c1f7533a226a5a9f65248f8ad \ + --hash=sha256:c58e2339def52ef6b71b8f36d13c3688ea23fa093353f3a4fee2556e62086ec9 \ + --hash=sha256:cfbc454a2880389dbb9b5b398e50d439e2e58669160f27b60e5eca11f68ae17c \ + --hash=sha256:cff63a0272fcd259dcc3be1657b07c929c466b067ceb1c20060e8d10af56f5bf \ + --hash=sha256:d115bffdd417c6d806ea9069237a4ae02f513b778e3789a359bc5856e0404cc4 \ + --hash=sha256:d20cfb4e099748ea39e6f7b16c91ab057989712d31761d3300d43134e26e165f \ + --hash=sha256:d48424e39c2611ee1b84ad0f44fb3b2b53d473e65de061e3f460fc0be5f1939d \ + --hash=sha256:e0fa2d4ec53dc51cf7d3bb22e0aa0143966119f42a0c3e4998293a3dd2856b09 \ + --hash=sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d \ + --hash=sha256:e35ba67d65d49080e8e5a1dd40101fccdd9798adb9b050ff670b7d74fa41c566 \ + --hash=sha256:e3fb866d9932a3d7d0c82da76d816996d1667c44891bd861a0f97ba27e84fc74 \ + --hash=sha256:e61b02c3f7a1e0b75e20c3978f7135fd13cb6cf551bf4a6d29b999a88830a338 \ + --hash=sha256:e67ba3c290821343c192f7eae1d8fd5999ca2dc99994114643e2f2d3e6138b15 \ + --hash=sha256:e79dd39f1e8c3504be0607e5fc6e86bb60fe3584bec8b782578c3b0fde8d932c \ + --hash=sha256:e89391e6d60251560f0a8f4bd32137b077a80d9b7dbe6d5cab1cd80d2746f648 \ + --hash=sha256:ea7433ce7e4bfc3a85654aeb6747babe3f66eaf9a1d0c1e7a4435bbdf27fea84 \ + --hash=sha256:eaf16ae9ae519a0e237a0f528fd9f0197b9bb70f40263ee57ae53c2b8d48aeb3 \ + --hash=sha256:eb0c341fa71df5a4595f9501df4ac5abfb5a09580081dffbd1ddd4654e6e9123 \ + --hash=sha256:f276b245347e6e36526cbd4a266a417796fc531ddf391e43574cf6466c492520 \ + --hash=sha256:f47ad3d5f3258bd7058d2d506852217865afefe6153a36eb4b6928758041d831 \ + --hash=sha256:f56a6b404f74ab372da986d240e2e002769a7d7102cc73eb238a4f72eec5284e \ + --hash=sha256:f5cf2a0c2bdadf3791b5c205d55a37a54025c6e18a71c71f82bb536cf9a454bf \ + --hash=sha256:f5d36399a1b96e1a5fdc91e0522544580dbebeb1f77f27b2b0ab25559e103b8b \ + --hash=sha256:f60bd8423be1d9d833f230fdbccf8f57af322d96bcad6599e5a771b151398eb2 \ + --hash=sha256:f612463ac081803f243ff13cccc648578e2279295048f2a8d5eb430af2bae6e3 \ + --hash=sha256:f73d3fef726b3243a811121de45193c0ca75f6407fe66f3f4e183c983573e130 \ + --hash=sha256:f82a116a1d03628a8ace4859556fb39fd1424c933341a08ea3ed6de1edb0283b \ + --hash=sha256:fb0ba113b4983beac1a2eb16faffd76cb41e176bf58c4afe3e14b9c681f702de \ + --hash=sha256:fb4f868f712b2dd4bcc538b0a0c1f63a2b1d584c925e69a224d759e7070a12d5 \ + --hash=sha256:fb6116dfb8d1925cbdb52595560584db42a7f664617a1f7d7f6e32f138cdf37d \ + --hash=sha256:fda7cb070f442bf80b642cd56483b5548e43d366fe3f39b98e67cce780cded00 \ + --hash=sha256:feea821ee2a9273771bae61194004ee2fc33f8ec7db08117ef9147d4bbcbca8e + # via + # -c release/ray_release/byod/requirements_compiled.txt + # jsonschema + # referencing rsa==4.7.2 \ --hash=sha256:78f9a9bf4e7be0c5ded4583326e7461e3a3c5aae24073648b4bdfa797d78c9d2 \ --hash=sha256:9d689e6ca1b3038bc82bf8d23e944b6b6037bc02301a574935b2dd946e0353b9 @@ -2767,9 +2841,9 @@ typer==0.12.3 \ # via # -c release/ray_release/byod/requirements_compiled.txt # -r release/ray_release/byod/requirements_byod_3.9.in -typing-extensions==4.8.0 \ - --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ - --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef +typing-extensions==4.12.2 \ + --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ + --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 # via # -c release/ray_release/byod/requirements_compiled.txt # -r release/ray_release/byod/requirements_byod_3.9.in @@ -2780,6 +2854,7 @@ typing-extensions==4.8.0 \ # gymnasium # pydantic # pydantic-core + # referencing # starlette # tensorflow # typer diff --git a/release/ray_release/byod/requirements_debian_byod.txt b/release/ray_release/byod/requirements_debian_byod.txt index a68169fa51ea4..6ef99fc4166fa 100644 --- a/release/ray_release/byod/requirements_debian_byod.txt +++ b/release/ray_release/byod/requirements_debian_byod.txt @@ -11,4 +11,4 @@ libosmesa6-dev patchelf unzip zip -libaio1 \ No newline at end of file +libaio1 diff --git a/release/ray_release/byod/requirements_ml_byod_3.9.txt b/release/ray_release/byod/requirements_ml_byod_3.9.txt index cea51f87c9501..4cf5ff6f7bbe5 100644 --- a/release/ray_release/byod/requirements_ml_byod_3.9.txt +++ b/release/ray_release/byod/requirements_ml_byod_3.9.txt @@ -160,14 +160,15 @@ async-timeout==4.0.3 \ # via # -c release/ray_release/byod/requirements_compiled.txt # aiohttp -attrs==21.4.0 \ - --hash=sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4 \ - --hash=sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd +attrs==25.1.0 \ + --hash=sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e \ + --hash=sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a # via # -c release/ray_release/byod/requirements_compiled.txt # aiohttp # jsonlines # jsonschema + # referencing backcall==0.2.0 \ --hash=sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e \ --hash=sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255 @@ -777,9 +778,9 @@ fastjsonschema==2.19.0 \ # via # -c release/ray_release/byod/requirements_compiled.txt # nbformat -filelock==3.13.1 \ - --hash=sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e \ - --hash=sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c +filelock==3.17.0 \ + --hash=sha256:533dc2f7ba78dc2f0f531fc6c4940addf7b70a481e269a5a3b93be94ffbe8338 \ + --hash=sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e # via # -c release/ray_release/byod/requirements_compiled.txt # -r release/ray_release/byod/requirements_ml_byod_3.9.in @@ -1414,12 +1415,18 @@ jsonlines==4.0.0 \ --hash=sha256:0c6d2c09117550c089995247f605ae4cf77dd1533041d366351f6f298822ea74 \ --hash=sha256:185b334ff2ca5a91362993f42e83588a360cf95ce4b71a73548502bda52a7c55 # via lm-eval -jsonschema==4.17.3 \ - --hash=sha256:0f864437ab8b6076ba6707453ef8f98a6a0d512a80e93f8abdb676f737ecb60d \ - --hash=sha256:a870ad254da1a8ca84b6a2905cac29d265f805acc57af304784962a2aa6508f6 +jsonschema==4.23.0 \ + --hash=sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4 \ + --hash=sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566 # via # -c release/ray_release/byod/requirements_compiled.txt # nbformat +jsonschema-specifications==2024.10.1 \ + --hash=sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272 \ + --hash=sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf + # via + # -c release/ray_release/byod/requirements_compiled.txt + # jsonschema jupyter-core==5.5.0 \ --hash=sha256:880b86053bf298a8724994f95e99b99130659022a4f7f45f563084b6223861d3 \ --hash=sha256:e11e02cd8ae0a9de5c6c44abf5727df9f2581055afe00b22183f621ba3585805 @@ -2614,116 +2621,104 @@ pycparser==2.21 \ # via # -c release/ray_release/byod/requirements_compiled.txt # cffi -pydantic==2.5.0 \ - --hash=sha256:69bd6fb62d2d04b7055f59a396993486a2ee586c43a0b89231ce0000de07627c \ - --hash=sha256:7ce6e766c456ad026fe5712f7bcf036efc34bd5d107b3e669ef7ea01b3a9050c +pydantic==2.9.2 \ + --hash=sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f \ + --hash=sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12 # via # -c release/ray_release/byod/requirements_compiled.txt # -r release/ray_release/byod/requirements_ml_byod_3.9.in # deepspeed # fastapi -pydantic-core==2.14.1 \ - --hash=sha256:023b6d7ec4e97890b28eb2ee24413e69a6d48de4e8b75123957edd5432f4eeb3 \ - --hash=sha256:052d8731aaf844f91fe4cd3faf28983b109a5865b3a256ec550b80a5689ead87 \ - --hash=sha256:0a8c8daf4e3aa3aeb98e3638fc3d58a359738f3d12590b2474c6bb64031a0764 \ - --hash=sha256:0d82a6ee815388a362885186e431fac84c7a06623bc136f508e9f88261d8cadb \ - --hash=sha256:101df420e954966868b8bc992aefed5fa71dd1f2755104da62ee247abab28e2f \ - --hash=sha256:102ac85a775e77821943ae38da9634ddd774b37a8d407181b4f7b05cdfb36b55 \ - --hash=sha256:1185548665bc61bbab0dc78f10c8eafa0db0aa1e920fe9a451b77782b10a65cc \ - --hash=sha256:12163197fec7c95751a3c71b36dcc1909eed9959f011ffc79cc8170a6a74c826 \ - --hash=sha256:130e49aa0cb316f743bc7792c36aefa39fc2221312f1d4b333b19edbdd71f2b1 \ - --hash=sha256:132b40e479cb5cebbbb681f77aaceabbc8355df16c9124cff1d4060ada83cde2 \ - --hash=sha256:144f2c1d5579108b6ed1193fcc9926124bd4142b0f7020a7744980d1235c8a40 \ - --hash=sha256:16f4a7e1ec6b3ea98a1e108a2739710cd659d68b33fbbeaba066202cab69c7b6 \ - --hash=sha256:184ff7b30c3f60e1b775378c060099285fd4b5249271046c9005f8b247b39377 \ - --hash=sha256:1bfb63821ada76719ffcd703fc40dd57962e0d8c253e3c565252e6de6d3e0bc6 \ - --hash=sha256:1e7208946ea9b27a8cef13822c339d4ae96e45952cc01fc4a91c7f1cb0ae2861 \ - --hash=sha256:217dcbfaf429a9b8f1d54eb380908b9c778e78f31378283b30ba463c21e89d5d \ - --hash=sha256:2459cc06572730e079ec1e694e8f68c99d977b40d98748ae72ff11ef21a56b0b \ - --hash=sha256:24ba48f9d0b8d64fc5e42e1600366c3d7db701201294989aebdaca23110c02ab \ - --hash=sha256:26242e3593d4929123615bd9365dd86ef79b7b0592d64a96cd11fd83c69c9f34 \ - --hash=sha256:2871daf5b2823bf77bf7d3d43825e5d904030c155affdf84b21a00a2e00821d2 \ - --hash=sha256:28734bcfb8fc5b03293dec5eb5ea73b32ff767f6ef79a31f6e41dad2f5470270 \ - --hash=sha256:2a7d08b39fac97540fba785fce3b21ee01a81f081a07a4d031efd791da6666f9 \ - --hash=sha256:2be018a84995b6be1bbd40d6064395dbf71592a981169cf154c0885637f5f54a \ - --hash=sha256:3303113fdfaca927ef11e0c5f109e2ec196c404f9d7ba5f8ddb63cdf287ea159 \ - --hash=sha256:36c3bf96f803e207a80dbcb633d82b98ff02a9faa76dd446e969424dec8e2b9f \ - --hash=sha256:3d5b2a4b3c10cad0615670cab99059441ff42e92cf793a0336f4bc611e895204 \ - --hash=sha256:3f48d4afd973abbd65266ac24b24de1591116880efc7729caf6b6b94a9654c9e \ - --hash=sha256:42d5d0e9bbb50481a049bd0203224b339d4db04006b78564df2b782e2fd16ebc \ - --hash=sha256:443dc5eede7fa76b2370213e0abe881eb17c96f7d694501853c11d5d56916602 \ - --hash=sha256:49ee28d65f506b2858a60745cc974ed005298ebab12693646b97641dd7c99c35 \ - --hash=sha256:4f0788699a92d604f348e9c1ac5e97e304e97127ba8325c7d0af88dcc7d35bd3 \ - --hash=sha256:51506e7652a2ef1d1cf763c4b51b972ff4568d1dddc96ca83931a6941f5e6389 \ - --hash=sha256:53efe03cc383a83660cfdda6a3cb40ee31372cedea0fde0b2a2e55e838873ab6 \ - --hash=sha256:55713d155da1e508083c4b08d0b1ad2c3054f68b8ef7eb3d3864822e456f0bb5 \ - --hash=sha256:581bb606a31749a00796f5257947a0968182d7fe91e1dada41f06aeb6bfbc91a \ - --hash=sha256:5879ac4791508d8f0eb7dec71ff8521855180688dac0c55f8c99fc4d1a939845 \ - --hash=sha256:587d75aec9ae50d0d63788cec38bf13c5128b3fc1411aa4b9398ebac884ab179 \ - --hash=sha256:59fa83873223f856d898452c6162a390af4297756f6ba38493a67533387d85d9 \ - --hash=sha256:5a1570875eb0d1479fb2270ed80c88c231aaaf68b0c3f114f35e7fb610435e4f \ - --hash=sha256:5b45b7be9f99991405ecd6f6172fb6798908a8097106ae78d5cc5cc15121bad9 \ - --hash=sha256:6015beb28deb5306049ecf2519a59627e9e050892927850a884df6d5672f8c7d \ - --hash=sha256:6590ed9d13eb51b28ea17ddcc6c8dbd6050b4eb589d497105f0e13339f223b72 \ - --hash=sha256:66dc0e63349ec39c1ea66622aa5c2c1f84382112afd3ab2fa0cca4fb01f7db39 \ - --hash=sha256:679cc4e184f213c8227862e57340d12fd4d4d19dc0e3ddb0f653f86f01e90f94 \ - --hash=sha256:69cd74e55a5326d920e7b46daa2d81c2bdb8bcf588eafb2330d981297b742ddc \ - --hash=sha256:69df82892ff00491d673b1929538efb8c8d68f534fdc6cb7fd3ac8a5852b9034 \ - --hash=sha256:72c2ef3787c3b577e5d6225d73a77167b942d12cef3c1fbd5e74e55b7f881c36 \ - --hash=sha256:744b807fe2733b6da3b53e8ad93e8b3ea3ee3dfc3abece4dd2824cc1f39aa343 \ - --hash=sha256:7977e261cac5f99873dc2c6f044315d09b19a71c4246560e1e67593889a90978 \ - --hash=sha256:798590d38c9381f07c48d13af1f1ef337cebf76ee452fcec5deb04aceced51c7 \ - --hash=sha256:812beca1dcb2b722cccc7e9c620bd972cbc323321194ec2725eab3222e6ac573 \ - --hash=sha256:8276bbab68a9dbe721da92d19cbc061f76655248fe24fb63969d0c3e0e5755e7 \ - --hash=sha256:85bb66d661be51b2cba9ca06759264b3469d2dbb53c3e6effb3f05fec6322be6 \ - --hash=sha256:871c641a83719caaa856a11dcc61c5e5b35b0db888e1a0d338fe67ce744575e2 \ - --hash=sha256:893bf4fb9bfb9c4639bc12f3de323325ada4c6d60e478d5cded65453e9364890 \ - --hash=sha256:8d927d042c0ef04607ee7822828b208ab045867d20477ec6593d612156798547 \ - --hash=sha256:8e17f0c3ba4cb07faa0038a59ce162de584ed48ba645c8d05a5de1e40d4c21e7 \ - --hash=sha256:9486e27bb3f137f33e2315be2baa0b0b983dae9e2f5f5395240178ad8e644728 \ - --hash=sha256:94cf6d0274eb899d39189144dcf52814c67f9b0fd196f211420d9aac793df2da \ - --hash=sha256:97246f896b4df7fd84caa8a75a67abb95f94bc0b547665bf0889e3262b060399 \ - --hash=sha256:9d59e0d7cdfe8ed1d4fcd28aad09625c715dc18976c7067e37d8a11b06f4be3e \ - --hash=sha256:a15f6e5588f7afb7f6fc4b0f4ff064749e515d34f34c666ed6e37933873d8ad8 \ - --hash=sha256:a2ccdc53cb88e51c7d47d74c59630d7be844428f6b8d463055ffad6f0392d8da \ - --hash=sha256:a68a36d71c7f638dda6c9e6b67f6aabf3fa1471b198d246457bfdc7c777cdeb7 \ - --hash=sha256:a7991f25b98038252363a03e6a9fe92e60fe390fda2631d238dc3b0e396632f8 \ - --hash=sha256:aadf74a40a7ae49c3c1aa7d32334fe94f4f968e21dd948e301bb4ed431fb2412 \ - --hash=sha256:abae6fd5504e5e438e4f6f739f8364fd9ff5a5cdca897e68363e2318af90bc28 \ - --hash=sha256:ac417312bf6b7a0223ba73fb12e26b2854c93bf5b1911f7afef6d24c379b22aa \ - --hash=sha256:ad9ea86f5fc50f1b62c31184767fe0cacaa13b54fe57d38898c3776d30602411 \ - --hash=sha256:b4ff385a525017f5adf6066d7f9fb309f99ade725dcf17ed623dc7dce1f85d9f \ - --hash=sha256:b89821a2c77cc1b8f2c1fc3aacd6a3ecc5df8f7e518dc3f18aef8c4dcf66003d \ - --hash=sha256:b8ff0302518dcd001bd722bbe342919c29e5066c7eda86828fe08cdc112668b8 \ - --hash=sha256:b91b5ec423e88caa16777094c4b2b97f11453283e7a837e5e5e1b886abba1251 \ - --hash=sha256:ba55d73a2df4771b211d0bcdea8b79454980a81ed34a1d77a19ddcc81f98c895 \ - --hash=sha256:bb1c6ecb53e4b907ee8486f453dd940b8cbb509946e2b671e3bf807d310a96fc \ - --hash=sha256:bc6a4ea9f88a810cb65ccae14404da846e2a02dd5c0ad21dee712ff69d142638 \ - --hash=sha256:c36987f5eb2a7856b5f5feacc3be206b4d1852a6ce799f6799dd9ffb0cba56ae \ - --hash=sha256:c6e98227eb02623d57e1fd061788837834b68bb995a869565211b9abf3de4bf4 \ - --hash=sha256:c7411cd06afeb263182e38c6ca5b4f5fe4f20d91466ad7db0cd6af453a02edec \ - --hash=sha256:c8c466facec2ccdf025b0b1455b18f2c3d574d5f64d24df905d3d7b8f05d5f4e \ - --hash=sha256:c964c0cc443d6c08a2347c0e5c1fc2d85a272dc66c1a6f3cde4fc4843882ada4 \ - --hash=sha256:ca942a2dc066ca5e04c27feaa8dfb9d353ddad14c6641660c565149186095343 \ - --hash=sha256:cb2fd3ab67558eb16aecfb4f2db4febb4d37dc74e6b8613dc2e7160fb58158a9 \ - --hash=sha256:d312ad20e3c6d179cb97c42232b53111bcd8dcdd5c1136083db9d6bdd489bc73 \ - --hash=sha256:d965bdb50725a805b083f5f58d05669a85705f50a6a864e31b545c589290ee31 \ - --hash=sha256:d983222223f63e323a5f497f5b85e211557a5d8fb670dc88f343784502b466ba \ - --hash=sha256:dee4682bd7947afc682d342a8d65ad1834583132383f8e801601a8698cb8d17a \ - --hash=sha256:e2be646a5155d408e68b560c0553e8a83dc7b9f90ec6e5a2fc3ff216719385db \ - --hash=sha256:e2c689439f262c29cf3fcd5364da1e64d8600facecf9eabea8643b8755d2f0de \ - --hash=sha256:e5a111f9158555582deadd202a60bd7803b6c68f406391b7cf6905adf0af6811 \ - --hash=sha256:e905014815687d88cbb14bbc0496420526cf20d49f20606537d87646b70f1046 \ - --hash=sha256:ebc79120e105e4bcd7865f369e3b9dbabb0d492d221e1a7f62a3e8e292550278 \ - --hash=sha256:f1a30eef060e21af22c7d23349f1028de0611f522941c80efa51c05a63142c62 \ - --hash=sha256:f483467c046f549572f8aca3b7128829e09ae3a9fe933ea421f7cb7c58120edb \ - --hash=sha256:f523e116879bc6714e61d447ce934676473b068069dce6563ea040381dc7a257 \ - --hash=sha256:f53a3ccdc30234cb4342cec541e3e6ed87799c7ca552f0b5f44e3967a5fed526 \ - --hash=sha256:fb290491f1f0786a7da4585250f1feee200fc17ff64855bdd7c42fb54526fa29 \ - --hash=sha256:fc3227408808ba7df8e95eb1d8389f4ba2203bed8240b308de1d7ae66d828f24 \ - --hash=sha256:fd80a2d383940eec3db6a5b59d1820f947317acc5c75482ff8d79bf700f8ad6a \ - --hash=sha256:fd937733bf2fe7d6a8bf208c12741f1f730b7bf5636033877767a75093c29b8a \ - --hash=sha256:ffba979801e3931a19cd30ed2049450820effe8f152aaa317e2fd93795d318d7 +pydantic-core==2.23.4 \ + --hash=sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36 \ + --hash=sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05 \ + --hash=sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071 \ + --hash=sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327 \ + --hash=sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c \ + --hash=sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36 \ + --hash=sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29 \ + --hash=sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744 \ + --hash=sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d \ + --hash=sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec \ + --hash=sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e \ + --hash=sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e \ + --hash=sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577 \ + --hash=sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232 \ + --hash=sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863 \ + --hash=sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6 \ + --hash=sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368 \ + --hash=sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480 \ + --hash=sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2 \ + --hash=sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2 \ + --hash=sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6 \ + --hash=sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769 \ + --hash=sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d \ + --hash=sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2 \ + --hash=sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84 \ + --hash=sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166 \ + --hash=sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271 \ + --hash=sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5 \ + --hash=sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb \ + --hash=sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13 \ + --hash=sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323 \ + --hash=sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556 \ + --hash=sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665 \ + --hash=sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef \ + --hash=sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb \ + --hash=sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119 \ + --hash=sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126 \ + --hash=sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510 \ + --hash=sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b \ + --hash=sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87 \ + --hash=sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f \ + --hash=sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc \ + --hash=sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8 \ + --hash=sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21 \ + --hash=sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f \ + --hash=sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6 \ + --hash=sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658 \ + --hash=sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b \ + --hash=sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3 \ + --hash=sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb \ + --hash=sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59 \ + --hash=sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24 \ + --hash=sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9 \ + --hash=sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3 \ + --hash=sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd \ + --hash=sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753 \ + --hash=sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55 \ + --hash=sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad \ + --hash=sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a \ + --hash=sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605 \ + --hash=sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e \ + --hash=sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b \ + --hash=sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433 \ + --hash=sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8 \ + --hash=sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07 \ + --hash=sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728 \ + --hash=sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0 \ + --hash=sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327 \ + --hash=sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555 \ + --hash=sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64 \ + --hash=sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6 \ + --hash=sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea \ + --hash=sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b \ + --hash=sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df \ + --hash=sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e \ + --hash=sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd \ + --hash=sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068 \ + --hash=sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3 \ + --hash=sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040 \ + --hash=sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12 \ + --hash=sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916 \ + --hash=sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f \ + --hash=sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f \ + --hash=sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801 \ + --hash=sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231 \ + --hash=sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5 \ + --hash=sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8 \ + --hash=sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee \ + --hash=sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607 # via # -c release/ray_release/byod/requirements_compiled.txt # pydantic @@ -2754,42 +2749,6 @@ pyparsing==3.1.1 \ # -c release/ray_release/byod/requirements_compiled.txt # httplib2 # matplotlib -pyrsistent==0.20.0 \ - --hash=sha256:0724c506cd8b63c69c7f883cc233aac948c1ea946ea95996ad8b1380c25e1d3f \ - --hash=sha256:09848306523a3aba463c4b49493a760e7a6ca52e4826aa100ee99d8d39b7ad1e \ - --hash=sha256:0f3b1bcaa1f0629c978b355a7c37acd58907390149b7311b5db1b37648eb6958 \ - --hash=sha256:21cc459636983764e692b9eba7144cdd54fdec23ccdb1e8ba392a63666c60c34 \ - --hash=sha256:2e14c95c16211d166f59c6611533d0dacce2e25de0f76e4c140fde250997b3ca \ - --hash=sha256:2e2c116cc804d9b09ce9814d17df5edf1df0c624aba3b43bc1ad90411487036d \ - --hash=sha256:4021a7f963d88ccd15b523787d18ed5e5269ce57aa4037146a2377ff607ae87d \ - --hash=sha256:4c48f78f62ab596c679086084d0dd13254ae4f3d6c72a83ffdf5ebdef8f265a4 \ - --hash=sha256:4f5c2d012671b7391803263419e31b5c7c21e7c95c8760d7fc35602353dee714 \ - --hash=sha256:58b8f6366e152092194ae68fefe18b9f0b4f89227dfd86a07770c3d86097aebf \ - --hash=sha256:59a89bccd615551391f3237e00006a26bcf98a4d18623a19909a2c48b8e986ee \ - --hash=sha256:5cdd7ef1ea7a491ae70d826b6cc64868de09a1d5ff9ef8d574250d0940e275b8 \ - --hash=sha256:6288b3fa6622ad8a91e6eb759cfc48ff3089e7c17fb1d4c59a919769314af224 \ - --hash=sha256:6d270ec9dd33cdb13f4d62c95c1a5a50e6b7cdd86302b494217137f760495b9d \ - --hash=sha256:79ed12ba79935adaac1664fd7e0e585a22caa539dfc9b7c7c6d5ebf91fb89054 \ - --hash=sha256:7d29c23bdf6e5438c755b941cef867ec2a4a172ceb9f50553b6ed70d50dfd656 \ - --hash=sha256:8441cf9616d642c475684d6cf2520dd24812e996ba9af15e606df5f6fd9d04a7 \ - --hash=sha256:881bbea27bbd32d37eb24dd320a5e745a2a5b092a17f6debc1349252fac85423 \ - --hash=sha256:8c3aba3e01235221e5b229a6c05f585f344734bd1ad42a8ac51493d74722bbce \ - --hash=sha256:a14798c3005ec892bbada26485c2eea3b54109cb2533713e355c806891f63c5e \ - --hash=sha256:b14decb628fac50db5e02ee5a35a9c0772d20277824cfe845c8a8b717c15daa3 \ - --hash=sha256:b318ca24db0f0518630e8b6f3831e9cba78f099ed5c1d65ffe3e023003043ba0 \ - --hash=sha256:c1beb78af5423b879edaf23c5591ff292cf7c33979734c99aa66d5914ead880f \ - --hash=sha256:c55acc4733aad6560a7f5f818466631f07efc001fd023f34a6c203f8b6df0f0b \ - --hash=sha256:ca52d1ceae015859d16aded12584c59eb3825f7b50c6cfd621d4231a6cc624ce \ - --hash=sha256:cae40a9e3ce178415040a0383f00e8d68b569e97f31928a3a8ad37e3fde6df6a \ - --hash=sha256:e78d0c7c1e99a4a45c99143900ea0546025e41bb59ebc10182e947cf1ece9174 \ - --hash=sha256:ef3992833fbd686ee783590639f4b8343a57f1f75de8633749d984dc0eb16c86 \ - --hash=sha256:f058a615031eea4ef94ead6456f5ec2026c19fb5bd6bfe86e9665c4158cf802f \ - --hash=sha256:f5ac696f02b3fc01a710427585c855f65cd9c640e14f52abe52020722bb4906b \ - --hash=sha256:f920385a11207dc372a028b3f1e1038bb244b3ec38d448e6d8e43c6b3ba20e98 \ - --hash=sha256:fed2c3216a605dc9a6ea50c7e84c82906e3684c4e80d2908208f662a6cbf9022 - # via - # -c release/ray_release/byod/requirements_compiled.txt - # jsonschema pyspark==3.4.1 \ --hash=sha256:72cd66ab8cf61a75854e5a753f75bea35ee075c3a96f9de4e2a66d02ec7fc652 # via @@ -2993,6 +2952,13 @@ qpd==0.4.4 \ # via # -c release/ray_release/byod/requirements_compiled.txt # fugue +referencing==0.36.2 \ + --hash=sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa \ + --hash=sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0 + # via + # -c release/ray_release/byod/requirements_compiled.txt + # jsonschema + # jsonschema-specifications regex==2024.5.15 \ --hash=sha256:0721931ad5fe0dda45d07f9820b90b2148ccdd8e45bb9e9b42a146cb4f695649 \ --hash=sha256:10002e86e6068d9e1c91eae8295ef690f02f913c57db120b58fdd35a6bb1af35 \ @@ -3125,6 +3091,114 @@ rouge-score==0.1.2 \ roundrobin==0.0.4 \ --hash=sha256:7e9d19a5bd6123d99993fb935fa86d25c88bb2096e493885f61737ed0f5e9abd # via locust +rpds-py==0.22.3 \ + --hash=sha256:009de23c9c9ee54bf11303a966edf4d9087cd43a6003672e6aa7def643d06518 \ + --hash=sha256:02fbb9c288ae08bcb34fb41d516d5eeb0455ac35b5512d03181d755d80810059 \ + --hash=sha256:0a0461200769ab3b9ab7e513f6013b7a97fdeee41c29b9db343f3c5a8e2b9e61 \ + --hash=sha256:0b09865a9abc0ddff4e50b5ef65467cd94176bf1e0004184eb915cbc10fc05c5 \ + --hash=sha256:0b8db6b5b2d4491ad5b6bdc2bc7c017eec108acbf4e6785f42a9eb0ba234f4c9 \ + --hash=sha256:0c150c7a61ed4a4f4955a96626574e9baf1adf772c2fb61ef6a5027e52803543 \ + --hash=sha256:0f3cec041684de9a4684b1572fe28c7267410e02450f4561700ca5a3bc6695a2 \ + --hash=sha256:1352ae4f7c717ae8cba93421a63373e582d19d55d2ee2cbb184344c82d2ae55a \ + --hash=sha256:177c7c0fce2855833819c98e43c262007f42ce86651ffbb84f37883308cb0e7d \ + --hash=sha256:1978d0021e943aae58b9b0b196fb4895a25cc53d3956b8e35e0b7682eefb6d56 \ + --hash=sha256:1a60bce91f81ddaac922a40bbb571a12c1070cb20ebd6d49c48e0b101d87300d \ + --hash=sha256:1aef18820ef3e4587ebe8b3bc9ba6e55892a6d7b93bac6d29d9f631a3b4befbd \ + --hash=sha256:1e9663daaf7a63ceccbbb8e3808fe90415b0757e2abddbfc2e06c857bf8c5e2b \ + --hash=sha256:20070c65396f7373f5df4005862fa162db5d25d56150bddd0b3e8214e8ef45b4 \ + --hash=sha256:214b7a953d73b5e87f0ebece4a32a5bd83c60a3ecc9d4ec8f1dca968a2d91e99 \ + --hash=sha256:22bebe05a9ffc70ebfa127efbc429bc26ec9e9b4ee4d15a740033efda515cf3d \ + --hash=sha256:24e8abb5878e250f2eb0d7859a8e561846f98910326d06c0d51381fed59357bd \ + --hash=sha256:26fd7cac7dd51011a245f29a2cc6489c4608b5a8ce8d75661bb4a1066c52dfbe \ + --hash=sha256:27b1d3b3915a99208fee9ab092b8184c420f2905b7d7feb4aeb5e4a9c509b8a1 \ + --hash=sha256:27e98004595899949bd7a7b34e91fa7c44d7a97c40fcaf1d874168bb652ec67e \ + --hash=sha256:2b8f60e1b739a74bab7e01fcbe3dddd4657ec685caa04681df9d562ef15b625f \ + --hash=sha256:2de29005e11637e7a2361fa151f780ff8eb2543a0da1413bb951e9f14b699ef3 \ + --hash=sha256:2e8b55d8517a2fda8d95cb45d62a5a8bbf9dd0ad39c5b25c8833efea07b880ca \ + --hash=sha256:2fa4331c200c2521512595253f5bb70858b90f750d39b8cbfd67465f8d1b596d \ + --hash=sha256:3445e07bf2e8ecfeef6ef67ac83de670358abf2996916039b16a218e3d95e97e \ + --hash=sha256:3453e8d41fe5f17d1f8e9c383a7473cd46a63661628ec58e07777c2fff7196dc \ + --hash=sha256:378753b4a4de2a7b34063d6f95ae81bfa7b15f2c1a04a9518e8644e81807ebea \ + --hash=sha256:3af6e48651c4e0d2d166dc1b033b7042ea3f871504b6805ba5f4fe31581d8d38 \ + --hash=sha256:3dfcbc95bd7992b16f3f7ba05af8a64ca694331bd24f9157b49dadeeb287493b \ + --hash=sha256:3f21f0495edea7fdbaaa87e633a8689cd285f8f4af5c869f27bc8074638ad69c \ + --hash=sha256:4041711832360a9b75cfb11b25a6a97c8fb49c07b8bd43d0d02b45d0b499a4ff \ + --hash=sha256:44d61b4b7d0c2c9ac019c314e52d7cbda0ae31078aabd0f22e583af3e0d79723 \ + --hash=sha256:4617e1915a539a0d9a9567795023de41a87106522ff83fbfaf1f6baf8e85437e \ + --hash=sha256:4b232061ca880db21fa14defe219840ad9b74b6158adb52ddf0e87bead9e8493 \ + --hash=sha256:5246b14ca64a8675e0a7161f7af68fe3e910e6b90542b4bfb5439ba752191df6 \ + --hash=sha256:5725dd9cc02068996d4438d397e255dcb1df776b7ceea3b9cb972bdb11260a83 \ + --hash=sha256:583f6a1993ca3369e0f80ba99d796d8e6b1a3a2a442dd4e1a79e652116413091 \ + --hash=sha256:59259dc58e57b10e7e18ce02c311804c10c5a793e6568f8af4dead03264584d1 \ + --hash=sha256:593eba61ba0c3baae5bc9be2f5232430453fb4432048de28399ca7376de9c627 \ + --hash=sha256:59f4a79c19232a5774aee369a0c296712ad0e77f24e62cad53160312b1c1eaa1 \ + --hash=sha256:5f0e260eaf54380380ac3808aa4ebe2d8ca28b9087cf411649f96bad6900c728 \ + --hash=sha256:62d9cfcf4948683a18a9aff0ab7e1474d407b7bab2ca03116109f8464698ab16 \ + --hash=sha256:64607d4cbf1b7e3c3c8a14948b99345eda0e161b852e122c6bb71aab6d1d798c \ + --hash=sha256:655ca44a831ecb238d124e0402d98f6212ac527a0ba6c55ca26f616604e60a45 \ + --hash=sha256:666ecce376999bf619756a24ce15bb14c5bfaf04bf00abc7e663ce17c3f34fe7 \ + --hash=sha256:68049202f67380ff9aa52f12e92b1c30115f32e6895cd7198fa2a7961621fc5a \ + --hash=sha256:69803198097467ee7282750acb507fba35ca22cc3b85f16cf45fb01cb9097730 \ + --hash=sha256:6c7b99ca52c2c1752b544e310101b98a659b720b21db00e65edca34483259967 \ + --hash=sha256:6dd9412824c4ce1aca56c47b0991e65bebb7ac3f4edccfd3f156150c96a7bf25 \ + --hash=sha256:70eb60b3ae9245ddea20f8a4190bd79c705a22f8028aaf8bbdebe4716c3fab24 \ + --hash=sha256:70fb28128acbfd264eda9bf47015537ba3fe86e40d046eb2963d75024be4d055 \ + --hash=sha256:7b2513ba235829860b13faa931f3b6846548021846ac808455301c23a101689d \ + --hash=sha256:7ef9d9da710be50ff6809fed8f1963fecdfecc8b86656cadfca3bc24289414b0 \ + --hash=sha256:81e69b0a0e2537f26d73b4e43ad7bc8c8efb39621639b4434b76a3de50c6966e \ + --hash=sha256:8633e471c6207a039eff6aa116e35f69f3156b3989ea3e2d755f7bc41754a4a7 \ + --hash=sha256:8bd7c8cfc0b8247c8799080fbff54e0b9619e17cdfeb0478ba7295d43f635d7c \ + --hash=sha256:9253fc214112405f0afa7db88739294295f0e08466987f1d70e29930262b4c8f \ + --hash=sha256:99b37292234e61325e7a5bb9689e55e48c3f5f603af88b1642666277a81f1fbd \ + --hash=sha256:9bd7228827ec7bb817089e2eb301d907c0d9827a9e558f22f762bb690b131652 \ + --hash=sha256:9beeb01d8c190d7581a4d59522cd3d4b6887040dcfc744af99aa59fef3e041a8 \ + --hash=sha256:a63cbdd98acef6570c62b92a1e43266f9e8b21e699c363c0fef13bd530799c11 \ + --hash=sha256:a76e42402542b1fae59798fab64432b2d015ab9d0c8c47ba7addddbaf7952333 \ + --hash=sha256:ac0a03221cdb5058ce0167ecc92a8c89e8d0decdc9e99a2ec23380793c4dcb96 \ + --hash=sha256:b0b4136a252cadfa1adb705bb81524eee47d9f6aab4f2ee4fa1e9d3cd4581f64 \ + --hash=sha256:b25bc607423935079e05619d7de556c91fb6adeae9d5f80868dde3468657994b \ + --hash=sha256:b3d504047aba448d70cf6fa22e06cb09f7cbd761939fdd47604f5e007675c24e \ + --hash=sha256:bb47271f60660803ad11f4c61b42242b8c1312a31c98c578f79ef9387bbde21c \ + --hash=sha256:bbb232860e3d03d544bc03ac57855cd82ddf19c7a07651a7c0fdb95e9efea8b9 \ + --hash=sha256:bc27863442d388870c1809a87507727b799c8460573cfbb6dc0eeaef5a11b5ec \ + --hash=sha256:bc51abd01f08117283c5ebf64844a35144a0843ff7b2983e0648e4d3d9f10dbb \ + --hash=sha256:be2eb3f2495ba669d2a985f9b426c1797b7d48d6963899276d22f23e33d47e37 \ + --hash=sha256:bf9db5488121b596dbfc6718c76092fda77b703c1f7533a226a5a9f65248f8ad \ + --hash=sha256:c58e2339def52ef6b71b8f36d13c3688ea23fa093353f3a4fee2556e62086ec9 \ + --hash=sha256:cfbc454a2880389dbb9b5b398e50d439e2e58669160f27b60e5eca11f68ae17c \ + --hash=sha256:cff63a0272fcd259dcc3be1657b07c929c466b067ceb1c20060e8d10af56f5bf \ + --hash=sha256:d115bffdd417c6d806ea9069237a4ae02f513b778e3789a359bc5856e0404cc4 \ + --hash=sha256:d20cfb4e099748ea39e6f7b16c91ab057989712d31761d3300d43134e26e165f \ + --hash=sha256:d48424e39c2611ee1b84ad0f44fb3b2b53d473e65de061e3f460fc0be5f1939d \ + --hash=sha256:e0fa2d4ec53dc51cf7d3bb22e0aa0143966119f42a0c3e4998293a3dd2856b09 \ + --hash=sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d \ + --hash=sha256:e35ba67d65d49080e8e5a1dd40101fccdd9798adb9b050ff670b7d74fa41c566 \ + --hash=sha256:e3fb866d9932a3d7d0c82da76d816996d1667c44891bd861a0f97ba27e84fc74 \ + --hash=sha256:e61b02c3f7a1e0b75e20c3978f7135fd13cb6cf551bf4a6d29b999a88830a338 \ + --hash=sha256:e67ba3c290821343c192f7eae1d8fd5999ca2dc99994114643e2f2d3e6138b15 \ + --hash=sha256:e79dd39f1e8c3504be0607e5fc6e86bb60fe3584bec8b782578c3b0fde8d932c \ + --hash=sha256:e89391e6d60251560f0a8f4bd32137b077a80d9b7dbe6d5cab1cd80d2746f648 \ + --hash=sha256:ea7433ce7e4bfc3a85654aeb6747babe3f66eaf9a1d0c1e7a4435bbdf27fea84 \ + --hash=sha256:eaf16ae9ae519a0e237a0f528fd9f0197b9bb70f40263ee57ae53c2b8d48aeb3 \ + --hash=sha256:eb0c341fa71df5a4595f9501df4ac5abfb5a09580081dffbd1ddd4654e6e9123 \ + --hash=sha256:f276b245347e6e36526cbd4a266a417796fc531ddf391e43574cf6466c492520 \ + --hash=sha256:f47ad3d5f3258bd7058d2d506852217865afefe6153a36eb4b6928758041d831 \ + --hash=sha256:f56a6b404f74ab372da986d240e2e002769a7d7102cc73eb238a4f72eec5284e \ + --hash=sha256:f5cf2a0c2bdadf3791b5c205d55a37a54025c6e18a71c71f82bb536cf9a454bf \ + --hash=sha256:f5d36399a1b96e1a5fdc91e0522544580dbebeb1f77f27b2b0ab25559e103b8b \ + --hash=sha256:f60bd8423be1d9d833f230fdbccf8f57af322d96bcad6599e5a771b151398eb2 \ + --hash=sha256:f612463ac081803f243ff13cccc648578e2279295048f2a8d5eb430af2bae6e3 \ + --hash=sha256:f73d3fef726b3243a811121de45193c0ca75f6407fe66f3f4e183c983573e130 \ + --hash=sha256:f82a116a1d03628a8ace4859556fb39fd1424c933341a08ea3ed6de1edb0283b \ + --hash=sha256:fb0ba113b4983beac1a2eb16faffd76cb41e176bf58c4afe3e14b9c681f702de \ + --hash=sha256:fb4f868f712b2dd4bcc538b0a0c1f63a2b1d584c925e69a224d759e7070a12d5 \ + --hash=sha256:fb6116dfb8d1925cbdb52595560584db42a7f664617a1f7d7f6e32f138cdf37d \ + --hash=sha256:fda7cb070f442bf80b642cd56483b5548e43d366fe3f39b98e67cce780cded00 \ + --hash=sha256:feea821ee2a9273771bae61194004ee2fc33f8ec7db08117ef9147d4bbcbca8e + # via + # -c release/ray_release/byod/requirements_compiled.txt + # jsonschema + # referencing rsa==4.7.2 \ --hash=sha256:78f9a9bf4e7be0c5ded4583326e7461e3a3c5aae24073648b4bdfa797d78c9d2 \ --hash=sha256:9d689e6ca1b3038bc82bf8d23e944b6b6037bc02301a574935b2dd946e0353b9 @@ -3918,9 +3992,9 @@ typer==0.12.3 \ # via # -c release/ray_release/byod/requirements_compiled.txt # -r release/ray_release/byod/requirements_ml_byod_3.9.in -typing-extensions==4.8.0 \ - --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ - --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef +typing-extensions==4.12.2 \ + --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ + --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 # via # -c release/ray_release/byod/requirements_compiled.txt # -r release/ray_release/byod/requirements_ml_byod_3.9.in @@ -3931,6 +4005,7 @@ typing-extensions==4.8.0 \ # pydantic # pydantic-core # pytorch-lightning + # referencing # starlette # torch # typer diff --git a/release/release_data_tests.yaml b/release/release_data_tests.yaml index 13887c7b93bc2..c09255cd1da5c 100644 --- a/release/release_data_tests.yaml +++ b/release/release_data_tests.yaml @@ -354,6 +354,17 @@ python gpu_batch_inference.py --data-directory 300G-image-data-synthetic-raw-parquet --data-format parquet +- name: batch_inference_from_metadata + # This benchmark errors because of the issues described in PLAN-383. + frequency: manual + + cluster: + cluster_compute: autoscaling_hetero_compute.yaml + + run: + timeout: 1800 + script: python batch_inference_benchmark.py + - name: batch_inference_chaos stable: False # Don't use 'nightly_tests/dataset' as the working directory because we need to run diff --git a/release/release_logs/0.8.2/microbenchmark.txt b/release/release_logs/0.8.2/microbenchmark.txt index dddb272787c15..bc9daf6824538 100644 --- a/release/release_logs/0.8.2/microbenchmark.txt +++ b/release/release_logs/0.8.2/microbenchmark.txt @@ -15,4 +15,4 @@ multi client tasks async per second 42285.81 +- 238.55 1:1 actor calls concurrent per second 6167.01 +- 75.67 1:n actor calls async per second 12241.67 +- 62.13 n:n actor calls async per second 41766.33 +- 672.14 -n:n actor calls with arg async per second 13134.22 +- 71.68 \ No newline at end of file +n:n actor calls with arg async per second 13134.22 +- 71.68 diff --git a/release/release_logs/0.8.2/stress_tests/application_stress_test.txt b/release/release_logs/0.8.2/stress_tests/application_stress_test.txt index 62d1aa7c45ed1..6e6aadb519914 100644 --- a/release/release_logs/0.8.2/stress_tests/application_stress_test.txt +++ b/release/release_logs/0.8.2/stress_tests/application_stress_test.txt @@ -11,4 +11,4 @@ Number of trials: 4 (4 TERMINATED) | IMPALA_BeamRiderNoFrameskip-v4_2565e804 | TERMINATED | | BeamRiderNoFrameskip-v4 | 3124.8 | 24121.2 | 30057000 | 408 | | IMPALA_QbertNoFrameskip-v4_256671de | TERMINATED | | QbertNoFrameskip-v4 | 8388.25 | 25163.5 | 30080000 | 453 | | IMPALA_SpaceInvadersNoFrameskip-v4_256725ac | TERMINATED | | SpaceInvadersNoFrameskip-v4 | 780.65 | 23148.1 | 30026500 | 384 | -+---------------------------------------------+------------+-------+-----------------------------+----------+------------------+----------+--------+ \ No newline at end of file ++---------------------------------------------+------------+-------+-----------------------------+----------+------------------+----------+--------+ diff --git a/release/release_logs/0.8.2/stress_tests/test_dead_actors.txt b/release/release_logs/0.8.2/stress_tests/test_dead_actors.txt index cab48a589b91b..644e1d7a13117 100644 --- a/release/release_logs/0.8.2/stress_tests/test_dead_actors.txt +++ b/release/release_logs/0.8.2/stress_tests/test_dead_actors.txt @@ -1,4 +1,4 @@ Finished in: 98.49777579307556s Average iteration time: 0.9849753308296204s Max iteration time: 2.9459526538848877s -Min iteration time: 0.08075928688049316s \ No newline at end of file +Min iteration time: 0.08075928688049316s diff --git a/release/release_logs/0.8.2/stress_tests/test_many_tasks.txt b/release/release_logs/0.8.2/stress_tests/test_many_tasks.txt index be7b57a14c673..42208ee8994d8 100644 --- a/release/release_logs/0.8.2/stress_tests/test_many_tasks.txt +++ b/release/release_logs/0.8.2/stress_tests/test_many_tasks.txt @@ -12,4 +12,4 @@ Stage 2 results: Min iteration time: 121.44297170639038 Stage 3 results: Actor creation time: 0.0635519027709961 - Total time: 3464.0461547374725 \ No newline at end of file + Total time: 3464.0461547374725 diff --git a/release/release_logs/0.8.3/rllib_regression.txt b/release/release_logs/0.8.3/rllib_regression.txt index 5a9072949f99a..cec11ccdf2d62 100644 --- a/release/release_logs/0.8.3/rllib_regression.txt +++ b/release/release_logs/0.8.3/rllib_regression.txt @@ -37,4 +37,3 @@ Number of trials: 24 (24 TERMINATED) | DQN_BreakoutNoFrameskip-v4_00022 | TERMINATED | | 27 | 3669.77 | 270000 | 15.45 | | DQN_BreakoutNoFrameskip-v4_00023 | TERMINATED | | 27 | 3688.08 | 270000 | 12.25 | +-------------------------------------+------------+-------+--------+------------------+---------+----------+ - diff --git a/release/release_logs/0.8.4/rllib_regression.txt b/release/release_logs/0.8.4/rllib_regression.txt index 7f288268828fa..d38951b6413f7 100644 --- a/release/release_logs/0.8.4/rllib_regression.txt +++ b/release/release_logs/0.8.4/rllib_regression.txt @@ -37,11 +37,3 @@ Number of trials: 24 (8 ERROR, 16 TERMINATED) | DQN_BreakoutNoFrameskip-v4_00022 | TERMINATED | | 33 | 3679.57 | 330000 | 17.47 | | DQN_BreakoutNoFrameskip-v4_00023 | TERMINATED | | 33 | 3657.31 | 330000 | 17.47 | +-------------------------------------+------------+-------+--------+------------------+---------+-----------+ - - - - - - - - diff --git a/release/release_logs/0.8.5/stress_tests/application_stress_test.txt b/release/release_logs/0.8.5/stress_tests/application_stress_test.txt index 55b5b907c1990..f2f9d627a6690 100644 --- a/release/release_logs/0.8.5/stress_tests/application_stress_test.txt +++ b/release/release_logs/0.8.5/stress_tests/application_stress_test.txt @@ -12,4 +12,3 @@ Number of trials: 4 (4 TERMINATED) | IMPALA_QbertNoFrameskip-v4_00002 | TERMINATED | | QbertNoFrameskip-v4 | 356 | 5897.19 | 30077500 | 5230.25 | | IMPALA_SpaceInvadersNoFrameskip-v4_00003 | TERMINATED | | SpaceInvadersNoFrameskip-v4 | 361 | 5990.74 | 30103500 | 806.043 | +------------------------------------------+------------+-------+-----------------------------+--------+------------------+----------+----------+ - diff --git a/release/release_logs/0.8.5/stress_tests/test_dead_actors.txt b/release/release_logs/0.8.5/stress_tests/test_dead_actors.txt index e81addf7c1397..34a95a84d82e2 100644 --- a/release/release_logs/0.8.5/stress_tests/test_dead_actors.txt +++ b/release/release_logs/0.8.5/stress_tests/test_dead_actors.txt @@ -1,4 +1,4 @@ Finished in: 98.98898577690125s Average iteration time: 0.9898880553245545s Max iteration time: 2.3596835136413574s -Min iteration time: 0.1039724349975586s \ No newline at end of file +Min iteration time: 0.1039724349975586s diff --git a/release/release_logs/0.8.6/stress_tests/test_many_tasks.txt b/release/release_logs/0.8.6/stress_tests/test_many_tasks.txt index cb618992e5e0d..f4f38ff99c3ea 100644 --- a/release/release_logs/0.8.6/stress_tests/test_many_tasks.txt +++ b/release/release_logs/0.8.6/stress_tests/test_many_tasks.txt @@ -13,4 +13,4 @@ Stage 2 results: Stage 3 results: Actor creation time: 0.09029483795166016 Total time: 3138.855129480362 - \ No newline at end of file + diff --git a/release/release_logs/0.8.7/microbenchmark.txt b/release/release_logs/0.8.7/microbenchmark.txt index 8d91eca24e8e0..f96be906efd2b 100644 --- a/release/release_logs/0.8.7/microbenchmark.txt +++ b/release/release_logs/0.8.7/microbenchmark.txt @@ -19,4 +19,4 @@ n:n actor calls with arg async per second 10603.63 +- 95.77 1:1 async-actor calls async per second 3755.41 +- 53.77 1:1 async-actor calls with args async per second 2340.26 +- 62.86 1:n async-actor calls async per second 13353.84 +- 446.48 -n:n async-actor calls async per second 29600.63 +- 284.26 \ No newline at end of file +n:n async-actor calls async per second 29600.63 +- 284.26 diff --git a/release/release_logs/0.8.7/rllib_regression.txt b/release/release_logs/0.8.7/rllib_regression.txt index 61b432ae5c86c..83759daecb341 100644 --- a/release/release_logs/0.8.7/rllib_regression.txt +++ b/release/release_logs/0.8.7/rllib_regression.txt @@ -37,5 +37,3 @@ Number of trials: 24 (24 TERMINATED) | DQN_BreakoutNoFrameskip-v4_b13b3_00022 | TERMINATED | | 24 | 3677.74 | 250000 | 11.62 | | DQN_BreakoutNoFrameskip-v4_b13b3_00023 | TERMINATED | | 24 | 3664.08 | 250000 | 7.47 | +-------------------------------------------+------------+-------+--------+------------------+---------+----------+ - - diff --git a/release/release_logs/0.8.7/stress_tests/test_dead_actors.txt b/release/release_logs/0.8.7/stress_tests/test_dead_actors.txt index e5ffb1dd94740..b1aa978b3d49f 100644 --- a/release/release_logs/0.8.7/stress_tests/test_dead_actors.txt +++ b/release/release_logs/0.8.7/stress_tests/test_dead_actors.txt @@ -1,4 +1,4 @@ Finished in: 456.6799967288971s Average iteration time: 4.566798083782196s Max iteration time: 10.93368911743164s -Min iteration time: 1.0819299221038818s \ No newline at end of file +Min iteration time: 1.0819299221038818s diff --git a/release/release_logs/1.0.1/rllib_regression.txt b/release/release_logs/1.0.1/rllib_regression.txt index 3328abe5f88a6..091f24fe1ded1 100644 --- a/release/release_logs/1.0.1/rllib_regression.txt +++ b/release/release_logs/1.0.1/rllib_regression.txt @@ -44,4 +44,4 @@ Number of trials: 24/24 (24 TERMINATED) 21.07 | 47 | 8 | 2586.87 | +-------------------------------------------+------------+-------+--------+------------------+---------+ ----------+----------------------+----------------------+--------------------+ 2020-11-05 02:48:15,548 INFO tune.py:439 -- Total run time: 14646.45 seconds (14645.94 seconds for the t -uning loop). \ No newline at end of file +uning loop). diff --git a/release/release_logs/1.1.0/microbenchmark.txt b/release/release_logs/1.1.0/microbenchmark.txt index 293540326eaac..d50e9816893b2 100644 --- a/release/release_logs/1.1.0/microbenchmark.txt +++ b/release/release_logs/1.1.0/microbenchmark.txt @@ -19,4 +19,4 @@ n:n actor calls with arg async per second 13447.94 +- 66.48 1:1 async-actor calls async per second 4028.29 +- 149.49 1:1 async-actor calls with args async per second 2833.48 +- 118.95 1:n async-actor calls async per second 14861.94 +- 1054.37 -n:n async-actor calls async per second 39168.35 +- 870.17 \ No newline at end of file +n:n async-actor calls async per second 39168.35 +- 870.17 diff --git a/release/release_logs/1.1.0/rllib_regression_tf.txt b/release/release_logs/1.1.0/rllib_regression_tf.txt index 30b393e8e8c0b..a4f8d90f7919f 100644 --- a/release/release_logs/1.1.0/rllib_regression_tf.txt +++ b/release/release_logs/1.1.0/rllib_regression_tf.txt @@ -13,4 +13,4 @@ | PPO_BreakoutNoFrameskip-v4_ed940_00009 | TERMINATED | | 1308 | 3601.78 | 6540000 | 38.9 | 243 | 9 | 2943.55 | | SAC_HalfCheetahBulletEnv-v0_ed940_00010 | TERMINATED | | 80 | 3609.46 | 89000 | 590.35 | 693.03 | 514.213 | 1000 | | SAC_HalfCheetahBulletEnv-v0_ed940_00011 | TERMINATED | | 81 | 3629.88 | 90000 | 699.174 | 724.238 | 676.54 | 1000 | -+-------------------------------------------+------------+-------+--------+------------------+---------+----------+----------------------+----------------------+--------------------+ \ No newline at end of file ++-------------------------------------------+------------+-------+--------+------------------+---------+----------+----------------------+----------------------+--------------------+ diff --git a/release/release_logs/1.1.0/rllib_regression_torch.txt b/release/release_logs/1.1.0/rllib_regression_torch.txt index dd119c806254f..8b9317b458bc1 100644 --- a/release/release_logs/1.1.0/rllib_regression_torch.txt +++ b/release/release_logs/1.1.0/rllib_regression_torch.txt @@ -13,4 +13,4 @@ | PPO_BreakoutNoFrameskip-v4_dba56_00009 | TERMINATED | | 786 | 3601.17 | 3930000 | 71.25 | 394 | 10 | 2852.6 | | SAC_HalfCheetahBulletEnv-v0_dba56_00010 | TERMINATED | | 62 | 3606.85 | 71000 | 743.684 | 920.212 | 524.94 | 1000 | | SAC_HalfCheetahBulletEnv-v0_dba56_00011 | TERMINATED | | 63 | 3643.78 | 72000 | 598.945 | 646.041 | 531.009 | 1000 | -+-------------------------------------------+------------+-------+--------+------------------+---------+----------+----------------------+----------------------+--------------------+ \ No newline at end of file ++-------------------------------------------+------------+-------+--------+------------------+---------+----------+----------------------+----------------------+--------------------+ diff --git a/release/release_logs/1.1.0/stress_tests/test_dead_actors.txt b/release/release_logs/1.1.0/stress_tests/test_dead_actors.txt index 41bd4fb2f8d38..cd0495eac0113 100644 --- a/release/release_logs/1.1.0/stress_tests/test_dead_actors.txt +++ b/release/release_logs/1.1.0/stress_tests/test_dead_actors.txt @@ -1,4 +1,4 @@ Finished in: 101.62001442909241s Average iteration time: 1.016197898387909s Max iteration time: 3.120382308959961s -Min iteration time: 0.14046645164489746s \ No newline at end of file +Min iteration time: 0.14046645164489746s diff --git a/release/release_logs/1.1.0/stress_tests/test_many_tasks.txt b/release/release_logs/1.1.0/stress_tests/test_many_tasks.txt index 97490f84c440d..2d69c2107d2b0 100644 --- a/release/release_logs/1.1.0/stress_tests/test_many_tasks.txt +++ b/release/release_logs/1.1.0/stress_tests/test_many_tasks.txt @@ -14,4 +14,4 @@ Stage 3 results: Actor creation time: 0.4038546085357666 Total time: 2379.7233917713165 Stage 4 results: - Scheduling spread: 91.36037818586006. \ No newline at end of file + Scheduling spread: 91.36037818586006. diff --git a/release/release_logs/1.10.0/stress_tests/stress_test_dead_actors.json b/release/release_logs/1.10.0/stress_tests/stress_test_dead_actors.json index a8ef26a4fea12..79db367c75108 100644 --- a/release/release_logs/1.10.0/stress_tests/stress_test_dead_actors.json +++ b/release/release_logs/1.10.0/stress_tests/stress_test_dead_actors.json @@ -8,4 +8,4 @@ "_session_url": "https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_nKYi2QS7g2rFe9rK9mD7K6Ay", "_commit_url": "https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.10.0/5ea565317a8104c04ae7892bb9bb41c6d72f12df/ray-1.10.0-cp37-cp37m-manylinux2014_x86_64.whl", "_stable": true -} \ No newline at end of file +} diff --git a/release/release_logs/1.2.0/microbenchmark.txt b/release/release_logs/1.2.0/microbenchmark.txt index 064e8b4411d4f..e7e6f6048daeb 100644 --- a/release/release_logs/1.2.0/microbenchmark.txt +++ b/release/release_logs/1.2.0/microbenchmark.txt @@ -25,4 +25,4 @@ client: put calls per second 1346.13 +- 8.2 client: remote put calls per second 58855.54 +- 849.21 client: 1:1 actor calls sync per second 730.58 +- 11.66 client: 1:1 actor calls async per second 774.79 +- 14.1 -client: 1:1 actor calls concurrent per second 805.73 +- 11.46 \ No newline at end of file +client: 1:1 actor calls concurrent per second 805.73 +- 11.46 diff --git a/release/release_logs/1.2.0/scalability/distributed.txt b/release/release_logs/1.2.0/scalability/distributed.txt index 860875201cea8..6dc5649e21565 100644 --- a/release/release_logs/1.2.0/scalability/distributed.txt +++ b/release/release_logs/1.2.0/scalability/distributed.txt @@ -1,4 +1,4 @@ Actor time: 34.21903751100001 (10000 actors) │ Task time: 386.82114117900005 (10000 tasks) │ PG time: 31.368525181999985 (1000 placement groups) │ -Node launch time: 756.3447095859999 (250 nodes) \ No newline at end of file +Node launch time: 756.3447095859999 (250 nodes) diff --git a/release/release_logs/1.2.0/scalability/single_node.txt b/release/release_logs/1.2.0/scalability/single_node.txt index 7a100e3eae987..f9516e8cb6364 100644 --- a/release/release_logs/1.2.0/scalability/single_node.txt +++ b/release/release_logs/1.2.0/scalability/single_node.txt @@ -2,4 +2,4 @@ Many args time: 11.433474627000002 (10000 args) Many returns time: 4.487700554 (3000 returns) Ray.get time: 21.957432587999996 (10000 args) Queued task time: 124.148238013 (1000000 tasks) -Ray.get large object time: 35.118229127000006 (107374182400 bytes) \ No newline at end of file +Ray.get large object time: 35.118229127000006 (107374182400 bytes) diff --git a/release/release_logs/1.2.0/stress_tests/test_many_tasks.txt b/release/release_logs/1.2.0/stress_tests/test_many_tasks.txt index ffc9bc3cd483a..905f3b718a633 100644 --- a/release/release_logs/1.2.0/stress_tests/test_many_tasks.txt +++ b/release/release_logs/1.2.0/stress_tests/test_many_tasks.txt @@ -14,4 +14,4 @@ Stage 3 results: Actor creation time: 0.3304018974304199 Total time: 2303.117142677307 Stage 4 results: - Scheduling spread: 66.90121385927009. \ No newline at end of file + Scheduling spread: 66.90121385927009. diff --git a/release/release_logs/1.2.0/stress_tests/test_placement_group.txt b/release/release_logs/1.2.0/stress_tests/test_placement_group.txt index 62f8a7b747867..5fd7d15001b53 100644 --- a/release/release_logs/1.2.0/stress_tests/test_placement_group.txt +++ b/release/release_logs/1.2.0/stress_tests/test_placement_group.txt @@ -1,3 +1,3 @@ Avg placement group creating time: 0.2691924729741867 ms Avg placement group removing time: 0.8786630945927776 ms -Stress Test succeed. \ No newline at end of file +Stress Test succeed. diff --git a/release/release_logs/1.3.0/microbenchmark.txt b/release/release_logs/1.3.0/microbenchmark.txt index 52d9b907bb5c3..ff97e2faf9f76 100644 --- a/release/release_logs/1.3.0/microbenchmark.txt +++ b/release/release_logs/1.3.0/microbenchmark.txt @@ -25,4 +25,4 @@ client: put calls per second 1275.03 +- 4.84 client: remote put calls per second 62818.07 +- 1390.38 client: 1:1 actor calls sync per second 701.76 +- 15.77 client: 1:1 actor calls async per second 760.75 +- 23.3 -client: 1:1 actor calls concurrent per second 758.73 +- 28.31 \ No newline at end of file +client: 1:1 actor calls concurrent per second 758.73 +- 28.31 diff --git a/release/release_logs/1.3.0/rllib_tests/regression_tests_tf.txt b/release/release_logs/1.3.0/rllib_tests/regression_tests_tf.txt index faccf0827e6e3..0a7331273b18e 100644 --- a/release/release_logs/1.3.0/rllib_tests/regression_tests_tf.txt +++ b/release/release_logs/1.3.0/rllib_tests/regression_tests_tf.txt @@ -14,4 +14,4 @@ Number of trials: 12/12 (12 TERMINATED) | PPO_BreakoutNoFrameskip-v4_14938_00009 | TERMINATED | | 1019 | 3600.79 | 5095000 | 32.37 | 371 | 9 | 2706.1 | | SAC_HalfCheetahBulletEnv-v0_14938_00010 | TERMINATED | | 47 | 3632.69 | 56000 | 477.215 | 714.128 | -93.7317 | 1000 | | SAC_HalfCheetahBulletEnv-v0_14938_00011 | TERMINATED | | 25 | 3614.29 | 34000 | 573.477 | 617.906 | 467.872 | 1000 | -+-------------------------------------------+------------+-------+--------+------------------+---------+----------+----------------------+----------------------+--- \ No newline at end of file ++-------------------------------------------+------------+-------+--------+------------------+---------+----------+----------------------+----------------------+--- diff --git a/release/release_logs/1.3.0/rllib_tests/regression_tests_torch.txt b/release/release_logs/1.3.0/rllib_tests/regression_tests_torch.txt index 9643038b19666..9652132890d8a 100644 --- a/release/release_logs/1.3.0/rllib_tests/regression_tests_torch.txt +++ b/release/release_logs/1.3.0/rllib_tests/regression_tests_torch.txt @@ -27,4 +27,4 @@ Number of trials: 12/12 (12 TERMINATED) +-------------------------------------------+------------+-------+--------+------------------+---------+----------+----------------------+----------------------+--------------------+ -2021-04-09 06:43:05,537 INFO tune.py:549 -- Total run time: 7289.99 seconds (7289.48 seconds for the tuning loop). \ No newline at end of file +2021-04-09 06:43:05,537 INFO tune.py:549 -- Total run time: 7289.99 seconds (7289.48 seconds for the tuning loop). diff --git a/release/release_logs/1.3.0/scalability/distributed.txt b/release/release_logs/1.3.0/scalability/distributed.txt index bdcd6ce5f52ed..bccd4d61483f7 100644 --- a/release/release_logs/1.3.0/scalability/distributed.txt +++ b/release/release_logs/1.3.0/scalability/distributed.txt @@ -1,4 +1,4 @@ Actor time: 37.22384708500087 (10000 actors) Task time: 403.85309777899885 (10000 tasks) PG time: 30.63472470399972 (1000 placement groups) -Node launch time: 651.0706797980001 (250 nodes) \ No newline at end of file +Node launch time: 651.0706797980001 (250 nodes) diff --git a/release/release_logs/1.3.0/scalability/single_node.txt b/release/release_logs/1.3.0/scalability/single_node.txt index b5d378699d983..63c6341abb22d 100644 --- a/release/release_logs/1.3.0/scalability/single_node.txt +++ b/release/release_logs/1.3.0/scalability/single_node.txt @@ -2,4 +2,4 @@ Many args time: 13.942214363000005 (10000 args) Many returns time: 5.577208736999978 (3000 returns) Ray.get time: 24.999562051999987 (10000 args) Queued task time: 159.31476044200002 (1000000 tasks) -Ray.get large object time: 34.683988763 (107374182400 bytes) \ No newline at end of file +Ray.get large object time: 34.683988763 (107374182400 bytes) diff --git a/release/release_logs/1.3.0/stress_tests/dead_actors.txt b/release/release_logs/1.3.0/stress_tests/dead_actors.txt index fc8d232aba2ca..d263d986009a8 100644 --- a/release/release_logs/1.3.0/stress_tests/dead_actors.txt +++ b/release/release_logs/1.3.0/stress_tests/dead_actors.txt @@ -2,4 +2,4 @@ Finished in: 62.34827518463135s Average iteration time: 0.6234804892539978s Max iteration time: 2.0002951622009277s Min iteration time: 0.09386014938354492s - \ No newline at end of file + diff --git a/release/release_logs/1.3.0/stress_tests/placement_groups.txt b/release/release_logs/1.3.0/stress_tests/placement_groups.txt index 46a314972c956..8eea87097662f 100644 --- a/release/release_logs/1.3.0/stress_tests/placement_groups.txt +++ b/release/release_logs/1.3.0/stress_tests/placement_groups.txt @@ -1,3 +1,3 @@ Avg placement group creating time: 0.768609554054404 ms Avg placement group removing time: 0.7591186996999492 ms -Stress Test succeed. \ No newline at end of file +Stress Test succeed. diff --git a/release/release_logs/1.3.0/tune_tests/scalability_tests/test_bookkeeping_overhead.txt b/release/release_logs/1.3.0/tune_tests/scalability_tests/test_bookkeeping_overhead.txt index 8d6cbee95d093..0a8d0302177a4 100644 --- a/release/release_logs/1.3.0/tune_tests/scalability_tests/test_bookkeeping_overhead.txt +++ b/release/release_logs/1.3.0/tune_tests/scalability_tests/test_bookkeeping_overhead.txt @@ -2,4 +2,4 @@ AssertionError: The bookkeeping overhead test took 866.63 seconds, but should no --- FAILED: BOOKKEEPING OVERHEAD ::: 866.63 > 800.00 --- -(signed off by Tune team) \ No newline at end of file +(signed off by Tune team) diff --git a/release/release_logs/1.3.0/tune_tests/scalability_tests/test_durable_trainable.txt b/release/release_logs/1.3.0/tune_tests/scalability_tests/test_durable_trainable.txt index 80c0581ae5d72..1f784a1d8f683 100644 --- a/release/release_logs/1.3.0/tune_tests/scalability_tests/test_durable_trainable.txt +++ b/release/release_logs/1.3.0/tune_tests/scalability_tests/test_durable_trainable.txt @@ -1,4 +1,4 @@ 2021-04-06 03:27:21,932 INFO tune.py:549 -- Total run time: 393.52 seconds (391.87 seconds for the tuning loop). The durable trainable test took 393.67 seconds, which is below the budget of 500.00 seconds. Test successful. ---- PASSED: DURABLE TRAINABLE ::: 393.67 <= 500.00 --- \ No newline at end of file +--- PASSED: DURABLE TRAINABLE ::: 393.67 <= 500.00 --- diff --git a/release/release_logs/1.3.0/tune_tests/scalability_tests/test_result_throughput_single_node.txt b/release/release_logs/1.3.0/tune_tests/scalability_tests/test_result_throughput_single_node.txt index a97413498fc44..5d6ad517b6773 100644 --- a/release/release_logs/1.3.0/tune_tests/scalability_tests/test_result_throughput_single_node.txt +++ b/release/release_logs/1.3.0/tune_tests/scalability_tests/test_result_throughput_single_node.txt @@ -1,4 +1,4 @@ 2021-04-06 00:50:08,299 INFO tune.py:549 -- Total run time: 110.36 seconds (109.79 seconds for the tuning loop). The result throughput single node test took 110.42 seconds, which is below the budget of 120.00 seconds. Test successful. ---- PASSED: RESULT THROUGHPUT SINGLE NODE ::: 110.42 <= 120.00 --- \ No newline at end of file +--- PASSED: RESULT THROUGHPUT SINGLE NODE ::: 110.42 <= 120.00 --- diff --git a/release/release_logs/1.3.0/tune_tests/scalability_tests/test_xgboost_sweep.txt b/release/release_logs/1.3.0/tune_tests/scalability_tests/test_xgboost_sweep.txt index 623777e41bfd5..64fab2279c116 100644 --- a/release/release_logs/1.3.0/tune_tests/scalability_tests/test_xgboost_sweep.txt +++ b/release/release_logs/1.3.0/tune_tests/scalability_tests/test_xgboost_sweep.txt @@ -8,4 +8,4 @@ AssertionError: The large xgboost sweep test took 3187.07 seconds, but should no --- FAILED: LARGE XGBOOST SWEEP ::: 3187.07 <= 2600.00 --- -Note: Target retrospecitvely adjusted to 3500, so test passed. \ No newline at end of file +Note: Target retrospecitvely adjusted to 3500, so test passed. diff --git a/release/release_logs/1.5.0/microbenchmark.txt b/release/release_logs/1.5.0/microbenchmark.txt index 30889cd2f5498..704655260b37e 100644 --- a/release/release_logs/1.5.0/microbenchmark.txt +++ b/release/release_logs/1.5.0/microbenchmark.txt @@ -25,4 +25,4 @@ client: put calls per second 803.36 +- 9.76 client: remote put calls per second 49220.37 +- 331.69 client: 1:1 actor calls sync per second 478.94 +- 6.98 client: 1:1 actor calls async per second 507.42 +- 6.53 -client: 1:1 actor calls concurrent per second 510.95 +- 9.80 \ No newline at end of file +client: 1:1 actor calls concurrent per second 510.95 +- 9.80 diff --git a/release/release_logs/1.5.0/scalability/single_node.txt b/release/release_logs/1.5.0/scalability/single_node.txt index 0d2613ca214de..8a7cbb8429b2d 100644 --- a/release/release_logs/1.5.0/scalability/single_node.txt +++ b/release/release_logs/1.5.0/scalability/single_node.txt @@ -2,4 +2,4 @@ Many args time: 13.32423606399999 (10000 args) Many returns time: 5.455830246000005 (3000 returns) Ray.get time: 30.079316559000006 (10000 args) Queued task time: 154.73578189300002 (1000000 tasks) -Ray.get large object time: 257.75048660999994 (107374182400 bytes) \ No newline at end of file +Ray.get large object time: 257.75048660999994 (107374182400 bytes) diff --git a/release/release_logs/1.5.0/stress_tests/dead_actors.txt b/release/release_logs/1.5.0/stress_tests/dead_actors.txt index 9568e861a235c..4ba47b6cfe2fd 100644 --- a/release/release_logs/1.5.0/stress_tests/dead_actors.txt +++ b/release/release_logs/1.5.0/stress_tests/dead_actors.txt @@ -1,4 +1,4 @@ Finished in: 131.18541312217712s Average iteration time: 1.311851441860199s Max iteration time: 3.8255529403686523s -Min iteration time: 0.023152828216552734s \ No newline at end of file +Min iteration time: 0.023152828216552734s diff --git a/release/release_logs/compare_perf_metrics b/release/release_logs/compare_perf_metrics index 360b1743570bb..5b1dc02d79304 100755 --- a/release/release_logs/compare_perf_metrics +++ b/release/release_logs/compare_perf_metrics @@ -131,7 +131,6 @@ def get_regressions(old_path, new_path): new_values.keys(), ) - regressions = [] throughput_regressions = [] latency_regression = [] for perf_metric_name in to_compare: diff --git a/release/requirements.txt b/release/requirements.txt index 87d281051731e..b78f157c77f9a 100644 --- a/release/requirements.txt +++ b/release/requirements.txt @@ -19,4 +19,4 @@ requests protobuf >= 3.15.3, != 3.19.5 pytz retry -kubernetes \ No newline at end of file +kubernetes diff --git a/release/requirements_buildkite.in b/release/requirements_buildkite.in index 0c20af4d90886..4773bb698d13b 100644 --- a/release/requirements_buildkite.in +++ b/release/requirements_buildkite.in @@ -22,3 +22,7 @@ tzdata aws_requests_auth requests >= 2.31.0 -r requirements-doc.txt + +# Upgrades +typing-extensions>=4.10 +jsonschema>=4.23.0 diff --git a/release/requirements_buildkite.txt b/release/requirements_buildkite.txt index 659db4c5e8db3..f93f6761a438c 100644 --- a/release/requirements_buildkite.txt +++ b/release/requirements_buildkite.txt @@ -810,10 +810,11 @@ jsonpointer==2.4 \ # via # jsonpatch # sphinx-jsonschema -jsonschema==4.22.0 \ - --hash=sha256:5b22d434a45935119af990552c862e5d6d564e8f6601206b305a61fdf661a2b7 \ - --hash=sha256:ff4cfd6b1367a40e7bc6411caec72effadd3db0bbe5017de188f2d6108335802 +jsonschema==4.23.0 \ + --hash=sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4 \ + --hash=sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566 # via + # -r release/requirements_buildkite.in # anyscale # nbformat # sphinxcontrib-redoc @@ -1890,6 +1891,7 @@ typing-extensions==4.11.0 \ --hash=sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0 \ --hash=sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a # via + # -r release/requirements_buildkite.in # aioitertools # anyio # anyscale diff --git a/release/serve_tests/compute_tpl_32_cpu_autoscaling_gce.yaml b/release/serve_tests/compute_tpl_32_cpu_autoscaling_gce.yaml index 18bf778e2cfc5..fedbe966e49e7 100644 --- a/release/serve_tests/compute_tpl_32_cpu_autoscaling_gce.yaml +++ b/release/serve_tests/compute_tpl_32_cpu_autoscaling_gce.yaml @@ -14,4 +14,4 @@ worker_node_types: instance_type: n2-standard-32 # m5.8xlarge min_workers: 5 max_workers: 35 - use_spot: false \ No newline at end of file + use_spot: false diff --git a/release/serve_tests/compute_tpl_32_cpu_gce.yaml b/release/serve_tests/compute_tpl_32_cpu_gce.yaml index 26a2474f3af33..31e3e8e438bdf 100644 --- a/release/serve_tests/compute_tpl_32_cpu_gce.yaml +++ b/release/serve_tests/compute_tpl_32_cpu_gce.yaml @@ -14,4 +14,4 @@ worker_node_types: instance_type: n2-standard-32 # m5.8xlarge min_workers: 32 max_workers: 32 - use_spot: false \ No newline at end of file + use_spot: false diff --git a/release/serve_tests/compute_tpl_8_cpu_autoscaling_gce.yaml b/release/serve_tests/compute_tpl_8_cpu_autoscaling_gce.yaml index a34796ff13908..550e15071fe0e 100644 --- a/release/serve_tests/compute_tpl_8_cpu_autoscaling_gce.yaml +++ b/release/serve_tests/compute_tpl_8_cpu_autoscaling_gce.yaml @@ -20,4 +20,4 @@ worker_node_types: use_spot: false resources: custom_resources: - proxy: 1 \ No newline at end of file + proxy: 1 diff --git a/release/serve_tests/compute_tpl_gpu_node_gce.yaml b/release/serve_tests/compute_tpl_gpu_node_gce.yaml index 9341c01b88992..98602cba49b50 100644 --- a/release/serve_tests/compute_tpl_gpu_node_gce.yaml +++ b/release/serve_tests/compute_tpl_gpu_node_gce.yaml @@ -14,4 +14,4 @@ worker_node_types: instance_type: n2-standard-16 # m5.4xlarge min_workers: 0 max_workers: 1 - use_spot: false \ No newline at end of file + use_spot: false diff --git a/release/serve_tests/compute_tpl_single_node_32_cpu_gce.yaml b/release/serve_tests/compute_tpl_single_node_32_cpu_gce.yaml index c8efe5ca98413..c36e290c19194 100644 --- a/release/serve_tests/compute_tpl_single_node_32_cpu_gce.yaml +++ b/release/serve_tests/compute_tpl_single_node_32_cpu_gce.yaml @@ -9,4 +9,4 @@ head_node_type: name: head_node instance_type: n2-standard-32 # m5.8xlarge -worker_node_types: [] \ No newline at end of file +worker_node_types: [] diff --git a/release/serve_tests/workloads/imagenet_classes.txt b/release/serve_tests/workloads/imagenet_classes.txt index 888d6f51dd77b..f40829ed0fc31 100644 --- a/release/serve_tests/workloads/imagenet_classes.txt +++ b/release/serve_tests/workloads/imagenet_classes.txt @@ -997,4 +997,4 @@ earthstar hen-of-the-woods bolete ear -toilet tissue \ No newline at end of file +toilet tissue diff --git a/release/tune_tests/fault_tolerance_tests/workloads/terminate_node_aws.py b/release/tune_tests/fault_tolerance_tests/workloads/terminate_node_aws.py index 464bfb0602656..e87253c7ba6a0 100644 --- a/release/tune_tests/fault_tolerance_tests/workloads/terminate_node_aws.py +++ b/release/tune_tests/fault_tolerance_tests/workloads/terminate_node_aws.py @@ -6,7 +6,8 @@ import ray from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy -from ray._private.test_utils import safe_write_to_results_json + +# from ray._private.test_utils import safe_write_to_results_json logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -127,7 +128,7 @@ def kill(self): "terminated_succesfully": terminated_succesfully, } ) - safe_write_to_results_json(self.history) + # safe_write_to_results_json(self.history) def create_instance_killer( diff --git a/rllib/BUILD b/rllib/BUILD index fb98e1fe0d323..556ca6a0de513 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -325,6 +325,18 @@ py_test( ], args = ["--as-test", "--enable-new-api-stack"] ) +py_test( + name = "learning_tests_cartpole_bc_gpu", + main = "tuned_examples/bc/cartpole_bc.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], + size = "medium", + srcs = ["tuned_examples/bc/cartpole_bc.py"], + # Include the offline data files. + data = [ + "tests/data/cartpole/cartpole-v1_large", + ], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus-per-learner=1"] +) # CQL # Pendulum @@ -340,6 +352,19 @@ py_test( ], args = ["--as-test", "--enable-new-api-stack"] ) +# GPU training. +py_test( + name = "learning_tests_pendulum_cql_gpu", + main = "tuned_examples/cql/pendulum_cql.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_cartpole", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], + size = "large", + srcs = ["tuned_examples/cql/pendulum_cql.py"], + # Include the zipped json data file as well. + data = [ + "tests/data/pendulum/pendulum-v1_enormous", + ], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus-per-learner=1"] +) # DQN # CartPole @@ -525,6 +550,19 @@ py_test( ], args = ["--as-test", "--enable-new-api-stack"] ) +# GPU-training. +py_test( + name = "learning_tests_cartpole_marwil_gpu", + main = "tuned_examples/marwil/cartpole_marwil.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], + size = "large", + srcs = ["tuned_examples/marwil/cartpole_marwil.py"], + # Include the offline data files. + data = [ + "tests/data/cartpole/cartpole-v1_large", + ], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus-per-learner=1"] +) # PPO # CartPole @@ -923,15 +961,6 @@ py_test( data = ["tests/data/cartpole/cartpole-v1_large"], srcs = ["algorithms/bc/tests/test_bc.py"] ) -# @OldAPIStack -py_test( - name = "test_bc_old_api_stack", - tags = ["team:rllib", "algorithms_dir"], - size = "medium", - # Include the json data file. - data = ["tests/data/cartpole/large.json"], - srcs = ["algorithms/bc/tests/test_bc_old_api_stack.py"] -) # CQL # @OldAPIStack @@ -1002,19 +1031,6 @@ py_test( ], srcs = ["algorithms/marwil/tests/test_marwil_rl_module.py"] ) -# @OldAPIStack -py_test( - name = "test_marwil_old_api_stack", - tags = ["team:rllib", "algorithms_dir"], - size = "large", - # Include the json data file. - data = [ - "tests/data/cartpole/large.json", - "tests/data/pendulum/large.json", - "tests/data/cartpole/small.json", - ], - srcs = ["algorithms/marwil/tests/test_marwil_old_api_stack.py"] -) # PPO py_test( @@ -1035,13 +1051,6 @@ py_test( size = "large", srcs = ["algorithms/ppo/tests/test_ppo_learner.py"] ) -# @OldAPIStack -py_test( - name = "test_ppo_old_api_stack", - tags = ["team:rllib", "algorithms_dir"], - size = "large", - srcs = ["algorithms/ppo/tests/test_ppo_old_api_stack.py"] -) # SAC py_test( @@ -1051,34 +1060,6 @@ py_test( srcs = ["algorithms/sac/tests/test_sac.py"] ) -# -------------------------------------------------------------------- -# Connector(V1) tests -# rllib/connector/ -# -# Tag: connector -# -------------------------------------------------------------------- - -py_test( - name = "connectors/tests/test_connector", - tags = ["team:rllib", "connector"], - size = "small", - srcs = ["connectors/tests/test_connector.py"] -) - -py_test( - name = "connectors/tests/test_action", - tags = ["team:rllib", "connector"], - size = "small", - srcs = ["connectors/tests/test_action.py"] -) - -py_test( - name = "connectors/tests/test_agent", - tags = ["team:rllib", "connector"], - size = "medium", - srcs = ["connectors/tests/test_agent.py"] -) - # -------------------------------------------------------------------- # ConnectorV2 tests # rllib/connector/ @@ -1143,13 +1124,6 @@ py_test( srcs = ["env/tests/test_single_agent_episode.py"] ) -py_test( - name = "env/wrappers/tests/test_exception_wrapper", - tags = ["team:rllib", "env"], - size = "small", - srcs = ["env/wrappers/tests/test_exception_wrapper.py"] -) - py_test( name = "env/wrappers/tests/test_group_agents_wrapper", tags = ["team:rllib", "env"], @@ -1374,28 +1348,12 @@ py_test( # Tag: models # -------------------------------------------------------------------- -py_test( - name = "test_conv2d_default_stacks", - tags = ["team:rllib", "models"], - size = "small", - srcs = ["models/tests/test_conv2d_default_stacks.py"] -) - -py_test( - name = "test_convtranspose2d_stack", - tags = ["team:rllib", "models"], - size = "medium", - data = glob(["tests/data/images/obstacle_tower.png"]), - srcs = ["models/tests/test_convtranspose2d_stack.py"] -) - py_test( name = "test_action_distributions", tags = ["team:rllib", "models"], size = "medium", srcs = ["models/tests/test_action_distributions.py"] ) - py_test( name = "test_distributions", tags = ["team:rllib", "models"], @@ -1403,21 +1361,6 @@ py_test( srcs = ["models/tests/test_distributions.py"] ) -py_test( - name = "test_lstms", - tags = ["team:rllib", "models"], - size = "large", - srcs = ["models/tests/test_lstms.py"] -) - -py_test( - name = "test_preprocessors", - tags = ["team:rllib", "models"], - size = "medium", - srcs = ["models/tests/test_preprocessors.py"] -) - - # -------------------------------------------------------------------- # Offline # rllib/offline/ @@ -1824,45 +1767,6 @@ py_test( srcs = ["tests/test_dependency_torch.py"] ) -py_test( - name = "tests/test_eager_support_policy_gradient", - main = "tests/test_eager_support.py", - tags = ["team:rllib", "tests_dir"], - size = "small", - srcs = ["tests/test_eager_support.py"], - args = ["TestEagerSupportPolicyGradient"] -) - -py_test( - name = "tests/test_eager_support_off_policy", - main = "tests/test_eager_support.py", - tags = ["team:rllib", "tests_dir"], - size = "small", - srcs = ["tests/test_eager_support.py"], - args = ["TestEagerSupportOffPolicy"] -) - -py_test( - name = "tests/test_filters", - tags = ["team:rllib", "tests_dir"], - size = "small", - srcs = ["tests/test_filters.py"] -) - -py_test( - name = "tests/test_gpus", - tags = ["team:rllib", "tests_dir"], - size = "large", - srcs = ["tests/test_gpus.py"] -) - -py_test( - name = "tests/test_io", - tags = ["team:rllib", "tests_dir"], - size = "large", - srcs = ["tests/test_io.py"] -) - py_test( name = "tests/test_local", tags = ["team:rllib", "tests_dir"], @@ -1898,13 +1802,6 @@ py_test( srcs = ["tests/test_placement_groups.py"] ) -py_test( - name = "tests/test_reproducibility", - tags = ["team:rllib", "tests_dir"], - size = "medium", - srcs = ["tests/test_reproducibility.py"] -) - py_test( name = "tests/test_timesteps", tags = ["team:rllib", "tests_dir"], @@ -2939,26 +2836,6 @@ py_test( args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--stop-reward-pretraining=250.0", "--stop-reward=250.0", "--stop-iters=3"], ) -#@OldAPIStack -py_test( - name = "examples/centralized_critic_tf", - main = "examples/centralized_critic.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/centralized_critic.py"], - args = ["--as-test", "--framework=tf", "--stop-reward=7.2"] -) - -#@OldAPIStack -py_test( - name = "examples/centralized_critic_torch", - main = "examples/centralized_critic.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/centralized_critic.py"], - args = ["--as-test", "--framework=torch", "--stop-reward=7.2"] -) - py_test( name = "examples/replay_buffer_api", tags = ["team:rllib", "examples"], diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 58012e4c077bf..9bebd2d7f712e 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -192,27 +192,32 @@ class AlgorithmBase: @staticmethod def _get_learner_bundles( - cf: AlgorithmConfig, + config: AlgorithmConfig, ) -> List[Dict[str, Union[float, int]]]: """Selects the right resource bundles for learner workers based off of cf. Args: - cf: The AlgorithmConfig instance to extract bundle-information from. + config: The AlgorithmConfig instance to extract bundle-information from. Returns: A list of resource bundles for the learner workers. """ - assert cf.num_learners > 0 + _num = config.num_learners + assert _num > 0 + + num_cpus_per_learner = ( + config.num_cpus_per_learner + if config.num_cpus_per_learner != "auto" + else 1 + if config.num_gpus_per_learner == 0 + else 0 + ) - _num = cf.num_learners all_learners = [ { "CPU": _num - * ( - (cf.num_cpus_per_learner if cf.num_gpus_per_learner == 0 else 0) - + cf.num_aggregator_actors_per_learner - ), - "GPU": _num * max(0, cf.num_gpus_per_learner), + * (num_cpus_per_learner + config.num_aggregator_actors_per_learner), + "GPU": _num * config.num_gpus_per_learner, } ] @@ -235,8 +240,10 @@ class Algorithm(Checkpointable, Trainable, AlgorithmBase): their respective [algo name].py files, for example: `ray.rllib.algorithms.dqn.dqn.py` or `ray.rllib.algorithms.impala.impala.py`. - The most important API methods a Algorithm exposes are `train()`, - `evaluate()`, `save_to_path()` and `restore_from_path()`. + The most important API methods an Algorithm exposes are `train()` for running a + single training iteration, `evaluate()` for running a single round of evaluation, + `save_to_path()` for creating a checkpoint, and `restore_from_path()` for loading a + state from an existing checkpoint. """ #: The AlgorithmConfig instance of the Algorithm. @@ -314,7 +321,7 @@ class Algorithm(Checkpointable, Trainable, AlgorithmBase): @override(Checkpointable) def from_checkpoint( cls, - path: Optional[Union[str, Checkpoint]] = None, + path: Union[str, Checkpoint], filesystem: Optional["pyarrow.fs.FileSystem"] = None, *, # @OldAPIStack @@ -333,20 +340,19 @@ def from_checkpoint( """Creates a new algorithm instance from a given checkpoint. Args: - path: The path (str) to the checkpoint directory to use - or an AIR Checkpoint instance to restore from. + path: The path (str) to the checkpoint directory to use or a Ray Train + Checkpoint instance to restore from. filesystem: PyArrow FileSystem to use to access data at the `path`. If not specified, this is inferred from the URI scheme of `path`. policy_ids: Optional list of PolicyIDs to recover. This allows users to restore an Algorithm with only a subset of the originally present Policies. - policy_mapping_fn: An optional (updated) policy mapping function - to use from here on. - policies_to_train: An optional list of policy IDs to be trained - or a callable taking PolicyID and SampleBatchType and - returning a bool (trainable or not?). - If None, will keep the existing setup in place. Policies, - whose IDs are not in the list (or for which the callable + policy_mapping_fn: An optional (updated) policy mapping function to use from + here on. + policies_to_train: An optional list of policy IDs to be trained or a + callable taking PolicyID and SampleBatchType and returning a bool + (trainable or not?). If None, will keep the existing setup in place. + Policies, whose IDs are not in the list (or for which the callable returns False) will not be updated. Returns: @@ -356,19 +362,20 @@ def from_checkpoint( deprecation_warning( old="Algorithm.from_checkpoint(checkpoint=...)", new="Algorithm.from_checkpoint(path=...)", - error=False, - ) - path = checkpoint - if path is None: - raise ValueError( - "`path` not provided in call to Algorithm.from_checkpoint()!" + error=True, ) - checkpoint_info = get_checkpoint_info(path) + # New API stack -> Use Checkpointable's default implementation. + if checkpoint_info["checkpoint_version"] >= version.Version("2.0"): + # `path` is a Checkpoint instance: Translate to directory and continue. + if isinstance(path, Checkpoint): + path = path.to_directory() + return super().from_checkpoint(path, filesystem=filesystem, **kwargs) + # Not possible for (v0.1) (algo class and config information missing # or very hard to retrieve). - if checkpoint_info["checkpoint_version"] == version.Version("0.1"): + elif checkpoint_info["checkpoint_version"] == version.Version("0.1"): raise ValueError( "Cannot restore a v0 checkpoint using `Algorithm.from_checkpoint()`!" "In this case, do the following:\n" @@ -382,9 +389,6 @@ def from_checkpoint( "()` must be 1.0 or later! You are using a checkpoint with " f"version v{checkpoint_info['checkpoint_version']}." ) - # New API stack -> Use Checkpointable's default implementation. - elif checkpoint_info["checkpoint_version"] >= version.Version("2.0"): - return super().from_checkpoint(path, filesystem=filesystem, **kwargs) # This is a msgpack checkpoint. if checkpoint_info["format"] == "msgpack": @@ -418,40 +422,7 @@ def from_checkpoint( return Algorithm.from_state(state) - @OldAPIStack - @staticmethod - def from_state(state: Dict) -> "Algorithm": - """Recovers an Algorithm from a state object. - - The `state` of an instantiated Algorithm can be retrieved by calling its - `get_state` method. It contains all information necessary - to create the Algorithm from scratch. No access to the original code (e.g. - configs, knowledge of the Algorithm's class, etc..) is needed. - - Args: - state: The state to recover a new Algorithm instance from. - - Returns: - A new Algorithm instance. - """ - algorithm_class: Type[Algorithm] = state.get("algorithm_class") - if algorithm_class is None: - raise ValueError( - "No `algorithm_class` key was found in given `state`! " - "Cannot create new Algorithm." - ) - # algo_class = get_trainable_cls(algo_class_name) - # Create the new algo. - config = state.get("config") - if not config: - raise ValueError("No `config` found in given Algorithm state!") - new_algo = algorithm_class(config=config) - # Set the new algo's state. - new_algo.__setstate__(state) - - # Return the new algo. - return new_algo - + @PublicAPI def __init__( self, config: Optional[AlgorithmConfig] = None, @@ -865,13 +836,15 @@ def setup(self, config: AlgorithmConfig) -> None: # Provide the actor handles for the learners for module # updating during preprocessing. self.offline_data.learner_handles = self.learner_group._workers - # Provide the module_spec. Note, in the remote case this is needed - # because the learner module cannot be copied, but must be built. - self.offline_data.module_spec = module_spec # Otherwise we can simply pass in the local learner. else: self.offline_data.learner_handles = [self.learner_group._learner] - + # TODO (simon, sven): Replace these set-some-object's-attributes- + # directly? We should find some solution for this in the future, an API, + # or setting these in the OfflineData constructor? + # Provide the module_spec. Note, in the remote case this is needed + # because the learner module cannot be copied, but must be built. + self.offline_data.module_spec = module_spec # Provide the `OfflineData` instance with space information. It might # need it for reading recorded experiences. self.offline_data.spaces = spaces @@ -887,7 +860,8 @@ def setup(self, config: AlgorithmConfig) -> None: ) agg_cls = ray.remote( num_cpus=1, - num_gpus=0.01 if self.config.num_gpus_per_learner > 0 else 0, + # TODO (sven): Activate this when Ray has figured out GPU pre-loading. + # num_gpus=0.01 if self.config.num_gpus_per_learner > 0 else 0, max_restarts=-1, )(AggregatorActor) self._aggregator_actor_manager = FaultTolerantActorManager( @@ -903,29 +877,29 @@ def setup(self, config: AlgorithmConfig) -> None: ), ) # Get the devices of each learner. - learner_locations = [ - (i, loc) - for i, loc in enumerate( + learner_locations = list( + enumerate( self.learner_group.foreach_learner( func=lambda _learner: (_learner.node, _learner.device), ) ) - ] + ) # Get the devices of each AggregatorActor. - aggregator_locations = [ - (i, loc) - for i, loc in enumerate( + aggregator_locations = list( + enumerate( self._aggregator_actor_manager.foreach_actor( func=lambda actor: (actor._node, actor._device) ) ) - ] + ) self._aggregator_actor_to_learner = {} for agg_idx, aggregator_location in aggregator_locations: + aggregator_location = aggregator_location.get() for learner_idx, learner_location in learner_locations: - if learner_location.get() == aggregator_location.get(): - # Round-robin, in case all Learners are on same device (e.g. for - # CPU learners). + # TODO (sven): Activate full comparison (including device) when Ray + # has figured out GPU pre-loading. + if learner_location.get()[0] == aggregator_location[0]: + # Round-robin, in case all Learners are on same device/node. learner_locations = learner_locations[1:] + [ learner_locations[0] ] @@ -2168,321 +2142,6 @@ def set_weights(self, weights: Dict[PolicyID, dict]): ) self.env_runner_group.local_env_runner.set_weights(weights) - @OldAPIStack - def compute_single_action( - self, - observation: Optional[TensorStructType] = None, - state: Optional[List[TensorStructType]] = None, - *, - prev_action: Optional[TensorStructType] = None, - prev_reward: Optional[float] = None, - info: Optional[EnvInfoDict] = None, - input_dict: Optional[SampleBatch] = None, - policy_id: PolicyID = DEFAULT_POLICY_ID, - full_fetch: bool = False, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - episode=None, - unsquash_action: Optional[bool] = None, - clip_action: Optional[bool] = None, - # Kwargs placeholder for future compatibility. - **kwargs, - ) -> Union[ - TensorStructType, - Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]], - ]: - """Computes an action for the specified policy on the local worker. - - Note that you can also access the policy object through - self.get_policy(policy_id) and call compute_single_action() on it - directly. - - Args: - observation: Single (unbatched) observation from the - environment. - state: List of all RNN hidden (single, unbatched) state tensors. - prev_action: Single (unbatched) previous action value. - prev_reward: Single (unbatched) previous reward value. - info: Env info dict, if any. - input_dict: An optional SampleBatch that holds all the values - for: obs, state, prev_action, and prev_reward, plus maybe - custom defined views of the current env trajectory. Note - that only one of `obs` or `input_dict` must be non-None. - policy_id: Policy to query (only applies to multi-agent). - Default: "default_policy". - full_fetch: Whether to return extra action fetch results. - This is always set to True if `state` is specified. - explore: Whether to apply exploration to the action. - Default: None -> use self.config.explore. - timestep: The current (sampling) time step. - episode: This provides access to all of the internal episodes' - state, which may be useful for model-based or multi-agent - algorithms. - unsquash_action: Should actions be unsquashed according to the - env's/Policy's action space? If None, use the value of - self.config.normalize_actions. - clip_action: Should actions be clipped according to the - env's/Policy's action space? If None, use the value of - self.config.clip_actions. - - Keyword Args: - kwargs: forward compatibility placeholder - - Returns: - The computed action if full_fetch=False, or a tuple of a) the - full output of policy.compute_actions() if full_fetch=True - or we have an RNN-based Policy. - - Raises: - KeyError: If the `policy_id` cannot be found in this Algorithm's local - worker. - """ - # `unsquash_action` is None: Use value of config['normalize_actions']. - if unsquash_action is None: - unsquash_action = self.config.normalize_actions - # `clip_action` is None: Use value of config['clip_actions']. - elif clip_action is None: - clip_action = self.config.clip_actions - - # User provided an input-dict: Assert that `obs`, `prev_a|r`, `state` - # are all None. - err_msg = ( - "Provide either `input_dict` OR [`observation`, ...] as " - "args to `Algorithm.compute_single_action()`!" - ) - if input_dict is not None: - assert ( - observation is None - and prev_action is None - and prev_reward is None - and state is None - ), err_msg - observation = input_dict[Columns.OBS] - else: - assert observation is not None, err_msg - - # Get the policy to compute the action for (in the multi-agent case, - # Algorithm may hold >1 policies). - policy = self.get_policy(policy_id) - if policy is None: - raise KeyError( - f"PolicyID '{policy_id}' not found in PolicyMap of the " - f"Algorithm's local worker!" - ) - # Just preprocess observations, similar to how it used to be done before. - pp = policy.agent_connectors[ObsPreprocessorConnector] - - # convert the observation to array if possible - if not isinstance(observation, (np.ndarray, dict, tuple)): - try: - observation = np.asarray(observation) - except Exception: - raise ValueError( - f"Observation type {type(observation)} cannot be converted to " - f"np.ndarray." - ) - if pp: - assert len(pp) == 1, "Only one preprocessor should be in the pipeline" - pp = pp[0] - - if not pp.is_identity(): - # Note(Kourosh): This call will leave the policy's connector - # in eval mode. would that be a problem? - pp.in_eval() - if observation is not None: - _input_dict = {Columns.OBS: observation} - elif input_dict is not None: - _input_dict = {Columns.OBS: input_dict[Columns.OBS]} - else: - raise ValueError( - "Either observation or input_dict must be provided." - ) - - # TODO (Kourosh): Create a new util method for algorithm that - # computes actions based on raw inputs from env and can keep track - # of its own internal state. - acd = AgentConnectorDataType("0", "0", _input_dict) - # make sure the state is reset since we are only applying the - # preprocessor - pp.reset(env_id="0") - ac_o = pp([acd])[0] - observation = ac_o.data[Columns.OBS] - - # Input-dict. - if input_dict is not None: - input_dict[Columns.OBS] = observation - action, state, extra = policy.compute_single_action( - input_dict=input_dict, - explore=explore, - timestep=timestep, - episode=episode, - ) - # Individual args. - else: - action, state, extra = policy.compute_single_action( - obs=observation, - state=state, - prev_action=prev_action, - prev_reward=prev_reward, - info=info, - explore=explore, - timestep=timestep, - episode=episode, - ) - - # If we work in normalized action space (normalize_actions=True), - # we re-translate here into the env's action space. - if unsquash_action: - action = space_utils.unsquash_action(action, policy.action_space_struct) - # Clip, according to env's action space. - elif clip_action: - action = space_utils.clip_action(action, policy.action_space_struct) - - # Return 3-Tuple: Action, states, and extra-action fetches. - if state or full_fetch: - return action, state, extra - # Ensure backward compatibility. - else: - return action - - @OldAPIStack - def compute_actions( - self, - observations: TensorStructType, - state: Optional[List[TensorStructType]] = None, - *, - prev_action: Optional[TensorStructType] = None, - prev_reward: Optional[TensorStructType] = None, - info: Optional[EnvInfoDict] = None, - policy_id: PolicyID = DEFAULT_POLICY_ID, - full_fetch: bool = False, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - episodes=None, - unsquash_actions: Optional[bool] = None, - clip_actions: Optional[bool] = None, - **kwargs, - ): - """Computes an action for the specified policy on the local Worker. - - Note that you can also access the policy object through - self.get_policy(policy_id) and call compute_actions() on it directly. - - Args: - observation: Observation from the environment. - state: RNN hidden state, if any. If state is not None, - then all of compute_single_action(...) is returned - (computed action, rnn state(s), logits dictionary). - Otherwise compute_single_action(...)[0] is returned - (computed action). - prev_action: Previous action value, if any. - prev_reward: Previous reward, if any. - info: Env info dict, if any. - policy_id: Policy to query (only applies to multi-agent). - full_fetch: Whether to return extra action fetch results. - This is always set to True if RNN state is specified. - explore: Whether to pick an exploitation or exploration - action (default: None -> use self.config.explore). - timestep: The current (sampling) time step. - episodes: This provides access to all of the internal episodes' - state, which may be useful for model-based or multi-agent - algorithms. - unsquash_actions: Should actions be unsquashed according - to the env's/Policy's action space? If None, use - self.config.normalize_actions. - clip_actions: Should actions be clipped according to the - env's/Policy's action space? If None, use - self.config.clip_actions. - - Keyword Args: - kwargs: forward compatibility placeholder - - Returns: - The computed action if full_fetch=False, or a tuple consisting of - the full output of policy.compute_actions_from_input_dict() if - full_fetch=True or we have an RNN-based Policy. - """ - # `unsquash_actions` is None: Use value of config['normalize_actions']. - if unsquash_actions is None: - unsquash_actions = self.config.normalize_actions - # `clip_actions` is None: Use value of config['clip_actions']. - elif clip_actions is None: - clip_actions = self.config.clip_actions - - # Preprocess obs and states. - state_defined = state is not None - policy = self.get_policy(policy_id) - filtered_obs, filtered_state = [], [] - for agent_id, ob in observations.items(): - worker = self.env_runner_group.local_env_runner - if worker.preprocessors.get(policy_id) is not None: - preprocessed = worker.preprocessors[policy_id].transform(ob) - else: - preprocessed = ob - filtered = worker.filters[policy_id](preprocessed, update=False) - filtered_obs.append(filtered) - if state is None: - continue - elif agent_id in state: - filtered_state.append(state[agent_id]) - else: - filtered_state.append(policy.get_initial_state()) - - # Batch obs and states - obs_batch = np.stack(filtered_obs) - if state is None: - state = [] - else: - state = list(zip(*filtered_state)) - state = [np.stack(s) for s in state] - - input_dict = {Columns.OBS: obs_batch} - - # prev_action and prev_reward can be None, np.ndarray, or tensor-like structure. - # Explicitly check for None here to avoid the error message "The truth value of - # an array with more than one element is ambiguous.", when np arrays are passed - # as arguments. - if prev_action is not None: - input_dict[SampleBatch.PREV_ACTIONS] = prev_action - if prev_reward is not None: - input_dict[SampleBatch.PREV_REWARDS] = prev_reward - if info: - input_dict[Columns.INFOS] = info - for i, s in enumerate(state): - input_dict[f"state_in_{i}"] = s - - # Batch compute actions - actions, states, infos = policy.compute_actions_from_input_dict( - input_dict=input_dict, - explore=explore, - timestep=timestep, - episodes=episodes, - ) - - # Unbatch actions for the environment into a multi-agent dict. - single_actions = space_utils.unbatch(actions) - actions = {} - for key, a in zip(observations, single_actions): - # If we work in normalized action space (normalize_actions=True), - # we re-translate here into the env's action space. - if unsquash_actions: - a = space_utils.unsquash_action(a, policy.action_space_struct) - # Clip, according to env's action space. - elif clip_actions: - a = space_utils.clip_action(a, policy.action_space_struct) - actions[key] = a - - # Unbatch states into a multi-agent dict. - unbatched_states = {} - for idx, agent_id in enumerate(observations): - unbatched_states[agent_id] = [s[idx] for s in states] - - # Return only actions or full tuple - if state_defined or full_fetch: - return actions, unbatched_states, infos - else: - return actions - @OldAPIStack def add_policy( self, @@ -2675,6 +2334,40 @@ def fn(worker): if remove_from_eval_env_runners and self.eval_env_runner_group is not None: self.eval_env_runner_group.foreach_env_runner(fn, local_env_runner=True) + @OldAPIStack + @staticmethod + def from_state(state: Dict) -> "Algorithm": + """Recovers an Algorithm from a state object. + + The `state` of an instantiated Algorithm can be retrieved by calling its + `get_state` method. It contains all information necessary + to create the Algorithm from scratch. No access to the original code (e.g. + configs, knowledge of the Algorithm's class, etc..) is needed. + + Args: + state: The state to recover a new Algorithm instance from. + + Returns: + A new Algorithm instance. + """ + algorithm_class: Type[Algorithm] = state.get("algorithm_class") + if algorithm_class is None: + raise ValueError( + "No `algorithm_class` key was found in given `state`! " + "Cannot create new Algorithm." + ) + # algo_class = get_trainable_cls(algo_class_name) + # Create the new algo. + config = state.get("config") + if not config: + raise ValueError("No `config` found in given Algorithm state!") + new_algo = algorithm_class(config=config) + # Set the new algo's state. + new_algo.__setstate__(state) + + # Return the new algo. + return new_algo + @OldAPIStack def export_policy_model( self, @@ -3048,10 +2741,17 @@ def default_resource_request( if cf.enable_rl_module_and_learner: # Training is done on local Learner. if cf.num_learners == 0: + num_cpus_per_learner = ( + cf.num_cpus_per_learner + if cf.num_cpus_per_learner != "auto" + else 1 + if cf.num_gpus_per_learner == 0 + else 0 + ) driver = { # Sampling and training is not done concurrently when local is # used, so pick the max. - "CPU": max(cf.num_cpus_per_learner, cf.num_cpus_for_main_process), + "CPU": max(num_cpus_per_learner, cf.num_cpus_for_main_process), "GPU": cf.num_gpus_per_learner, } # Training is done on n remote Learners. @@ -4134,6 +3834,208 @@ def _compile_iteration_results_old_api_stack( return results + @OldAPIStack + @Deprecated( + help="`Algorithm.compute_single_action` should no longer be used. Get the " + "RLModule instance through `Algorithm.get_module([module ID])`, then compute " + "actions through `RLModule.forward_inference({'obs': [obs batch]})`.", + error=False, + ) + def compute_single_action( + self, + observation: Optional[TensorStructType] = None, + state: Optional[List[TensorStructType]] = None, + *, + prev_action: Optional[TensorStructType] = None, + prev_reward: Optional[float] = None, + info: Optional[EnvInfoDict] = None, + input_dict: Optional[SampleBatch] = None, + policy_id: PolicyID = DEFAULT_POLICY_ID, + full_fetch: bool = False, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + episode=None, + unsquash_action: Optional[bool] = None, + clip_action: Optional[bool] = None, + ) -> Union[ + TensorStructType, + Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]], + ]: + if unsquash_action is None: + unsquash_action = self.config.normalize_actions + elif clip_action is None: + clip_action = self.config.clip_actions + + err_msg = ( + "Provide either `input_dict` OR [`observation`, ...] as " + "args to `Algorithm.compute_single_action()`!" + ) + if input_dict is not None: + assert ( + observation is None + and prev_action is None + and prev_reward is None + and state is None + ), err_msg + observation = input_dict[Columns.OBS] + else: + assert observation is not None, err_msg + + policy = self.get_policy(policy_id) + if policy is None: + raise KeyError( + f"PolicyID '{policy_id}' not found in PolicyMap of the " + f"Algorithm's local worker!" + ) + pp = policy.agent_connectors[ObsPreprocessorConnector] + + if not isinstance(observation, (np.ndarray, dict, tuple)): + try: + observation = np.asarray(observation) + except Exception: + raise ValueError( + f"Observation type {type(observation)} cannot be converted to " + f"np.ndarray." + ) + if pp: + assert len(pp) == 1, "Only one preprocessor should be in the pipeline" + pp = pp[0] + + if not pp.is_identity(): + pp.in_eval() + if observation is not None: + _input_dict = {Columns.OBS: observation} + elif input_dict is not None: + _input_dict = {Columns.OBS: input_dict[Columns.OBS]} + else: + raise ValueError( + "Either observation or input_dict must be provided." + ) + + acd = AgentConnectorDataType("0", "0", _input_dict) + pp.reset(env_id="0") + ac_o = pp([acd])[0] + observation = ac_o.data[Columns.OBS] + + if input_dict is not None: + input_dict[Columns.OBS] = observation + action, state, extra = policy.compute_single_action( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episode=episode, + ) + else: + action, state, extra = policy.compute_single_action( + obs=observation, + state=state, + prev_action=prev_action, + prev_reward=prev_reward, + info=info, + explore=explore, + timestep=timestep, + episode=episode, + ) + + if unsquash_action: + action = space_utils.unsquash_action(action, policy.action_space_struct) + elif clip_action: + action = space_utils.clip_action(action, policy.action_space_struct) + + if state or full_fetch: + return action, state, extra + else: + return action + + @OldAPIStack + @Deprecated( + help="`Algorithm.compute_actions` should no longer be used. Get the RLModule " + "instance through `Algorithm.get_module([module ID])`, then compute actions " + "through `RLModule.forward_inference({'obs': [obs batch]})`.", + error=False, + ) + def compute_actions( + self, + observations: TensorStructType, + state: Optional[List[TensorStructType]] = None, + *, + prev_action: Optional[TensorStructType] = None, + prev_reward: Optional[TensorStructType] = None, + info: Optional[EnvInfoDict] = None, + policy_id: PolicyID = DEFAULT_POLICY_ID, + full_fetch: bool = False, + explore: Optional[bool] = None, + timestep: Optional[int] = None, + episodes=None, + unsquash_actions: Optional[bool] = None, + clip_actions: Optional[bool] = None, + ): + if unsquash_actions is None: + unsquash_actions = self.config.normalize_actions + elif clip_actions is None: + clip_actions = self.config.clip_actions + + state_defined = state is not None + policy = self.get_policy(policy_id) + filtered_obs, filtered_state = [], [] + for agent_id, ob in observations.items(): + worker = self.env_runner_group.local_env_runner + if worker.preprocessors.get(policy_id) is not None: + preprocessed = worker.preprocessors[policy_id].transform(ob) + else: + preprocessed = ob + filtered = worker.filters[policy_id](preprocessed, update=False) + filtered_obs.append(filtered) + if state is None: + continue + elif agent_id in state: + filtered_state.append(state[agent_id]) + else: + filtered_state.append(policy.get_initial_state()) + + obs_batch = np.stack(filtered_obs) + if state is None: + state = [] + else: + state = list(zip(*filtered_state)) + state = [np.stack(s) for s in state] + + input_dict = {Columns.OBS: obs_batch} + + if prev_action is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action + if prev_reward is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward + if info: + input_dict[Columns.INFOS] = info + for i, s in enumerate(state): + input_dict[f"state_in_{i}"] = s + + actions, states, infos = policy.compute_actions_from_input_dict( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episodes=episodes, + ) + + single_actions = space_utils.unbatch(actions) + actions = {} + for key, a in zip(observations, single_actions): + if unsquash_actions: + a = space_utils.unsquash_action(a, policy.action_space_struct) + elif clip_actions: + a = space_utils.clip_action(a, policy.action_space_struct) + actions[key] = a + + unbatched_states = {} + for idx, agent_id in enumerate(observations): + unbatched_states[agent_id] = [s[idx] for s in states] + + if state_defined or full_fetch: + return actions, unbatched_states, infos + else: + return actions + @Deprecated(new="Algorithm.restore_env_runners", error=False) def restore_workers(self, *args, **kwargs): return self.restore_env_runners(*args, **kwargs) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index fb25f7ad3d6bd..33d0e28258458 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -355,9 +355,9 @@ def __init__(self, algo_class: Optional[type] = None): # `self.learners()` self.num_learners = 0 self.num_gpus_per_learner = 0 - self.num_cpus_per_learner = 1 + self.num_cpus_per_learner = "auto" self.num_aggregator_actors_per_learner = 0 - self.max_requests_in_flight_per_aggregator_actor = 100 + self.max_requests_in_flight_per_aggregator_actor = 3 self.local_gpu_idx = 0 # TODO (sven): This probably works even without any restriction # (allowing for any arbitrary number of requests in-flight). Test with @@ -2135,7 +2135,7 @@ def learners( self, *, num_learners: Optional[int] = NotProvided, - num_cpus_per_learner: Optional[Union[float, int]] = NotProvided, + num_cpus_per_learner: Optional[Union[str, float, int]] = NotProvided, num_gpus_per_learner: Optional[Union[float, int]] = NotProvided, num_aggregator_actors_per_learner: Optional[int] = NotProvided, max_requests_in_flight_per_aggregator_actor: Optional[float] = NotProvided, @@ -2153,14 +2153,16 @@ def learners( 1 GPU: `num_learners=4; num_gpus_per_learner=1` OR 4 GPUs total and model requires 2 GPUs: `num_learners=2; num_gpus_per_learner=2`). num_cpus_per_learner: Number of CPUs allocated per Learner worker. + If "auto" (default), use 1 if `num_gpus_per_learner=0`, otherwise 0. Only necessary for custom processing pipeline inside each Learner - requiring multiple CPU cores. Ignored if `num_learners=0`. + requiring multiple CPU cores. + If `num_learners=0`, RLlib creates only one local Learner instance and + the number of CPUs on the main process is + `max(num_cpus_per_learner, num_cpus_for_main_process)`. num_gpus_per_learner: Number of GPUs allocated per Learner worker. If `num_learners=0`, any value greater than 0 runs the training on a single GPU on the main process, while a value of 0 runs - the training on main process CPUs. If `num_gpus_per_learner` is > 0, - then you shouldn't change `num_cpus_per_learner` (from its default - value of 1). + the training on main process CPUs. num_aggregator_actors_per_learner: The number of aggregator actors per Learner (if num_learners=0, one local learner is created). Must be at least 1. Aggregator actors perform the task of a) converting episodes @@ -4555,20 +4557,7 @@ def _validate_framework_settings(self) -> None: def _validate_resources_settings(self): """Checks, whether resources related settings make sense.""" - - # TODO @Avnishn: This is a short-term work around due to - # https://github.com/ray-project/ray/issues/35409 - # Remove this once we are able to specify placement group bundle index in RLlib - if self.num_cpus_per_learner > 1 and self.num_gpus_per_learner > 0: - self._value_error( - "Can't set both `num_cpus_per_learner` > 1 and " - " `num_gpus_per_learner` > 0! Either set " - "`num_cpus_per_learner` > 1 (and `num_gpus_per_learner`" - "=0) OR set `num_gpus_per_learner` > 0 (and leave " - "`num_cpus_per_learner` at its default value of 1). " - "This is due to issues with placement group fragmentation. See " - "https://github.com/ray-project/ray/issues/35409 for more details." - ) + pass def _validate_multi_agent_settings(self): """Checks, whether multi-agent related settings make sense.""" diff --git a/rllib/algorithms/bc/tests/test_bc_old_api_stack.py b/rllib/algorithms/bc/tests/test_bc_old_api_stack.py deleted file mode 100644 index 335a751376ade..0000000000000 --- a/rllib/algorithms/bc/tests/test_bc_old_api_stack.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -from pathlib import Path -import unittest - -import ray -import ray.rllib.algorithms.bc as bc -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, -) -from ray.rllib.utils.test_utils import ( - check_compute_single_action, - check_train_results, -) - - -class TestBC(unittest.TestCase): - @classmethod - def setUpClass(cls): - ray.init() - - @classmethod - def tearDownClass(cls): - ray.shutdown() - - def test_bc_compilation_and_learning_from_offline_file(self): - """Test whether BC can be built with all frameworks. - - And learns from a historic-data file (while being evaluated on an - actual env using evaluation_num_env_runners > 0). - """ - rllib_dir = Path(__file__).parents[3] - print("rllib_dir={}".format(rllib_dir)) - # This has still to be done until `pathlib` will be used in the readers. - data_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json") - print(f"data_file={data_file} exists={os.path.isfile(data_file)}") - - config = ( - bc.BCConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .evaluation( - evaluation_interval=3, - evaluation_num_env_runners=1, - evaluation_duration=5, - evaluation_parallel_to_training=True, - evaluation_config=bc.BCConfig.overrides(input_="sampler"), - ) - .offline_data(input_=[data_file]) - ) - num_iterations = 350 - min_return_to_reach = 75.0 - - for recurrent in [True, False]: - # We only test recurrent networks with RLModules. - if recurrent: - # TODO (Artur): We read input data without a time-dimensions. - # In order for a recurrent offline learning RL Module to - # work, the input data needs to be transformed do add a - # time-dimension. - continue - - config.training(model={"use_lstm": recurrent}) - algo = config.build(env="CartPole-v1") - learnt = False - for i in range(num_iterations): - results = algo.train() - check_train_results(results) - print(results) - - eval_results = results.get("evaluation") - if eval_results: - mean_return = eval_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] - print("iter={} R={}".format(i, mean_return)) - # Learn until good reward is reached in the actual env. - if mean_return > min_return_to_reach: - print("learnt!") - learnt = True - break - - if not learnt: - raise ValueError( - "`BC` did not reach {} reward from expert offline " - "data!".format(min_return_to_reach) - ) - - check_compute_single_action(algo, include_prev_action_reward=True) - - algo.stop() - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 3a29db72a7e10..876424fe35494 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -31,7 +31,6 @@ ) from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.metrics import ( - ALL_MODULES, LEARNER_RESULTS, LEARNER_UPDATE_TIMER, LAST_TARGET_UPDATE_TS, @@ -213,6 +212,10 @@ def build_learner_connector( AddNextObservationsFromEpisodesToTrainBatch(), ) + # If training on GPU, do not convert batches to tensors. + if self.num_gpus_per_learner > 0: + pipeline.remove("NumpyToTensor") + return pipeline @override(SACConfig) @@ -307,7 +310,11 @@ def training_step(self) -> None: batch_or_iterator = self.offline_data.sample( num_samples=self.config.train_batch_size_per_learner, num_shards=self.config.num_learners, - return_iterator=self.config.num_learners > 1, + # Return an iterator, if a `Learner` should update + # multiple times per RLlib iteration. + return_iterator=self.config.dataset_num_iters_per_learner > 1 + if self.config.dataset_num_iters_per_learner + else True, ) # Updating the policy. @@ -323,24 +330,6 @@ def training_step(self) -> None: # Log training results. self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS) - # Synchronize weights. - # As the results contain for each policy the loss and in addition the - # total loss over all policies is returned, this total loss has to be - # removed. - modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES} - - if self.eval_env_runner_group: - # Update weights - after learning on the local worker - - # on all remote workers. Note, we only have the local `EnvRunner`, - # but from this `EnvRunner` the evaulation `EnvRunner`s get updated. - with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): - self.eval_env_runner_group.sync_weights( - # Sync weights from learner_group to all EnvRunners. - from_worker_or_learner_group=self.learner_group, - policies=modules_to_update, - inference_only=True, - ) - @OldAPIStack def _training_step_old_api_stack(self) -> ResultDict: # Collect SampleBatches from sample workers. diff --git a/rllib/algorithms/dqn/learner_thread.py b/rllib/algorithms/dqn/learner_thread.py deleted file mode 100644 index a3cb59916afcb..0000000000000 --- a/rllib/algorithms/dqn/learner_thread.py +++ /dev/null @@ -1,82 +0,0 @@ -import queue -import threading - -from ray.util.timer import _Timer -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder -from ray.rllib.utils.metrics.window_stat import WindowStat - -LEARNER_QUEUE_MAX_SIZE = 16 - -tf1, tf, tfv = try_import_tf() - - -class LearnerThread(threading.Thread): - """Background thread that updates the local model from replay data. - - The learner thread communicates with the main thread through Queues. This - is needed since Ray operations can only be run on the main thread. In - addition, moving heavyweight gradient ops session runs off the main thread - improves overall throughput. - """ - - def __init__(self, local_worker): - threading.Thread.__init__(self) - self.learner_queue_size = WindowStat("size", 50) - self.local_worker = local_worker - self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) - self.outqueue = queue.Queue() - self.queue_timer = _Timer() - self.grad_timer = _Timer() - self.overall_timer = _Timer() - self.daemon = True - self.policy_ids_updated = [] - self.stopped = False - self.learner_info = {} - - def run(self): - # Switch on eager mode if configured. - if self.local_worker.config.framework_str == "tf2": - tf1.enable_eager_execution() - while not self.stopped: - self.step() - - def step(self): - with self.overall_timer: - with self.queue_timer: - replay_actor, ma_batch = self.inqueue.get() - if ma_batch is not None: - prio_dict = {} - with self.grad_timer: - # Use LearnerInfoBuilder as a unified way to build the - # final results dict from `learn_on_loaded_batch` call(s). - # This makes sure results dicts always have the same - # structure no matter the setup (multi-GPU, multi-agent, - # minibatch SGD, tf vs torch). - learner_info_builder = LearnerInfoBuilder(num_devices=1) - multi_agent_results = self.local_worker.learn_on_batch(ma_batch) - self.policy_ids_updated.extend(list(multi_agent_results.keys())) - for pid, results in multi_agent_results.items(): - learner_info_builder.add_learn_on_batch_results(results, pid) - td_error = results["td_error"] - # Switch off auto-conversion from numpy to torch/tf - # tensors for the indices. This may lead to errors - # when sent to the buffer for processing - # (may get manipulated if they are part of a tensor). - ma_batch.policy_batches[pid].set_get_interceptor(None) - prio_dict[pid] = ( - ma_batch.policy_batches[pid].get("batch_indexes"), - td_error, - ) - self.learner_info = learner_info_builder.finalize() - self.grad_timer.push_units_processed(ma_batch.count) - # Put tuple: replay_actor, prio-dict, env-steps, and agent-steps into - # the queue. - self.outqueue.put( - (replay_actor, prio_dict, ma_batch.count, ma_batch.agent_steps()) - ) - self.learner_queue_size.push(self.inqueue.qsize()) - self.overall_timer.push_units_processed( - ma_batch and ma_batch.count or 0 - ) - del ma_batch diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index ad1436ac81ee7..fb55d154af3c9 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -27,13 +27,14 @@ from ray.rllib.utils.annotations import OldAPIStack, override from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.metrics import ( + AGGREGATOR_ACTOR_RESULTS, ALL_MODULES, ENV_RUNNER_RESULTS, LEARNER_GROUP, LEARNER_RESULTS, LEARNER_UPDATE_TIMER, MEAN_NUM_EPISODE_LISTS_RECEIVED, - MEAN_NUM_LEARNER_GROUP_RESULTS_RECEIVED, + MEAN_NUM_LEARNER_RESULTS_RECEIVED, MEAN_NUM_LEARNER_GROUP_UPDATE_CALLED, NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED, @@ -54,12 +55,10 @@ from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ( LearningRateOrSchedule, - PartialAlgorithmConfigDict, PolicyID, ResultDict, SampleBatchType, ) -from ray.tune.execution.placement_groups import PlacementGroupFactory logger = logging.getLogger(__name__) @@ -609,7 +608,10 @@ def training_step(self): ) # Log the average number of sample results (list of episodes) received. - self.metrics.log_value(MEAN_NUM_EPISODE_LISTS_RECEIVED, len(episode_refs)) + self.metrics.log_value( + (ENV_RUNNER_RESULTS, MEAN_NUM_EPISODE_LISTS_RECEIVED), + len(episode_refs), + ) time.sleep(0.01) @@ -620,6 +622,11 @@ def training_step(self): data_packages_for_aggregators = self._pre_queue_episode_refs( episode_refs, package_size=self.config.train_batch_size_per_learner ) + self.metrics.log_value( + (AGGREGATOR_ACTOR_RESULTS, "mean_num_input_packages"), + len(episode_refs), + ) + ma_batches_refs_remote_results = ( self._aggregator_actor_manager.fetch_ready_async_reqs( timeout_seconds=0.0, @@ -630,6 +637,10 @@ def training_step(self): ma_batches_refs = [] for call_result in ma_batches_refs_remote_results: ma_batches_refs.append((call_result.actor_id, call_result.get())) + self.metrics.log_value( + (AGGREGATOR_ACTOR_RESULTS, "mean_num_output_batches"), + len(ma_batches_refs), + ) while data_packages_for_aggregators: @@ -639,18 +650,31 @@ def _func(actor, p): num_agg = self.config.num_aggregator_actors_per_learner * ( self.config.num_learners or 1 ) - packs = data_packages_for_aggregators[:num_agg] - self._aggregator_actor_manager.foreach_actor_async( + packs, data_packages_for_aggregators = ( + data_packages_for_aggregators[:num_agg], + data_packages_for_aggregators[num_agg:], + ) + sent = self._aggregator_actor_manager.foreach_actor_async( func=[functools.partial(_func, p=p) for p in packs], tag="batches", ) - data_packages_for_aggregators = data_packages_for_aggregators[num_agg:] + self.metrics.log_value( + (AGGREGATOR_ACTOR_RESULTS, "num_env_steps_dropped_lifetime"), + self.config.train_batch_size_per_learner * (len(packs) - sent), + reduce="sum", + ) # Get n lists of m ObjRef[MABatch] (m=num_learners) to perform n calls to # all learner workers with the already GPU-located batches. data_packages_for_learner_group = self._pre_queue_batch_refs( ma_batches_refs ) + self.metrics.log_value( + (AGGREGATOR_ACTOR_RESULTS, "num_env_steps_aggregated_lifetime"), + self.config.train_batch_size_per_learner + * len(data_packages_for_learner_group), + reduce="sum", + ) else: data_packages_for_learner_group = self._pre_queue_episode_refs( @@ -686,6 +710,9 @@ def _func(actor, p): ), } if self.config.num_aggregator_actors_per_learner > 0: + assert len(batch_ref_or_episode_list_ref) == ( + self.config.num_learners or 1 + ) learner_results = self.learner_group.update_from_batch( batch=batch_ref_or_episode_list_ref, async_update=do_async_updates, @@ -712,23 +739,18 @@ def _func(actor, p): 1, reduce="sum", ) - if not do_async_updates: - learner_results = [learner_results] - - for results_from_n_learners in learner_results: - if not results_from_n_learners[0]: - continue - num_learner_group_results_received += 1 - for r in results_from_n_learners: - rl_module_state = r.pop( - "_rl_module_state_after_update", rl_module_state - ) - self.metrics.merge_and_log_n_dicts( - stats_dicts=results_from_n_learners, - key=LEARNER_RESULTS, + + num_learner_group_results_received += len(learner_results) + for result_from_1_learner in learner_results: + rl_module_state = result_from_1_learner.pop( + "_rl_module_state_after_update", rl_module_state ) + self.metrics.merge_and_log_n_dicts( + stats_dicts=learner_results, + key=LEARNER_RESULTS, + ) self.metrics.log_value( - key=MEAN_NUM_LEARNER_GROUP_RESULTS_RECEIVED, + key=(LEARNER_GROUP, MEAN_NUM_LEARNER_RESULTS_RECEIVED), value=num_learner_group_results_received, ) @@ -865,92 +887,6 @@ def _pre_queue_batch_refs( return batch_refs_for_learner_group - @classmethod - @override(Algorithm) - def default_resource_request( - cls, - config: Union[AlgorithmConfig, PartialAlgorithmConfigDict], - ): - if isinstance(config, AlgorithmConfig): - cf: IMPALAConfig = config - else: - cf: IMPALAConfig = cls.get_default_config().update_from_dict(config) - - eval_config = cf.get_evaluation_config_object() - - bundles = [] - - # Main process (old API stack). - if not cf.enable_rl_module_and_learner: - bundles.append( - { - "CPU": cf.num_cpus_for_main_process, - "GPU": 0 if cf._fake_gpus else cf.num_gpus, - } - ) - # Main process (no local learner). - elif cf.num_learners > 0: - bundles.append({"CPU": cf.num_cpus_for_main_process}) - # Main process (local learner). - else: - bundles.append( - { - "CPU": max( - cf.num_cpus_for_main_process, - cf.num_cpus_per_learner if cf.num_gpus_per_learner == 0 else 0, - ), - "GPU": max( - 0, - cf.num_gpus_per_learner - - 0.01 * cf.num_aggregator_actors_per_learner, - ), - } - ) - # Aggregation actors (for the local learner). - bundles += [ - {"CPU": 1, "GPU": 0.01 if cf.num_gpus_per_learner > 0 else 0} - for _ in range(cf.num_aggregator_actors_per_learner) - ] - - # EnvRunners. - bundles += [ - { - "CPU": cf.num_cpus_per_env_runner, - "GPU": cf.num_gpus_per_env_runner, - **cf.custom_resources_per_env_runner, - } - for _ in range(cf.num_env_runners) - ] - - # Evaluation (remote) workers. - bundles += ( - [ - { - # Note: The local eval worker is located on the driver - # CPU or not even created iff >0 eval workers. - "CPU": eval_config.num_cpus_per_env_runner, - "GPU": eval_config.num_gpus_per_env_runner, - **eval_config.custom_resources_per_env_runner, - } - for _ in range(cf.evaluation_num_env_runners) - ] - if cf.evaluation_interval - else [] - ) - # TODO (avnishn): Remove this once we have a way to extend placement group - # factories. - # Only if we have actual (remote) learner workers. In case of a local learner, - # the resource has already been taken care of above. - if cf.enable_rl_module_and_learner and cf.num_learners > 0: - bundles += cls._get_learner_bundles(cf) - - # Return PlacementGroupFactory containing all needed resources - # (already properly defined as device bundles). - return PlacementGroupFactory( - bundles=bundles, - strategy=cf.placement_strategy, - ) - @OldAPIStack def _training_step_old_api_stack(self): # First, check, whether our learner thread is still healthy. diff --git a/rllib/algorithms/impala/impala_learner.py b/rllib/algorithms/impala/impala_learner.py index 48e04636d003c..c9c8d700241da 100644 --- a/rllib/algorithms/impala/impala_learner.py +++ b/rllib/algorithms/impala/impala_learner.py @@ -1,4 +1,5 @@ from collections import deque +import queue import threading import time from typing import Any, Dict, Union @@ -64,10 +65,30 @@ def build(self) -> None: ) ) + # Create and start the GPU-loader thread. It picks up train-ready batches from + # the "GPU-loader queue" and loads them to the GPU, then places the GPU batches + # on the "update queue" for the actual RLModule forward pass and loss + # computations. + self._gpu_loader_in_queue = queue.Queue() + # Default is to have a learner thread. if not hasattr(self, "_learner_thread_in_queue"): self._learner_thread_in_queue = deque(maxlen=self.config.learner_queue_size) + # Create and start the GPU loader thread(s). + if self.config.num_gpus_per_learner > 0: + self._gpu_loader_threads = [ + _GPULoaderThread( + in_queue=self._gpu_loader_in_queue, + out_queue=self._learner_thread_in_queue, + device=self._device, + metrics_logger=self.metrics, + ) + for _ in range(self.config.num_gpu_loader_threads) + ] + for t in self._gpu_loader_threads: + t.start() + # Create and start the Learner thread. self._learner_thread = _LearnerThread( update_method=self._update_from_batch_or_episodes, @@ -92,16 +113,11 @@ def update_from_batch( self.before_gradient_based_update(timesteps=timesteps or {}) - if isinstance(self._learner_thread_in_queue, CircularBuffer): - ts_dropped = self._learner_thread_in_queue.add(batch) - self.metrics.log_value( - (ALL_MODULES, LEARNER_THREAD_ENV_STEPS_DROPPED), - ts_dropped, - reduce="sum", - ) - # Enqueue to Learner thread's in-queue. - else: - _LearnerThread.enqueue(self._learner_thread_in_queue, batch, self.metrics) + self._gpu_loader_in_queue.put(batch) + self.metrics.log_value( + (ALL_MODULES, QUEUE_SIZE_GPU_LOADER_QUEUE), + self._gpu_loader_in_queue.qsize(), + ) return self.metrics.reduce() @@ -136,6 +152,49 @@ def rl_module_required_apis(cls) -> list[type]: ImpalaLearner = IMPALALearner +class _GPULoaderThread(threading.Thread): + def __init__( + self, + *, + in_queue: queue.Queue, + out_queue: deque, + device: torch.device, + metrics_logger: MetricsLogger, + ): + super().__init__(name="_GPULoaderThread") + self.daemon = True + + self._in_queue = in_queue + self._out_queue = out_queue + self._ts_dropped = 0 + self._device = device + self.metrics = metrics_logger + + def run(self) -> None: + while True: + self._step() + + def _step(self) -> None: + # Get a new batch from the data (inqueue). + with self.metrics.log_time((ALL_MODULES, GPU_LOADER_QUEUE_WAIT_TIMER)): + ma_batch_on_cpu = self._in_queue.get() + + # Load the batch onto the GPU device. + with self.metrics.log_time((ALL_MODULES, GPU_LOADER_LOAD_TO_GPU_TIMER)): + ma_batch_on_gpu = ma_batch_on_cpu.to_device(self._device, pin_memory=True) + + if isinstance(self._out_queue, CircularBuffer): + ts_dropped = self._out_queue.add(ma_batch_on_gpu) + self.metrics.log_value( + (ALL_MODULES, LEARNER_THREAD_ENV_STEPS_DROPPED), + ts_dropped, + reduce="sum", + ) + else: + # Enqueue to Learner thread's in-queue. + _LearnerThread.enqueue(self._out_queue, ma_batch_on_gpu, self.metrics) + + class _LearnerThread(threading.Thread): def __init__( self, @@ -144,7 +203,7 @@ def __init__( in_queue: deque, metrics_logger, ): - super().__init__() + super().__init__(name="_LearnerThread") self.daemon = True self.metrics: MetricsLogger = metrics_logger self.stopped = False @@ -166,9 +225,8 @@ def step(self): ma_batch_on_gpu = self._in_queue.sample() else: # Queue is empty: Sleep a tiny bit to avoid CPU-thrashing. - if not self._in_queue: + while not self._in_queue: time.sleep(0.001) - return # Consume from the left (oldest batches first). # If we consumed from the right, we would run into the danger of # learning from newer batches (left side) most times, BUT sometimes diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index 8e98ed80e69fb..2ff4801be81d6 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -2,6 +2,7 @@ from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.connectors.common import TensorToNumpy from ray.rllib.connectors.learner import ( AddObservationsFromEpisodesToBatch, AddOneTsToEpisodesAndTruncate, @@ -21,7 +22,6 @@ from ray.rllib.utils.annotations import OldAPIStack, override from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.metrics import ( - ALL_MODULES, LEARNER_RESULTS, LEARNER_UPDATE_TIMER, NUM_AGENT_STEPS_SAMPLED, @@ -374,6 +374,11 @@ def build_learner_connector( GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_) ) + # If training on GPU, convert batches to `numpy` arrays to load them + # on GPU in the `Learner`. + if self.num_gpus_per_learner > 0: + pipeline.insert_after(GeneralAdvantageEstimation, TensorToNumpy()) + return pipeline @override(AlgorithmConfig) @@ -461,7 +466,11 @@ class (multi-/single-learner setup) and evaluation on batch_or_iterator = self.offline_data.sample( num_samples=self.config.train_batch_size_per_learner, num_shards=self.config.num_learners, - return_iterator=self.config.num_learners > 1, + # Return an iterator, if a `Learner` should update + # multiple times per RLlib iteration. + return_iterator=self.config.dataset_num_iters_per_learner > 1 + if self.config.dataset_num_iters_per_learner + else True, ) with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): @@ -478,23 +487,6 @@ class (multi-/single-learner setup) and evaluation on # Log training results. self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS) - # Synchronize weights. - # As the results contain for each policy the loss and in addition the - # total loss over all policies is returned, this total loss has to be - # removed. - modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES} - - if self.eval_env_runner_group: - # Update weights - after learning on the local worker - - # on all remote workers. - with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): - self.eval_env_runner_group.sync_weights( - # Sync weights from learner_group to all EnvRunners. - from_worker_or_learner_group=self.learner_group, - policies=list(modules_to_update), - inference_only=True, - ) - @OldAPIStack def _training_step_old_api_stack(self) -> ResultDict: """Implements training step for the old stack. diff --git a/rllib/algorithms/marwil/tests/test_marwil.py b/rllib/algorithms/marwil/tests/test_marwil.py index a43eacf46ba36..2b67fe9f67527 100644 --- a/rllib/algorithms/marwil/tests/test_marwil.py +++ b/rllib/algorithms/marwil/tests/test_marwil.py @@ -5,9 +5,10 @@ import ray import ray.rllib.algorithms.marwil as marwil -from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core import DEFAULT_MODULE_ID, COMPONENT_RL_MODULE from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY +from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.offline.offline_prelearner import OfflinePreLearner from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.test_utils import check @@ -157,9 +158,17 @@ def test_marwil_loss_function(self): # Sample a batch from the offline data. batch = algo.offline_data.data.take_batch(2000) + # Get the module state. + module_state = algo.offline_data.learner_handles[0].get_state( + component=COMPONENT_RL_MODULE, + )[COMPONENT_RL_MODULE] + # Create the prelearner and compute advantages and values. offline_prelearner = OfflinePreLearner( - config=config, learner=algo.learner_group._learner + config=config, + module_spec=algo.offline_data.module_spec, + module_state=module_state, + spaces=algo.offline_data.spaces[INPUT_ENV_SPACES], ) # Note, for `ray.data`'s pipeline everything has to be a dictionary # therefore the batch is embedded into another dictionary. @@ -179,9 +188,7 @@ def possibly_masked_mean(data_): # Calculate our own expected values (to then compare against the # agent's loss output). module = algo.learner_group._learner.module[DEFAULT_MODULE_ID].unwrapped() - fwd_out = module.forward_train( - {k: v for k, v in batch[DEFAULT_MODULE_ID].items()} - ) + fwd_out = module.forward_train(dict(batch[DEFAULT_MODULE_ID])) advantages = ( batch[DEFAULT_MODULE_ID][Columns.VALUE_TARGETS].detach().cpu().numpy() - module.compute_values(batch[DEFAULT_MODULE_ID]).detach().cpu().numpy() @@ -210,7 +217,7 @@ def possibly_masked_mean(data_): # calculation above). total_loss = algo.learner_group._learner.compute_loss_for_module( module_id=DEFAULT_MODULE_ID, - batch={k: v for k, v in batch[DEFAULT_MODULE_ID].items()}, + batch=dict(batch[DEFAULT_MODULE_ID]), fwd_out=fwd_out, config=config, ) diff --git a/rllib/algorithms/marwil/tests/test_marwil_old_api_stack.py b/rllib/algorithms/marwil/tests/test_marwil_old_api_stack.py deleted file mode 100644 index bb1fabfed7ee0..0000000000000 --- a/rllib/algorithms/marwil/tests/test_marwil_old_api_stack.py +++ /dev/null @@ -1,215 +0,0 @@ -import numpy as np -import os -from pathlib import Path -import unittest - -import ray -import ray.rllib.algorithms.marwil as marwil -from ray.rllib.algorithms.marwil.marwil_torch_policy import MARWILTorchPolicy -from ray.rllib.evaluation.postprocessing import compute_advantages -from ray.rllib.offline import JsonReader -from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - EVALUATION_RESULTS, -) -from ray.rllib.utils.test_utils import ( - check, - check_compute_single_action, - check_train_results, -) - -tf1, tf, tfv = try_import_tf() -torch, _ = try_import_torch() - - -class TestMARWILOld(unittest.TestCase): - @classmethod - def setUpClass(cls): - ray.init() - - @classmethod - def tearDownClass(cls): - ray.shutdown() - - def test_marwil_compilation_and_learning_from_offline_file(self): - """Test whether a MARWILAlgorithm can be built with all frameworks. - - Learns from a historic-data file. - To generate this data, first run: - $ ./train.py --run=PPO --env=CartPole-v1 \ - --stop='{"timesteps_total": 50000}' \ - --config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}' - """ - rllib_dir = Path(__file__).parent.parent.parent.parent - print("rllib dir={}".format(rllib_dir)) - data_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json") - print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) - - config = ( - marwil.MARWILConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .env_runners(num_env_runners=2) - .environment(env="CartPole-v1") - .evaluation( - evaluation_interval=3, - evaluation_num_env_runners=1, - evaluation_duration=5, - evaluation_parallel_to_training=True, - evaluation_config=marwil.MARWILConfig.overrides(input_="sampler"), - off_policy_estimation_methods={}, - ) - .offline_data(input_=[data_file]) - ) - - num_iterations = 350 - min_reward = 100.0 - - algo = config.build() - learnt = False - for i in range(num_iterations): - results = algo.train() - check_train_results(results) - print(results) - - eval_results = results.get(EVALUATION_RESULTS) - if eval_results: - print( - "iter={} R={} ".format( - i, eval_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] - ) - ) - # Learn until some reward is reached on an actual live env. - if eval_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] > min_reward: - print("learnt!") - learnt = True - break - - if not learnt: - raise ValueError( - "MARWILAlgorithm did not reach {} reward from expert " - "offline data!".format(min_reward) - ) - - check_compute_single_action(algo, include_prev_action_reward=True) - - algo.stop() - - def test_marwil_cont_actions_from_offline_file(self): - """Test whether MARWIL runs with cont. actions. - - Learns from a historic-data file. - To generate this data, first run: - $ ./train.py --run=PPO --env=Pendulum-v1 \ - --stop='{"timesteps_total": 50000}' \ - --config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}' - """ - rllib_dir = Path(__file__).parent.parent.parent.parent - print("rllib dir={}".format(rllib_dir)) - data_file = os.path.join(rllib_dir, "tests/data/pendulum/large.json") - print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) - - config = ( - marwil.MARWILConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .env_runners(num_env_runners=1) - .evaluation( - evaluation_num_env_runners=1, - evaluation_interval=3, - evaluation_duration=5, - evaluation_parallel_to_training=True, - # Evaluate on actual environment. - evaluation_config=marwil.MARWILConfig.overrides(input_="sampler"), - off_policy_estimation_methods={}, - ) - .offline_data( - # Learn from offline data. - input_=[data_file], - ) - ) - - num_iterations = 3 - - algo = config.build(env="Pendulum-v1") - for i in range(num_iterations): - print(algo.train()) - algo.stop() - - def test_marwil_loss_function(self): - """ - To generate the historic data used in this test case, first run: - $ ./train.py --run=PPO --env=CartPole-v1 \ - --stop='{"timesteps_total": 50000}' \ - --config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}' - """ - rllib_dir = Path(__file__).parent.parent.parent.parent - print("rllib dir={}".format(rllib_dir)) - data_file = os.path.join(rllib_dir, "tests/data/cartpole/small.json") - print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) - - config = ( - marwil.MARWILConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .env_runners(num_env_runners=0) - .offline_data(input_=[data_file]) - ) # Learn from offline data. - - reader = JsonReader(inputs=[data_file]) - batch = reader.next() - - algo = config.build(env="CartPole-v1") - policy = algo.get_policy() - model = policy.model - - # Calculate our own expected values (to then compare against the - # agent's loss output). - cummulative_rewards = compute_advantages( - batch, 0.0, config.gamma, 1.0, False, False - )["advantages"] - cummulative_rewards = torch.tensor(cummulative_rewards) - batch = policy._lazy_tensor_dict(batch) - model_out, _ = model(batch) - vf_estimates = model.value_function() - adv = cummulative_rewards - vf_estimates - adv = adv.detach().cpu().numpy() - adv_squared = np.mean(np.square(adv)) - c_2 = 100.0 + 1e-8 * (adv_squared - 100.0) - c = np.sqrt(c_2) - exp_advs = np.exp(config.beta * (adv / c)) - dist = policy.dist_class(model_out, model) - logp = dist.logp(batch["actions"]) - logp = logp.detach().cpu().numpy() - # Calculate all expected loss components. - expected_vf_loss = 0.5 * adv_squared - expected_pol_loss = -1.0 * np.mean(exp_advs * logp) - expected_loss = expected_pol_loss + config.vf_coeff * expected_vf_loss - - # Calculate the algorithm's loss (to check against our own - # calculation above). - batch.set_get_interceptor(None) - postprocessed_batch = policy.postprocess_trajectory(batch) - loss_func = MARWILTorchPolicy.loss - policy._lazy_tensor_dict(postprocessed_batch) - loss_out = loss_func(policy, model, policy.dist_class, postprocessed_batch) - - # Check all components. - check(policy.v_loss, expected_vf_loss, decimals=4) - check(policy.p_loss, expected_pol_loss, decimals=4) - check(loss_out, expected_loss, decimals=3) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/ppo/tests/test_ppo_old_api_stack.py b/rllib/algorithms/ppo/tests/test_ppo_old_api_stack.py deleted file mode 100644 index 5d5deacd2f316..0000000000000 --- a/rllib/algorithms/ppo/tests/test_ppo_old_api_stack.py +++ /dev/null @@ -1,524 +0,0 @@ -import unittest - -import numpy as np - -import ray -from ray.rllib.callbacks.callbacks import RLlibCallback -import ray.rllib.algorithms.ppo as ppo -from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy -from ray.rllib.core.columns import Columns -from ray.rllib.evaluation.postprocessing import ( - compute_gae_for_sample_batch, - Postprocessing, -) -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.models.torch.torch_action_dist import TorchCategorical -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY -from ray.rllib.utils.numpy import fc -from ray.rllib.utils.test_utils import ( - check, - check_compute_single_action, - check_off_policyness, - check_train_results, - check_inference_w_connectors, -) - - -# Fake CartPole episode of n time steps. -CARTPOLE_FAKE_BATCH = SampleBatch( - { - Columns.OBS: np.array( - [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], - dtype=np.float32, - ), - Columns.ACTIONS: np.array([0, 1, 1]), - SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]), - Columns.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), - SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), - Columns.TERMINATEDS: np.array([False, False, True]), - Columns.TRUNCATEDS: np.array([False, False, False]), - Columns.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32), - Columns.ACTION_DIST_INPUTS: np.array( - [[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32 - ), - Columns.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32), - Columns.EPS_ID: np.array([0, 0, 0]), - SampleBatch.AGENT_INDEX: np.array([0, 0, 0]), - } -) - -# Fake Pendulum episode of n time steps. -PENDULUM_FAKE_BATCH = SampleBatch( - { - Columns.OBS: np.array( - [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], - dtype=np.float32, - ), - Columns.ACTIONS: np.array([0.1, 0.2, 0.3], dtype=np.float32), - SampleBatch.PREV_ACTIONS: np.array([0.3, 0.4], dtype=np.float32), - Columns.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), - SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32), - Columns.TERMINATEDS: np.array([False, False, True]), - Columns.TRUNCATEDS: np.array([False, False, False]), - Columns.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32), - Columns.ACTION_DIST_INPUTS: np.array( - [ - [0.1, 0.0, 0.1, 0.2, 0.3, 0.4], - [0.5, 0.6, 0.7, 0.8, 0.9, 1.0], - [1.1, 1.2, 1.3, 1.4, 1.5, 1.6], - ], - dtype=np.float32, - ), - Columns.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32), - Columns.EPS_ID: np.array([0, 0, 0]), - SampleBatch.AGENT_INDEX: np.array([0, 0, 0]), - } -) - - -class MyCallbacks(RLlibCallback): - @staticmethod - def _check_lr_torch(policy, policy_id): - for j, opt in enumerate(policy._optimizers): - for p in opt.param_groups: - assert p["lr"] == policy.cur_lr, "LR scheduling error!" - - @staticmethod - def _check_lr_tf(policy, policy_id): - lr = policy.cur_lr - sess = policy.get_session() - if sess: - lr = sess.run(lr) - optim_lr = sess.run(policy._optimizer._lr) - else: - lr = lr.numpy() - optim_lr = policy._optimizer.lr.numpy() - assert lr == optim_lr, "LR scheduling error!" - - def on_train_result(self, *, algorithm, result: dict, **kwargs): - stats = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY] - # Learning rate should go to 0 after 1 iter. - check(stats["cur_lr"], 5e-5 if algorithm.iteration == 1 else 0.0) - # Entropy coeff goes to 0.05, then 0.0 (per iter). - check(stats["entropy_coeff"], 0.1 if algorithm.iteration == 1 else 0.05) - - algorithm.env_runner_group.foreach_policy( - self._check_lr_torch - if algorithm.config["framework"] == "torch" - else self._check_lr_tf - ) - - -class TestPPO(unittest.TestCase): - @classmethod - def setUpClass(cls): - ray.init() - - @classmethod - def tearDownClass(cls): - ray.shutdown() - - def test_ppo_compilation_w_connectors(self): - """Test whether PPO can be built with all frameworks w/ connectors.""" - - # Build a PPOConfig object. - config = ( - ppo.PPOConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .training( - num_epochs=2, - # Setup lr schedule for testing. - lr_schedule=[[0, 5e-5], [128, 0.0]], - # Set entropy_coeff to a faulty value to proof that it'll get - # overridden by the schedule below (which is expected). - entropy_coeff=100.0, - entropy_coeff_schedule=[[0, 0.1], [256, 0.0]], - train_batch_size=128, - model=dict( - # Settings in case we use an LSTM. - lstm_cell_size=10, - max_seq_len=20, - ), - ) - .env_runners( - num_env_runners=1, - # Test with compression. - compress_observations=True, - ) - .callbacks(MyCallbacks) - .evaluation( - evaluation_duration=2, - evaluation_duration_unit="episodes", - evaluation_num_env_runners=1, - ) - ) # For checking lr-schedule correctness. - - num_iterations = 2 - - for env in ["FrozenLake-v1", "ale_py:ALE/MsPacman-v5"]: - print("Env={}".format(env)) - for lstm in [False, True]: - print("LSTM={}".format(lstm)) - config.training( - model=dict( - use_lstm=lstm, - lstm_use_prev_action=lstm, - lstm_use_prev_reward=lstm, - ) - ) - - algo = config.build(env=env) - policy = algo.get_policy() - entropy_coeff = algo.get_policy().entropy_coeff - lr = policy.cur_lr - check(entropy_coeff, 0.1) - check(lr, config.lr) - - for i in range(num_iterations): - results = algo.train() - check_train_results(results) - print(results) - - algo.evaluate() - - check_inference_w_connectors(policy, env_name=env) - algo.stop() - - def test_ppo_compilation_and_schedule_mixins(self): - """Test whether PPO can be built with all frameworks.""" - - # Build a PPOConfig object. - config = ( - ppo.PPOConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .training( - # Setup lr schedule for testing. - lr_schedule=[[0, 5e-5], [256, 0.0]], - # Set entropy_coeff to a faulty value to proof that it'll get - # overridden by the schedule below (which is expected). - entropy_coeff=100.0, - entropy_coeff_schedule=[[0, 0.1], [512, 0.0]], - train_batch_size=256, - minibatch_size=128, - num_epochs=2, - model=dict( - # Settings in case we use an LSTM. - lstm_cell_size=10, - max_seq_len=20, - ), - ) - .env_runners( - num_env_runners=1, - # Test with compression. - compress_observations=True, - ) - .callbacks(MyCallbacks) - ) # For checking lr-schedule correctness. - - num_iterations = 2 - - for env in ["FrozenLake-v1", "ale_py:ALE/MsPacman-v5"]: - print("Env={}".format(env)) - for lstm in [False, True]: - print("LSTM={}".format(lstm)) - config.training( - model=dict( - use_lstm=lstm, - lstm_use_prev_action=lstm, - lstm_use_prev_reward=lstm, - ) - ) - - algo = config.build(env=env) - policy = algo.get_policy() - entropy_coeff = algo.get_policy().entropy_coeff - lr = policy.cur_lr - check(entropy_coeff, 0.1) - check(lr, config.lr) - - for i in range(num_iterations): - results = algo.train() - print(results) - check_train_results(results) - # 2 sgd iters per update, 2 minibatches per trainbatch -> 4x - # avg(0.0, 1.0, 2.0, 3.0) -> 1.5 - off_policy_ness = check_off_policyness( - results, lower_limit=1.5, upper_limit=1.5 - ) - print(f"off-policy'ness={off_policy_ness}") - - check_compute_single_action( - algo, include_prev_action_reward=True, include_state=lstm - ) - algo.stop() - - def test_ppo_exploration_setup(self): - """Tests, whether PPO runs with different exploration setups.""" - config = ( - ppo.PPOConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .environment( - "FrozenLake-v1", - env_config={"is_slippery": False, "map_name": "4x4"}, - ) - .env_runners( - # Run locally. - num_env_runners=0, - ) - ) - obs = np.array(0) - - # Default Agent should be setup with StochasticSampling. - algo = config.build() - # explore=False, always expect the same (deterministic) action. - a_ = algo.compute_single_action( - obs, explore=False, prev_action=np.array(2), prev_reward=np.array(1.0) - ) - - for _ in range(50): - a = algo.compute_single_action( - obs, - explore=False, - prev_action=np.array(2), - prev_reward=np.array(1.0), - ) - check(a, a_) - - # With explore=True (default), expect stochastic actions. - actions = [] - for _ in range(300): - actions.append( - algo.compute_single_action( - obs, prev_action=np.array(2), prev_reward=np.array(1.0) - ) - ) - check(np.mean(actions), 1.5, atol=0.2) - algo.stop() - - def test_ppo_free_log_std(self): - """Tests the free log std option works. - - This test is overfitted to the old ModelV2 stack (e.g. - policy.model.trainable_variables is not callable in the new stack) - # TODO (Kourosh) we should create a new test for the new RLModule stack. - """ - - config = ( - ppo.PPOConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .environment("CartPole-v1") - .env_runners( - num_env_runners=0, - ) - .training( - gamma=0.99, - model=dict( - fcnet_hiddens=[10], - fcnet_activation="linear", - free_log_std=True, - vf_share_layers=True, - ), - ) - ) - - algo = config.build() - policy = algo.get_policy() - - # Check the free log std var is created. - matching = [v for (n, v) in policy.model.named_parameters() if "log_std" in n] - assert len(matching) == 1, matching - log_std_var = matching[0] - - # linter yells at you if you don't pass in the parameters. - # reason: https://docs.python-guide.org/writing/gotchas/ - # #late-binding-closures - def get_value(fw="torch", policy=policy, log_std_var=log_std_var): - return log_std_var.detach().cpu().numpy()[0] - - # Check the variable is initially zero. - init_std = get_value() - assert init_std == 0.0, init_std - batch = compute_gae_for_sample_batch(policy, CARTPOLE_FAKE_BATCH.copy()) - batch = policy._lazy_tensor_dict(batch) - policy.learn_on_batch(batch) - - # Check the variable is updated. - post_std = get_value() - assert post_std != 0.0, post_std - algo.stop() - - def test_ppo_loss_function(self): - """Tests the PPO loss function math. - - This test is overfitted to the old ModelV2 stack (e.g. - policy.model.trainable_variables is not callable in the new stack) - # TODO (Kourosh) we should create a new test for the new RLModule stack. - """ - config = ( - ppo.PPOConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .environment("CartPole-v1") - .env_runners( - num_env_runners=0, - ) - .training( - gamma=0.99, - model=dict( - fcnet_hiddens=[10], - fcnet_activation="linear", - vf_share_layers=True, - ), - ) - ) - - algo = config.build() - policy = algo.get_policy() - - # Check no free log std var by default. - matching = [v for (n, v) in policy.model.named_parameters() if "log_std" in n] - assert len(matching) == 0, matching - - # Post-process (calculate simple (non-GAE) advantages) and attach - # to train_batch dict. - # A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] = - # [0.50005, -0.505, 0.5] - train_batch = compute_gae_for_sample_batch(policy, CARTPOLE_FAKE_BATCH.copy()) - train_batch = policy._lazy_tensor_dict(train_batch) - - # Check Advantage values. - check(train_batch[Postprocessing.VALUE_TARGETS], [0.50005, -0.505, 0.5]) - - # Calculate actual PPO loss. - PPOTorchPolicy.loss(policy, policy.model, policy.dist_class, train_batch) - - vars = list(policy.model.parameters()) - expected_shared_out = fc( - train_batch[Columns.OBS], - vars[2], - vars[3], - framework="torch", - ) - expected_logits = fc( - expected_shared_out, - vars[0], - vars[1], - framework="torch", - ) - expected_value_outs = fc( - expected_shared_out, vars[4], vars[5], framework="torch" - ) - - kl, entropy, pg_loss, vf_loss, overall_loss = self._ppo_loss_helper( - policy, - policy.model, - TorchCategorical, - train_batch, - expected_logits, - expected_value_outs, - sess=None, - ) - check(policy.model.tower_stats["mean_kl_loss"], kl) - check(policy.model.tower_stats["mean_entropy"], entropy) - check(policy.model.tower_stats["mean_policy_loss"], np.mean(-pg_loss)) - check( - policy.model.tower_stats["mean_vf_loss"], - np.mean(vf_loss), - decimals=4, - ) - check(policy.model.tower_stats["total_loss"], overall_loss, decimals=4) - algo.stop() - - def _ppo_loss_helper( - self, policy, model, dist_class, train_batch, logits, vf_outs, sess=None - ): - """ - Calculates the expected PPO loss (components) given Policy, - Model, distribution, some batch, logits & vf outputs, using numpy. - """ - # Calculate expected PPO loss results. - dist = dist_class(logits, policy.model) - dist_prev = dist_class(train_batch[Columns.ACTION_DIST_INPUTS], policy.model) - expected_logp = dist.logp(train_batch[Columns.ACTIONS]) - if isinstance(model, TorchModelV2): - train_batch.set_get_interceptor(None) - expected_rho = np.exp( - expected_logp.detach().cpu().numpy() - train_batch[Columns.ACTION_LOGP] - ) - # KL(prev vs current action dist)-loss component. - kl = np.mean(dist_prev.kl(dist).detach().cpu().numpy()) - # Entropy-loss component. - entropy = np.mean(dist.entropy().detach().cpu().numpy()) - else: - if sess: - expected_logp = sess.run(expected_logp) - expected_rho = np.exp(expected_logp - train_batch[Columns.ACTION_LOGP]) - # KL(prev vs current action dist)-loss component. - kl = dist_prev.kl(dist) - if sess: - kl = sess.run(kl) - kl = np.mean(kl) - # Entropy-loss component. - entropy = dist.entropy() - if sess: - entropy = sess.run(entropy) - entropy = np.mean(entropy) - - # Policy loss component. - pg_loss = np.minimum( - train_batch[Postprocessing.ADVANTAGES] * expected_rho, - train_batch[Postprocessing.ADVANTAGES] - * np.clip( - expected_rho, - 1 - policy.config["clip_param"], - 1 + policy.config["clip_param"], - ), - ) - - # Value function loss component. - vf_loss1 = np.power(vf_outs - train_batch[Postprocessing.VALUE_TARGETS], 2.0) - vf_clipped = train_batch[Columns.VF_PREDS] + np.clip( - vf_outs - train_batch[Columns.VF_PREDS], - -policy.config["vf_clip_param"], - policy.config["vf_clip_param"], - ) - vf_loss2 = np.power(vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0) - vf_loss = np.maximum(vf_loss1, vf_loss2) - - # Overall loss. - if sess: - policy_sess = policy.get_session() - kl_coeff, entropy_coeff = policy_sess.run( - [policy.kl_coeff, policy.entropy_coeff] - ) - else: - kl_coeff, entropy_coeff = policy.kl_coeff, policy.entropy_coeff - overall_loss = np.mean( - -pg_loss - + kl_coeff * kl - + policy.config["vf_loss_coeff"] * vf_loss - - entropy_coeff * entropy - ) - return kl, entropy, pg_loss, vf_loss, overall_loss - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/sac/README.md b/rllib/algorithms/sac/README.md index 2d60f94548475..c0a01e485e4ef 100644 --- a/rllib/algorithms/sac/README.md +++ b/rllib/algorithms/sac/README.md @@ -17,6 +17,3 @@ coeffcient. **[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#sac)** **[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/sac/sac.py)** - - - diff --git a/rllib/algorithms/sac/tests/test_sac.py b/rllib/algorithms/sac/tests/test_sac.py index 53c5749f7966e..51e6dbbd61cd3 100644 --- a/rllib/algorithms/sac/tests/test_sac.py +++ b/rllib/algorithms/sac/tests/test_sac.py @@ -130,9 +130,7 @@ def test_sac_dict_obs_order(self): # Dict space .sample() returns an ordered dict. # Make sure the keys in samples are ordered differently. - dict_samples = [ - {k: v for k, v in reversed(dict_space.sample().items())} for _ in range(10) - ] + dict_samples = [dict(reversed(dict_space.sample().items())) for _ in range(10)] class NestedDictEnv(gym.Env): def __init__(self): diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index 53c643bc4ae8d..1b47c02eba620 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -316,7 +316,7 @@ def new_mapping_fn(agent_id, episode, worker, i=i, **kwargs): # Test restoring from the checkpoint (which has more policies # than what's defined in the config dict). - test = ppo.PPO.from_checkpoint(checkpoint=checkpoint) + test = ppo.PPO.from_checkpoint(checkpoint) # Make sure evaluation worker also got the restored, added policy. def _has_policies(w, pid=pid): diff --git a/rllib/algorithms/utils.py b/rllib/algorithms/utils.py index b99cc51d144f6..6f0a5bdf87d9b 100644 --- a/rllib/algorithms/utils.py +++ b/rllib/algorithms/utils.py @@ -48,11 +48,13 @@ def __init__(self, config: AlgorithmConfig, rl_module_spec): # Set device and node. self._node = platform.node() - self._device = torch.device( - f"cuda:{ray.get_gpu_ids()[0]}" - if self.config.num_gpus_per_learner > 0 - else "cpu" - ) + self._device = torch.device("cpu") + # TODO (sven): Activate this when Ray has figured out GPU pre-loading. + # self._device = torch.device( + # f"cuda:{ray.get_gpu_ids()[0]}" + # if self.config.num_gpus_per_learner > 0 + # else "cpu" + # ) self.metrics = MetricsLogger() # Create the RLModule. @@ -89,7 +91,7 @@ def get_batch(self, episode_refs: List[ray.ObjectRef]): # If we have enough episodes collected to create a single train batch, pass # them at once through the connector to recieve a single train batch. - batch_on_gpu = self._learner_connector( + batch = self._learner_connector( episodes=episodes, rl_module=self._module, metrics=self.metrics, @@ -97,13 +99,13 @@ def get_batch(self, episode_refs: List[ray.ObjectRef]): # Convert to a dict into a `MultiAgentBatch`. # TODO (sven): Try to get rid of dependency on MultiAgentBatch (once our mini- # batch iterators support splitting over a dict). - ma_batch_on_gpu = MultiAgentBatch( + ma_batch = MultiAgentBatch( policy_batches={ - pid: SampleBatch(batch) for pid, batch in batch_on_gpu.items() + pid: SampleBatch(pol_batch) for pid, pol_batch in batch.items() }, env_steps=env_steps, ) - return ma_batch_on_gpu + return ma_batch def get_metrics(self): return self.metrics.reduce() diff --git a/rllib/callbacks/callbacks.py b/rllib/callbacks/callbacks.py index 06c24d4b42c4e..ad8eecaabdadf 100644 --- a/rllib/callbacks/callbacks.py +++ b/rllib/callbacks/callbacks.py @@ -257,39 +257,28 @@ def on_episode_created( ) -> None: """Callback run when a new episode is created (but has not started yet!). - This method gets called after a new Episode(V2) (old stack) or - MultiAgentEpisode instance has been created. - This happens before the respective sub-environment's (usually a gym.Env) + This method gets called after a new SingleAgentEpisode or MultiAgentEpisode + instance has been created. This happens before the respective sub-environment's `reset()` is called by RLlib. - Note, at the moment this callback does not get called in the new API stack - and single-agent mode. - - 1) Episode(V2)/MultiAgentEpisode created: This callback is called. + 1) SingleAgentEpisode/MultiAgentEpisode created: This callback is called. 2) Respective sub-environment (gym.Env) is `reset()`. 3) Callback `on_episode_start` is called. 4) Stepping through sub-environment/episode commences. Args: - episode: The newly created episode. On the new API stack, this will be a - MultiAgentEpisode object. On the old API stack, this will be a - Episode or EpisodeV2 object. + episode: The newly created SingleAgentEpisode or MultiAgentEpisode. This is the episode that is about to be started with an upcoming `env.reset()`. Only after this reset call, the `on_episode_start` callback will be called. - env_runner: Replaces `worker` arg. Reference to the current EnvRunner. + env_runner: Reference to the current EnvRunner. metrics_logger: The MetricsLogger object inside the `env_runner`. Can be used to log custom metrics after Episode creation. - env: Replaces `base_env` arg. The gym.Env (new API stack) or RLlib - BaseEnv (old API stack) running the episode. On the old stack, the - underlying sub environment objects can be retrieved by calling - `base_env.get_sub_environments()`. - rl_module: Replaces `policies` arg. Either the RLModule (new API stack) or a - dict mapping policy IDs to policy objects (old stack). In single agent - mode there will only be a single policy/RLModule under the - `rl_module["default_policy"]` key. - env_index: The index of the sub-environment that is about to be reset - (within the vector of sub-environments of the BaseEnv). + env: The gym.Env running the episode. + rl_module: The RLModule used to compute actions for stepping the env. In + single-agent mode, this is a simple RLModule, in multi-agent mode, this + is a MultiRLModule. + env_index: The index of the sub-environment that is about to be reset. kwargs: Forward compatibility placeholder. """ pass @@ -329,9 +318,9 @@ def on_episode_start( env: The gym.Env or gym.vector.Env object running the started episode. env_index: The index of the sub-environment that is about to be reset (within the vector of sub-environments of the BaseEnv). - rl_module: The RLModule used to compute actions for stepping the env. - In a single-agent setup, this is a (single-agent) RLModule, in a multi- - agent setup, this will be a MultiRLModule. + rl_module: The RLModule used to compute actions for stepping the env. In + single-agent mode, this is a simple RLModule, in multi-agent mode, this + is a MultiRLModule. kwargs: Forward compatibility placeholder. """ pass @@ -372,9 +361,9 @@ def on_episode_step( used to log custom metrics during env/episode stepping. env: The gym.Env or gym.vector.Env object running the started episode. env_index: The index of the sub-environment that has just been stepped. - rl_module: The RLModule used to compute actions for stepping the env. - In a single-agent setup, this is a (single-agent) RLModule, in a multi- - agent setup, this will be a MultiRLModule. + rl_module: The RLModule used to compute actions for stepping the env. In + single-agent mode, this is a simple RLModule, in multi-agent mode, this + is a MultiRLModule. kwargs: Forward compatibility placeholder. """ pass @@ -432,9 +421,9 @@ def on_episode_end( env: The gym.Env or gym.vector.Env object running the started episode. env_index: The index of the sub-environment that has just been terminated or truncated. - rl_module: The RLModule used to compute actions for stepping the env. - In a single-agent setup, this is a (single-agent) RLModule, in a multi- - agent setup, this will be a MultiRLModule. + rl_module: The RLModule used to compute actions for stepping the env. In + single-agent mode, this is a simple RLModule, in multi-agent mode, this + is a MultiRLModule. kwargs: Forward compatibility placeholder. """ pass @@ -456,8 +445,9 @@ def on_sample_end( env_runner: Reference to the current EnvRunner object. metrics_logger: The MetricsLogger object inside the `env_runner`. Can be used to log custom metrics during env/episode stepping. - samples: Batch to be returned. You can mutate this - object to modify the samples generated. + samples: Lists of SingleAgentEpisode or MultiAgentEpisode instances to be + returned. You can mutate the episodes to modify the returned training + data. kwargs: Forward compatibility placeholder. """ pass @@ -480,13 +470,13 @@ def on_sub_environment_created( `Algorithm.validate_env()`), wrapped (e.g. video-wrapper), and seeded. Args: - worker: Reference to the current rollout worker. + worker: Reference to the current EnvRunner. sub_environment: The sub-environment instance that has been created. This is usually a gym.Env object. env_context: The `EnvContext` object that has been passed to the env's constructor. env_index: The index of the sub-environment that has been created - (within the vector of sub-environments of the BaseEnv). + (within the vector of sub-environments of the gym.vector.Env). kwargs: Forward compatibility placeholder. """ pass diff --git a/rllib/connectors/common/__init__.py b/rllib/connectors/common/__init__.py index cf28ba9ae9fbc..6ef51fb150732 100644 --- a/rllib/connectors/common/__init__.py +++ b/rllib/connectors/common/__init__.py @@ -10,6 +10,7 @@ from ray.rllib.connectors.common.agent_to_module_mapping import AgentToModuleMapping from ray.rllib.connectors.common.batch_individual_items import BatchIndividualItems from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor +from ray.rllib.connectors.common.tensor_to_numpy import TensorToNumpy __all__ = [ @@ -19,4 +20,5 @@ "AgentToModuleMapping", "BatchIndividualItems", "NumpyToTensor", + "TensorToNumpy", ] diff --git a/rllib/connectors/tests/test_action.py b/rllib/connectors/tests/test_action.py deleted file mode 100644 index ff6b8c47f513f..0000000000000 --- a/rllib/connectors/tests/test_action.py +++ /dev/null @@ -1,113 +0,0 @@ -# @OldAPIStack - -import unittest - -import gymnasium as gym -import numpy as np - -from ray.rllib.connectors.action.clip import ClipActionsConnector -from ray.rllib.connectors.action.immutable import ImmutableActionsConnector -from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector -from ray.rllib.connectors.action.normalize import NormalizeActionsConnector -from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline -from ray.rllib.connectors.connector import ConnectorContext -from ray.rllib.connectors.registry import get_connector -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import ActionConnectorDataType - -torch, _ = try_import_torch() - - -class TestActionConnector(unittest.TestCase): - def test_connector_pipeline(self): - ctx = ConnectorContext() - connectors = [ConvertToNumpyConnector(ctx)] - pipeline = ActionConnectorPipeline(ctx, connectors) - name, params = pipeline.to_state() - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, ActionConnectorPipeline)) - self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector)) - # There should not be any timer yet - self.assertFalse(bool(pipeline.timers.values())) - pipeline(ActionConnectorDataType(0, 0, {}, ([1], [], None))) - # After a first input, there should be one timer - self.assertEqual(len(pipeline.timers.values()), 1) - - def test_convert_to_numpy_connector(self): - ctx = ConnectorContext() - c = ConvertToNumpyConnector(ctx) - - name, params = c.to_state() - - self.assertEqual(name, "ConvertToNumpyConnector") - - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, ConvertToNumpyConnector)) - - action = torch.Tensor([8, 9]) - states = torch.Tensor([[1, 1, 1], [2, 2, 2]]) - ac_data = ActionConnectorDataType(0, 1, {}, (action, states, {})) - - converted = c(ac_data) - self.assertTrue(isinstance(converted.output[0], np.ndarray)) - self.assertTrue(isinstance(converted.output[1], np.ndarray)) - - def test_normalize_action_connector(self): - ctx = ConnectorContext( - action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]) - ) - c = NormalizeActionsConnector(ctx) - - name, params = c.to_state() - self.assertEqual(name, "NormalizeActionsConnector") - - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, NormalizeActionsConnector)) - - ac_data = ActionConnectorDataType(0, 1, {}, (0.5, [], {})) - - normalized = c(ac_data) - self.assertEqual(normalized.output[0], 4.5) - - def test_clip_action_connector(self): - ctx = ConnectorContext( - action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]) - ) - c = ClipActionsConnector(ctx) - - name, params = c.to_state() - self.assertEqual(name, "ClipActionsConnector") - - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, ClipActionsConnector)) - - ac_data = ActionConnectorDataType(0, 1, {}, (8.8, [], {})) - - clipped = c(ac_data) - self.assertEqual(clipped.output[0], 6.0) - - def test_immutable_action_connector(self): - ctx = ConnectorContext( - action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]) - ) - c = ImmutableActionsConnector(ctx) - - name, params = c.to_state() - self.assertEqual(name, "ImmutableActionsConnector") - - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, ImmutableActionsConnector)) - - ac_data = ActionConnectorDataType(0, 1, {}, (np.array([8.8]), [], {})) - - immutable = c(ac_data) - - with self.assertRaises(ValueError): - immutable.output[0][0] = 5 - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/connectors/tests/test_agent.py b/rllib/connectors/tests/test_agent.py deleted file mode 100644 index cc1acab22588d..0000000000000 --- a/rllib/connectors/tests/test_agent.py +++ /dev/null @@ -1,646 +0,0 @@ -# @OldAPIStack - -import gymnasium as gym -from gymnasium.spaces import Box -import numpy as np -import tree # pip install dm_tree -import unittest - -from ray.rllib.algorithms.ppo.ppo import PPOConfig -from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector -from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector -from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector -from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline -from ray.rllib.connectors.agent.state_buffer import StateBufferConnector -from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector -from ray.rllib.connectors.connector import ConnectorContext -from ray.rllib.connectors.registry import get_connector -from ray.rllib.policy.view_requirement import ViewRequirement -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.test_utils import check -from ray.rllib.utils.typing import ( - ActionConnectorDataType, - AgentConnectorDataType, - AgentConnectorsOutput, -) -from ray.rllib.connectors.agent.mean_std_filter import ( - MeanStdObservationFilterAgentConnector, -) - - -class TestAgentConnector(unittest.TestCase): - def test_connector_pipeline(self): - ctx = ConnectorContext() - connectors = [ClipRewardAgentConnector(ctx, False, 1.0)] - pipeline = AgentConnectorPipeline(ctx, connectors) - name, params = pipeline.to_state() - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, AgentConnectorPipeline)) - self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector)) - - def test_obs_preprocessor_connector(self): - obs_space = gym.spaces.Dict( - { - "a": gym.spaces.Box(low=0, high=1, shape=(1,)), - "b": gym.spaces.Tuple( - [gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])] - ), - } - ) - ctx = ConnectorContext(config={}, observation_space=obs_space) - - c = ObsPreprocessorConnector(ctx) - name, params = c.to_state() - - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, ObsPreprocessorConnector)) - - obs = obs_space.sample() - # Fake deterministic data. - obs["a"][0] = 0.5 - obs["b"] = (1, np.array([0, 2])) - - d = AgentConnectorDataType( - 0, - 1, - { - SampleBatch.OBS: obs, - }, - ) - preprocessed = c([d]) - - # obs is completely flattened. - self.assertTrue( - (preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all() - ) - - def test_clip_reward_connector(self): - ctx = ConnectorContext() - - c = ClipRewardAgentConnector(ctx, limit=2.0) - name, params = c.to_state() - - self.assertEqual(name, "ClipRewardAgentConnector") - self.assertAlmostEqual(params["limit"], 2.0) - - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, ClipRewardAgentConnector)) - - d = AgentConnectorDataType( - 0, - 1, - { - SampleBatch.REWARDS: 5.8, - }, - ) - clipped = restored([d]) - - self.assertEqual(len(clipped), 1) - self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0) - - def test_flatten_data_connector(self): - ctx = ConnectorContext() - - c = FlattenDataAgentConnector(ctx) - - name, params = c.to_state() - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, FlattenDataAgentConnector)) - - sample_batch = { - SampleBatch.NEXT_OBS: { - "sensor1": [[1, 1], [2, 2]], - "sensor2": 8.8, - }, - SampleBatch.REWARDS: 5.8, - SampleBatch.ACTIONS: [[1, 1], [2, 2]], - SampleBatch.INFOS: {"random": "info"}, - } - - d = AgentConnectorDataType( - 0, - 1, - # FlattenDataAgentConnector does NOT touch raw_dict, - # so simply pass None here. - AgentConnectorsOutput(None, sample_batch), - ) - - flattened = c([d]) - self.assertEqual(len(flattened), 1) - - batch = flattened[0].data.sample_batch - self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all()) - self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8) - # Not flattened. - self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2) - self.assertEqual(batch[SampleBatch.INFOS]["random"], "info") - - def test_state_buffer_connector(self): - ctx = ConnectorContext( - action_space=gym.spaces.Box(low=-1.0, high=1.0, shape=(3,)), - ) - c = StateBufferConnector(ctx) - - # Reset without any buffered data should do nothing. - c.reset(env_id=0) - - d = AgentConnectorDataType( - 0, - 1, - { - SampleBatch.NEXT_OBS: { - "sensor1": [[1, 1], [2, 2]], - "sensor2": 8.8, - }, - }, - ) - - with_buffered = c([d]) - self.assertEqual(len(with_buffered), 1) - self.assertTrue((with_buffered[0].data[SampleBatch.ACTIONS] == [0, 0, 0]).all()) - - c.on_policy_output(ActionConnectorDataType(0, 1, {}, ([1, 2, 3], [], {}))) - - with_buffered = c([d]) - self.assertEqual(len(with_buffered), 1) - self.assertEqual(with_buffered[0].data[SampleBatch.ACTIONS], [1, 2, 3]) - - def test_mean_std_observation_filter_connector(self): - for bounds in [ - (-1, 1), # normalized - (-2, 2), # scaled - (0, 2), # shifted - (0, 4), # scaled and shifted - ]: - print("Testing uniform sampling with bounds: {}".format(bounds)) - - observation_space = Box(bounds[0], bounds[1], (3, 64, 64)) - ctx = ConnectorContext(observation_space=observation_space) - filter_connector = MeanStdObservationFilterAgentConnector(ctx) - - # Warm up Mean-Std filter - for i in range(1000): - obs = observation_space.sample() - sample_batch = { - SampleBatch.NEXT_OBS: obs, - } - ac = AgentConnectorDataType(0, 0, sample_batch) - filter_connector.transform(ac) - - # Create another connector to set state to - _, state = filter_connector.to_state() - another_filter_connector = ( - MeanStdObservationFilterAgentConnector.from_state(ctx, state) - ) - - another_filter_connector.in_eval() - - # Collect transformed observations - transformed_observations = [] - for i in range(1000): - obs = observation_space.sample() - sample_batch = { - SampleBatch.NEXT_OBS: obs, - } - ac = AgentConnectorDataType(0, 0, sample_batch) - connector_output = another_filter_connector.transform(ac) - transformed_observations.append( - connector_output.data[SampleBatch.NEXT_OBS] - ) - - # Check if transformed observations are actually mean-std filtered - self.assertTrue(np.isclose(np.mean(transformed_observations), 0, atol=0.1)) - self.assertTrue(np.isclose(np.var(transformed_observations), 1, atol=0.1)) - - # Check if filter parameters where frozen because we are not training - self.assertTrue( - filter_connector.filter.running_stats.num_pushes - == another_filter_connector.filter.running_stats.num_pushes, - ) - self.assertTrue( - np.all( - filter_connector.filter.running_stats.mean_array - == another_filter_connector.filter.running_stats.mean_array, - ) - ) - self.assertTrue( - np.all( - filter_connector.filter.running_stats.std_array - == another_filter_connector.filter.running_stats.std_array, - ) - ) - self.assertTrue( - filter_connector.filter.buffer.num_pushes - == another_filter_connector.filter.buffer.num_pushes, - ) - self.assertTrue( - np.all( - filter_connector.filter.buffer.mean_array - == another_filter_connector.filter.buffer.mean_array, - ) - ) - self.assertTrue( - np.all( - filter_connector.filter.buffer.std_array - == another_filter_connector.filter.buffer.std_array, - ) - ) - - -class TestViewRequirementAgentConnector(unittest.TestCase): - def test_vr_connector_respects_training_or_inference_vr_flags(self): - """Tests that the connector respects the flags within view_requirements (i.e. - used_for_training, used_for_compute_actions). - - the returned data is the input dict itself, which the policy collector in - env_runner will use to construct the episode, and a SampleBatch that can be - used to run corresponding policy. - """ - view_rq_dict = { - "both": ViewRequirement( - "obs", used_for_training=True, used_for_compute_actions=True - ), - "only_inference": ViewRequirement( - "obs", used_for_training=False, used_for_compute_actions=True - ), - "none": ViewRequirement( - "obs", used_for_training=False, used_for_compute_actions=False - ), - "only_training": ViewRequirement( - "obs", used_for_training=True, used_for_compute_actions=False - ), - } - - obs_arr = np.array([0, 1, 2, 3]) - agent_data = {SampleBatch.NEXT_OBS: obs_arr} - data = AgentConnectorDataType(0, 1, agent_data) - - config = PPOConfig().to_dict() - config["_enable_new_api_stack"] = False - ctx = ConnectorContext( - view_requirements=view_rq_dict, - config=config, - is_policy_recurrent=True, - ) - - sample_batch_expected = SampleBatch( - { - "both": obs_arr[None], - # Output in training model as well. - "only_inference": obs_arr[None], - "seq_lens": np.array([1]), - } - ) - - c = ViewRequirementAgentConnector(ctx) - c.in_training() - processed = c([data]) - - raw_dict = processed[0].data.raw_dict - sample_batch = processed[0].data.sample_batch - - check(raw_dict, agent_data) - check(sample_batch, sample_batch_expected) - - def test_vr_connector_shift_by_one(self): - view_rq_dict = { - "state": ViewRequirement("obs"), - "next_state": ViewRequirement( - "obs", shift=1, used_for_compute_actions=False - ), - "prev_state": ViewRequirement("obs", shift=-1), - } - - obs_arrs = np.arange(10)[:, None] + 1 - config = PPOConfig().to_dict() - config["_enable_new_api_stack"] = False - ctx = ConnectorContext( - view_requirements=view_rq_dict, config=config, is_policy_recurrent=True - ) - c = ViewRequirementAgentConnector(ctx) - - # keep a running list of observations - obs_list = [] - for t, obs in enumerate(obs_arrs): - # t=0 is the next state of t=-1 - data = AgentConnectorDataType(0, 1, {SampleBatch.NEXT_OBS: obs}) - processed = c([data]) # env.reset() for t == -1 else env.step() - sample_batch = processed[0].data.sample_batch - # add cur obs to the list - obs_list.append(obs) - - if t == 0: - check(sample_batch["prev_state"], sample_batch["state"]) - else: - # prev state should be equal to the prev time step obs - check(sample_batch["prev_state"], obs_list[-2][None]) - - def test_vr_connector_causal_slice(self): - """Test that the ViewRequirementAgentConnector can handle slice shifts.""" - view_rq_dict = { - "state": ViewRequirement("obs"), - # shift array should be [-2, -1, 0] - "prev_states": ViewRequirement("obs", shift="-2:0"), - # shift array should be [-4, -2, 0] - "prev_strided_states_even": ViewRequirement("obs", shift="-4:0:2"), - # shift array should be [-3, -1] - "prev_strided_states_odd": ViewRequirement("obs", shift="-3:0:2"), - } - - obs_arrs = np.arange(10)[:, None] + 1 - config = PPOConfig().to_dict() - config["_enable_new_api_stack"] = False - ctx = ConnectorContext( - view_requirements=view_rq_dict, config=config, is_policy_recurrent=True - ) - c = ViewRequirementAgentConnector(ctx) - - # keep a queue of observations - obs_list = [] - for t, obs in enumerate(obs_arrs): - # t=0 is the next state of t=-1 - data = AgentConnectorDataType(0, 1, {SampleBatch.NEXT_OBS: obs}) - processed = c([data]) - sample_batch = processed[0].data.sample_batch - - if t == 0: - obs_list.extend([obs for _ in range(5)]) - else: - # remove the first obs and add the current obs to the end - obs_list.pop(0) - obs_list.append(obs) - - # check state - check(sample_batch["state"], obs[None]) - - # check prev_states - check( - sample_batch["prev_states"], - np.stack(obs_list)[np.array([-3, -2, -1])][None], - ) - - # check prev_strided_states_even - check( - sample_batch["prev_strided_states_even"], - np.stack(obs_list)[np.array([-5, -3, -1])][None], - ) - - check( - sample_batch["prev_strided_states_odd"], - np.stack(obs_list)[np.array([-4, -2])][None], - ) - - def test_vr_connector_with_multiple_buffers(self): - """Test that the ViewRequirementAgentConnector can handle slice shifts correctly - when it has multiple buffers to shift.""" - context_len = 5 - # This view requirement simulates the use-case of a decision transformer - # without reward-to-go. - view_rq_dict = { - # obs[t-context_len+1:t] - "context_obs": ViewRequirement( - "obs", - shift=f"-{context_len-1}:0", - space=Box(-np.inf, np.inf, shape=(1,), dtype=np.float64), - ), - # next_obs[t-context_len+1:t] - "context_next_obs": ViewRequirement( - "obs", - shift=f"-{context_len}:1", - used_for_compute_actions=False, - space=Box(-np.inf, np.inf, shape=(1,), dtype=np.float64), - ), - # act[t-context_len+1:t] - "context_act": ViewRequirement( - SampleBatch.ACTIONS, - shift=f"-{context_len-1}:-1", - space=Box(-np.inf, np.inf, shape=(1,)), - ), - } - - obs_arrs = np.arange(10)[:, None] + 1 - act_arrs = (np.arange(10)[:, None] + 1) * 100 - n_steps = obs_arrs.shape[0] - config = PPOConfig().to_dict() - config["_enable_new_api_stack"] = False - ctx = ConnectorContext( - view_requirements=view_rq_dict, config=config, is_policy_recurrent=True - ) - c = ViewRequirementAgentConnector(ctx) - - # keep a queue of length ctx_len of observations - obs_list, act_list = [], [] - for t in range(n_steps): - # next state and action at time t-1 are the following - timestep_data = { - SampleBatch.NEXT_OBS: obs_arrs[t], - } - if t > 0: - timestep_data[SampleBatch.ACTIONS] = act_arrs[t - 1] - data = AgentConnectorDataType(0, 1, timestep_data) - processed = c([data]) - sample_batch = processed[0].data.sample_batch - - if t == 0: - obs_list.extend([obs_arrs[0] for _ in range(context_len)]) - act_list.extend( - [np.zeros_like(act_arrs[0]) for _ in range(context_len)] - ) - else: - obs_list.pop(0) - act_list.pop(0) - obs_list.append(obs_arrs[t]) - act_list.append(act_arrs[t - 1]) - - self.assertTrue("context_next_obs" not in sample_batch) - # We should have the 5 (context_len) most recent observations here - check(sample_batch["context_obs"], np.stack(obs_list)[None]) - # The context for actions is [t-context_len+1:t]. Since we build sample - # batch for inference in ViewRequirementAgentConnector, it always - # includes everything up until the last action (at t-1), but not the - # action current action (at t). - check(sample_batch["context_act"], np.stack(act_list[1:])[None]) - - def test_connector_pipline_with_view_requirement(self): - """A very minimal test that checks wheter pipeline connectors work in a - simulation rollout.""" - config = ( - PPOConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .framework("torch") - .environment(env="CartPole-v1") - .env_runners(create_env_on_local_worker=True) - ) - - env = gym.make("CartPole-v1") - policy = config.build().get_policy() - - REQUIRED_KEYS = { - SampleBatch.OBS, - SampleBatch.NEXT_OBS, - SampleBatch.REWARDS, - SampleBatch.TERMINATEDS, - SampleBatch.TRUNCATEDS, - SampleBatch.INFOS, - SampleBatch.ACTIONS, - } - policy.view_requirements = { - k: v for k, v in policy.view_requirements.items() if k in REQUIRED_KEYS - } - - # create a connector context - ctx = ConnectorContext( - view_requirements=policy.view_requirements, - config=policy.config, - initial_states=policy.get_initial_state(), - is_policy_recurrent=policy.is_recurrent(), - observation_space=policy.observation_space, - action_space=policy.action_space, - ) - - # build chain of connectors - connectors = [ - ObsPreprocessorConnector(ctx), - StateBufferConnector(ctx), - ViewRequirementAgentConnector(ctx), - ] - agent_connector = AgentConnectorPipeline(ctx, connectors) - - name, params = agent_connector.to_state() - restored = get_connector(name, ctx, params) - self.assertTrue(isinstance(restored, AgentConnectorPipeline)) - for cidx, c in enumerate(connectors): - check(restored.connectors[cidx].to_state(), c.to_state()) - - # simulate a rollout - n_steps = 10 - obs, info = env.reset() - env_out = AgentConnectorDataType( - 0, 1, {SampleBatch.NEXT_OBS: obs, SampleBatch.T: -1} - ) - agent_obs = agent_connector([env_out])[0] - t = 0 - total_rewards = 0 - while t < n_steps: - policy_output = policy.compute_actions_from_input_dict( - agent_obs.data.sample_batch - ) - # Removes batch dimension - policy_output = tree.map_structure(lambda x: x[0], policy_output) - - agent_connector.on_policy_output( - ActionConnectorDataType(0, 1, {}, policy_output) - ) - action = policy_output[0] - - next_obs, rewards, terminateds, truncateds, info = env.step(action) - env_out_dict = { - SampleBatch.NEXT_OBS: next_obs, - SampleBatch.REWARDS: rewards, - SampleBatch.TERMINATEDS: terminateds, - SampleBatch.TRUNCATEDS: truncateds, - SampleBatch.INFOS: info, - SampleBatch.ACTIONS: action, - # state_out - } - env_out = AgentConnectorDataType(0, 1, env_out_dict) - agent_obs = agent_connector([env_out])[0] - total_rewards += rewards - t += 1 - print(total_rewards) - - def test_vr_connector_only_keeps_useful_timesteps(self): - """Tests that the connector respects the flags within view_requirements (i.e. - used_for_training, used_for_compute_actions). - - the returned data is the input dict itself, which the policy collector in - env_runner will use to construct the episode, and a SampleBatch that can be - used to run corresponding policy. - """ - view_rqs = { - "obs": ViewRequirement( - None, used_for_training=True, used_for_compute_actions=True - ), - } - - config = PPOConfig().to_dict() - config["_enable_new_api_stack"] = False - ctx = ConnectorContext( - view_requirements=view_rqs, - config=config, - is_policy_recurrent=False, - ) - - c = ViewRequirementAgentConnector(ctx) - c.in_training() - - for i in range(5): - obs_arr = np.array([0, 1, 2, 3]) + i - agent_data = {SampleBatch.NEXT_OBS: obs_arr} - data = AgentConnectorDataType(0, 1, agent_data) - - # Feed ViewRequirementAgentConnector 5 samples. - c([data]) - - obs_data = c.agent_collectors[0][1].buffers["obs"][0] - # Only keep data for the last timestep. - self.assertEqual(len(obs_data), 1) - # Data matches the latest timestep. - self.assertTrue(np.array_equal(obs_data[0], np.array([4, 5, 6, 7]))) - - def test_vr_connector_default_agent_collector_is_empty(self): - """Tests that after reset() the view_requirement connector will - create a fresh new agent collector. - """ - view_rqs = { - "obs": ViewRequirement( - None, used_for_training=True, used_for_compute_actions=True - ), - } - - config = PPOConfig().to_dict() - config["_enable_new_api_stack"] = False - ctx = ConnectorContext( - view_requirements=view_rqs, - config=config, - is_policy_recurrent=False, - ) - - c = ViewRequirementAgentConnector(ctx) - c.in_training() - - for i in range(5): - obs_arr = np.array([0, 1, 2, 3]) + i - agent_data = {SampleBatch.NEXT_OBS: obs_arr} - data = AgentConnectorDataType(0, 1, agent_data) - - # Feed ViewRequirementAgentConnector 5 samples. - c([data]) - - # 1 init_obs, plus 4 agent steps. - self.assertEqual(c.agent_collectors[0][1].agent_steps, 4) - - # Reset. - c.reset(0) # env_id = 0 - - # Process a new timestep. - obs_arr = np.array([0, 1, 2, 3]) + i - agent_data = {SampleBatch.NEXT_OBS: obs_arr} - data = AgentConnectorDataType(0, 1, agent_data) - - # Feed ViewRequirementAgentConnector 5 samples. - c([data]) - - # Start fresh with 0 agent step. - self.assertEqual(c.agent_collectors[0][1].agent_steps, 0) - - -if __name__ == "__main__": - import sys - - import pytest - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/connectors/tests/test_connector.py b/rllib/connectors/tests/test_connector.py deleted file mode 100644 index 2d1e5a18855c4..0000000000000 --- a/rllib/connectors/tests/test_connector.py +++ /dev/null @@ -1,100 +0,0 @@ -# @OldAPIStack - -import unittest - -import gymnasium as gym - -from ray.rllib.connectors.connector import Connector, ConnectorPipeline -from ray.rllib.connectors.connector import ConnectorContext -from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector -from ray.rllib.connectors.agent.mean_std_filter import ( - MeanStdObservationFilterAgentConnector, -) -from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector -from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector - - -class TestConnectorPipeline(unittest.TestCase): - class Tom(Connector): - def to_state(): - return "tom" - - class Bob(Connector): - def to_state(): - return "bob" - - class Mary(Connector): - def to_state(): - return "mary" - - class MockConnectorPipeline(ConnectorPipeline): - def __init__(self, ctx, connectors): - # Real connector pipelines should keep a list of - # Connectors. - # Use strings here for simple unit tests. - self.connectors = connectors - - def test_sanity_check(self): - ctx = {} - - m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)]) - m.insert_before("Bob", self.Mary(ctx)) - self.assertEqual(len(m.connectors), 3) - self.assertEqual(m.connectors[1].__class__.__name__, "Mary") - - m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)]) - m.insert_after("Tom", self.Mary(ctx)) - self.assertEqual(len(m.connectors), 3) - self.assertEqual(m.connectors[1].__class__.__name__, "Mary") - - m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)]) - m.prepend(self.Mary(ctx)) - self.assertEqual(len(m.connectors), 3) - self.assertEqual(m.connectors[0].__class__.__name__, "Mary") - - m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)]) - m.append(self.Mary(ctx)) - self.assertEqual(len(m.connectors), 3) - self.assertEqual(m.connectors[2].__class__.__name__, "Mary") - - m.remove("Bob") - self.assertEqual(len(m.connectors), 2) - self.assertEqual(m.connectors[0].__class__.__name__, "Tom") - self.assertEqual(m.connectors[1].__class__.__name__, "Mary") - - m.remove("Bob") - # Bob does not exist anymore, still 2. - self.assertEqual(len(m.connectors), 2) - self.assertEqual(m.connectors[0].__class__.__name__, "Tom") - self.assertEqual(m.connectors[1].__class__.__name__, "Mary") - - self.assertEqual(m["Tom"], [m.connectors[0]]) - self.assertEqual(m[0], [m.connectors[0]]) - self.assertEqual(m[m.connectors[1].__class__], [m.connectors[1]]) - - def test_pipeline_indexing(self): - """Tests if ConnectorPipeline.__getitem__() works as intended.""" - context = ConnectorContext({}, observation_space=gym.spaces.Box(-1, 1, (1,))) - some_connector = MeanStdObservationFilterAgentConnector(context) - some_other_connector = ObsPreprocessorConnector(context) - # Create a dummy pipeline just for indexing purposes - pipeline = ConnectorPipeline(context, [some_connector, some_other_connector]) - - for key, expected_value in [ - (MeanStdObservationFilterAgentConnector, [some_connector]), - ("MeanStdObservationFilterAgentConnector", [some_connector]), - (SyncedFilterAgentConnector, [some_connector]), - ("SyncedFilterAgentConnector", []), - (ClipRewardAgentConnector, []), - ("can i get something?", []), - (0, [some_connector]), - (1, [some_other_connector]), - ]: - self.assertEqual(pipeline[key], expected_value) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index 497a71fa29234..2f96d102df637 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -55,6 +55,8 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.metrics import ( ALL_MODULES, + DATASET_NUM_ITERS_TRAINED, + DATASET_NUM_ITERS_TRAINED_LIFETIME, NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_ENV_STEPS_TRAINED, NUM_ENV_STEPS_TRAINED_LIFETIME, @@ -895,7 +897,7 @@ def compute_losses( use the `forward_train()` outputs of the RLModule(s) to compute the required loss tensors. See here for a custom loss function example script: - https://github.com/ray-project/ray/blob/master/rllib/examples/learners/custom_loss_fn_simple.py # noqa + https://github.com/ray-project/ray/blob/master/rllib/examples/learners/ppo_with_custom_loss_fn.py # noqa Args: fwd_out: Output from a call to the `forward_train()` method of the @@ -1125,48 +1127,50 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]: i = 0 logger.debug(f"===> [Learner {id(self)}]: Looping through batches ... ") - for batch in self.iterator.iter_batches( - # Note, this needs to be one b/c data is already mapped to - # `MultiAgentBatch`es of `minibatch_size`. - batch_size=1, - _finalize_fn=_finalize_fn, - **kwargs, - ): - # Update the iteration counter. - i += 1 - - # Note, `_finalize_fn` must return a dictionary. - batch = batch["batch"] - logger.debug( - f"===> [Learner {id(self)}]: batch {i} with {batch.env_steps()} rows." - ) - # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs - # found in this batch. If not, throw an error. - unknown_module_ids = set(batch.policy_batches.keys()) - set( - self.module.keys() - ) - if len(unknown_module_ids) > 0: - raise ValueError( - "Batch contains one or more ModuleIDs that are not in this " - f"Learner! Found IDs: {unknown_module_ids}" + while num_iters is None or i < num_iters: + for batch in self.iterator.iter_batches( + # Note, this needs to be one b/c data is already mapped to + # `MultiAgentBatch`es of `minibatch_size`. + batch_size=1, + _finalize_fn=_finalize_fn, + **kwargs, + ): + # TODO (simon): Add metrics for the `dataset_num_iter`. + # Update the iteration counter. + i += 1 + + # Note, `_finalize_fn` must return a dictionary. + batch = batch["batch"] + logger.debug( + f"===> [Learner {id(self)}]: batch {i} with {batch.env_steps()} rows." ) + # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs + # found in this batch. If not, throw an error. + unknown_module_ids = set(batch.policy_batches.keys()) - set( + self.module.keys() + ) + if len(unknown_module_ids) > 0: + raise ValueError( + "Batch contains one or more ModuleIDs that are not in this " + f"Learner! Found IDs: {unknown_module_ids}" + ) - # Log metrics. - self._log_steps_trained_metrics(batch) + # Log metrics. + self._log_steps_trained_metrics(batch) - # Make the actual in-graph/traced `_update` call. This should return - # all tensor values (no numpy). - fwd_out, loss_per_module, tensor_metrics = self._update( - batch.policy_batches - ) - # Convert logged tensor metrics (logged during tensor-mode of MetricsLogger) - # to actual (numpy) values. - self.metrics.tensors_to_numpy(tensor_metrics) + # Make the actual in-graph/traced `_update` call. This should return + # all tensor values (no numpy). + fwd_out, loss_per_module, tensor_metrics = self._update( + batch.policy_batches + ) + # Convert logged tensor metrics (logged during tensor-mode of MetricsLogger) + # to actual (numpy) values. + self.metrics.tensors_to_numpy(tensor_metrics) - self._set_slicing_by_batch_id(batch, value=False) - # If `num_iters` is reached break and return. - if num_iters and i == num_iters: - break + self._set_slicing_by_batch_id(batch, value=False) + # If `num_iters` is reached break and return. + if num_iters and i == num_iters: + break logger.debug( f"===> [Learner {id(self)}] number of iterations run in this epoch: {i}" @@ -1180,6 +1184,18 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]: value=loss, window=1, ) + # Record the number of batches pulled from the dataset in this RLlib iteration. + self.metrics.log_value( + DATASET_NUM_ITERS_TRAINED, + i, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + DATASET_NUM_ITERS_TRAINED_LIFETIME, + i, + reduce="sum", + ) # Call `after_gradient_based_update` to allow for non-gradient based # cleanups-, logging-, and update logic to happen. # TODO (simon): Check, if this should stay here, when running multiple @@ -1385,6 +1401,11 @@ def _update_from_batch_or_episodes( batch = MultiAgentBatch( {next(iter(self.module.keys())): batch}, env_steps=len(batch) ) + # If we have already an `MultiAgentBatch` but with `numpy` array, convert to tensors. + elif isinstance(batch, MultiAgentBatch) and isinstance( + next(iter(batch.policy_batches.values()))["obs"], numpy.ndarray + ): + batch = self._convert_batch_type(batch) # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs # found in this batch. If not, throw an error. diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index 9ef6abb3748d1..0826a8e3eced1 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -1,5 +1,5 @@ import pathlib -from collections import defaultdict, Counter +from collections import defaultdict import copy from functools import partial import itertools @@ -133,19 +133,18 @@ def __init__( else: backend_config = _get_backend_config(learner_class) - # TODO (sven): Can't set both `num_cpus_per_learner`>1 and - # `num_gpus_per_learner`>0! Users must set one or the other due - # to issues with placement group fragmentation. See - # https://github.com/ray-project/ray/issues/35409 for more details. num_cpus_per_learner = ( self.config.num_cpus_per_learner - if not self.config.num_gpus_per_learner + if self.config.num_cpus_per_learner != "auto" + else 1 + if self.config.num_gpus_per_learner == 0 else 0 ) num_gpus_per_learner = max( 0, self.config.num_gpus_per_learner - - (0.01 * self.config.num_aggregator_actors_per_learner), + # TODO (sven): Activate this when Ray has figured out GPU pre-loading. + # - (0.01 * self.config.num_aggregator_actors_per_learner), ) resources_per_learner = { "CPU": num_cpus_per_learner, @@ -178,19 +177,13 @@ def __init__( self.config.max_requests_in_flight_per_learner ), ) - # Counters for the tags for asynchronous update requests that are - # in-flight. Used for keeping trakc of and grouping together the results of - # requests that were sent to the workers at the same time. - self._update_request_tags = Counter() - self._update_request_tag = 0 - self._update_request_results = {} # TODO (sven): Replace this with call to `self.metrics.peek()`? # Currently LearnerGroup does not have a metrics object. def get_stats(self) -> Dict[str, Any]: """Returns the current stats for the input queue for this learner group.""" return { - "learner_group_ts_dropped": self._ts_dropped, + "learner_group_ts_dropped_lifetime": self._ts_dropped, "actor_manager_num_outstanding_async_reqs": ( 0 if self.is_local @@ -414,9 +407,15 @@ def _learner_update( " local mode! Try setting `config.num_learners > 0`." ) - if isinstance(batch, list) and isinstance(batch[0], ray.ObjectRef): + if isinstance(batch, list): + # Ensure we are not in a multi-learner setting. assert len(batch) == 1 - batch = ray.get(batch[0]) + # If we have `ObjectRef`s, get the respective objects. + if isinstance(batch[0], ray.ObjectRef): + batch = ray.get(batch[0]) + # If we have a `DataIterator`, get the iterator. + elif isinstance(batch[0], ray.data.DataIterator): + batch = batch[0] results = [ _learner_update( @@ -537,45 +536,25 @@ def _learner_update( if async_update: # Retrieve all ready results (kicked off by prior calls to this method). - tags_to_get = [] - for tag in self._update_request_tags.keys(): - result = self._worker_manager.fetch_ready_async_reqs( - tags=[str(tag)], timeout_seconds=0.0 - ) - if tag not in self._update_request_results: - self._update_request_results[tag] = result - else: - for r in result: - self._update_request_results[tag].add_result( - r.actor_id, r.result_or_error, tag - ) - - # Still not done with this `tag` -> skip out early. - if ( - self._update_request_tags[tag] - > len(self._update_request_results[tag].result_or_errors) - > 0 - ): - break - tags_to_get.append(tag) - + results = self._worker_manager.fetch_ready_async_reqs( + timeout_seconds=0.0 + ) # Send out new request(s), if there is still capacity on the actors # (each actor is allowed only some number of max in-flight requests # at the same time). - update_tag = self._update_request_tag - self._update_request_tag += 1 - num_sent_requests = self._worker_manager.foreach_actor_async( - partials, tag=str(update_tag) - ) - if num_sent_requests: - self._update_request_tags[update_tag] = num_sent_requests + num_sent_requests = self._worker_manager.foreach_actor_async(partials) # Some requests were dropped, record lost ts/data. if num_sent_requests != len(self._workers): factor = 1 - (num_sent_requests / len(self._workers)) # Batch: Measure its length. if episodes is None: - dropped = len(batch) + if isinstance(batch, list) and isinstance(batch[0], ObjectRef): + dropped = ( + len(batch) * self.config.train_batch_size_per_learner + ) + else: + dropped = len(batch) # List of Ray ObjectRefs (each object ref is a list of episodes of # total len=`rollout_fragment_length * num_envs_per_env_runner`) elif isinstance(episodes[0], ObjectRef): @@ -595,7 +574,7 @@ def _learner_update( # a list of lists where each inner list should be the length of the # number of learner workers, if results from an non-blocking update are # ready. - results = self._get_async_results(tags_to_get) + results = self._get_async_results(results) else: results = self._get_results( @@ -615,7 +594,7 @@ def _get_results(self, results): raise result_or_error return processed_results - def _get_async_results(self, tags_to_get): + def _get_async_results(self, results): """Get results from the worker manager and group them by tag. Returns: @@ -623,32 +602,15 @@ def _get_async_results(self, tags_to_get): for same tags. """ - unprocessed_results = defaultdict(list) - for tag in tags_to_get: - results = self._update_request_results[tag] - for result in results: - result_or_error = result.get() - if result.ok: - if result.tag is None: - raise RuntimeError( - "Cannot call `LearnerGroup._get_async_results()` on " - "untagged async requests!" - ) - tag = int(result.tag) - unprocessed_results[tag].append(result_or_error) - - if tag in self._update_request_tags: - self._update_request_tags[tag] -= 1 - if self._update_request_tags[tag] == 0: - del self._update_request_tags[tag] - del self._update_request_results[tag] - else: - assert False - - else: - raise result_or_error + ret = [] + for result in results: + result_or_error = result.get() + if result.ok: + ret.append(result_or_error) + else: + raise result_or_error - return list(unprocessed_results.values()) + return ret def add_module( self, diff --git a/rllib/core/learner/tests/test_learner_group.py b/rllib/core/learner/tests/test_learner_group.py index 04a2566f6839d..1ad42f875996b 100644 --- a/rllib/core/learner/tests/test_learner_group.py +++ b/rllib/core/learner/tests/test_learner_group.py @@ -501,9 +501,7 @@ def tearDown(cls) -> None: def test_async_update(self): """Test that async style updates converge to the same result as sync.""" - # async_update only needs to be tested for the most complex case. - # so we'll only test it for multi-gpu-ddp. - scaling_modes = ["multi-gpu-ddp", "remote-gpu"] + scaling_modes = ["multi-gpu-ddp", "multi-cpu-ddp", "remote-gpu"] for scaling_mode in scaling_modes: print(f"Testing scaling mode: {scaling_mode}.") @@ -534,31 +532,25 @@ def test_async_update(self): ) if not result_async: continue - self.assertIsInstance(result_async[0], list) - self.assertIsInstance(result_async[0][0], dict) - # Check the latest async result AND those sub-results of the first - # Learner in the group. - loss = result_async[-1][0][DEFAULT_MODULE_ID][Learner.TOTAL_LOSS_KEY] + self.assertIsInstance(result_async, list) + self.assertIsInstance(result_async[0], dict) + # Check one async Learner result. + loss = result_async[0][DEFAULT_MODULE_ID][Learner.TOTAL_LOSS_KEY] # The loss is initially around 0.69 (ln2). When it gets to around # 0.57 the return of the policy gets to around 100. if loss < 0.57: break # Compare reported "mean_weight" with actual ones. - _check_multi_worker_weights( - learner_group, result_async, result_async=True - ) + _check_multi_worker_weights(learner_group, result_async) iter_i += 1 learner_group.shutdown() self.assertLess(loss, 0.57) -def _check_multi_worker_weights(learner_group, results, result_async=False): +def _check_multi_worker_weights(learner_group, results): # Check that module weights are updated across workers and synchronized. # for i in range(1, len(results)): - if result_async: - results = results[-1] - learner_1_results = results[0] for module_id, mod_result in learner_1_results.items(): if module_id == ALL_MODULES: diff --git a/rllib/core/models/README.rst b/rllib/core/models/README.rst index 2ef2007403e25..5aadf2abc7620 100644 --- a/rllib/core/models/README.rst +++ b/rllib/core/models/README.rst @@ -1,2 +1,2 @@ This folder holds models that are under development and to be used with RLModules in upcoming versions of RLlib. -They are not yet ready for use in the current version of RLlib. \ No newline at end of file +They are not yet ready for use in the current version of RLlib. diff --git a/rllib/core/rl_module/apis/inference_only_api.py b/rllib/core/rl_module/apis/inference_only_api.py index 34bed76781d25..723c1e3480c23 100644 --- a/rllib/core/rl_module/apis/inference_only_api.py +++ b/rllib/core/rl_module/apis/inference_only_api.py @@ -10,11 +10,11 @@ class InferenceOnlyAPI(abc.ABC): Only the `get_non_inference_attributes` method needs to get implemented for an RLModule to have the following functionality: - - On EnvRunners (or when self.inference_only=True), RLlib will remove - those parts of the model not required for action computation. - - An RLModule on a Learner (where `self.inference_only=False`) will - return only those weights from `get_state()` that are part of its inference-only - version, thus possibly saving network traffic/time. + - On EnvRunners (or when self.inference_only=True), RLlib will remove + those parts of the model not required for action computation. + - An RLModule on a Learner (where `self.inference_only=False`) will + return only those weights from `get_state()` that are part of its inference-only + version, thus possibly saving network traffic/time. """ @abc.abstractmethod @@ -30,14 +30,13 @@ def get_non_inference_attributes(self) -> List[str]: For example: - .. testcode:: - :skipif: True + .. code-block:: python from ray.rllib.core.rl_module.rl_module import RLModuleSpec spec = RLModuleSpec(module_class=..., inference_only=True) - If an RLModule has the following `setup()` implementation: + If an RLModule has the following setup() implementation: .. testcode:: :skipif: True @@ -48,11 +47,11 @@ def setup(self): self._policy_head = [some NN component] self._value_function_head = [some NN component] - self._encoder = [some NN component with attributes: `pol` and `vf` + self._encoder = [some NN component with attributes: pol and vf (policy- and value func. encoder)] Then its `get_non_inference_attributes()` should return: - `["_value_function_head", "_encoder.vf"]` + ["_value_function_head", "_encoder.vf"]. Note the "." notation to separate attributes and their sub-attributes in case you need more fine-grained control over which exact sub-attributes to exclude in diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index 49b8097675f54..c1fff46f26a18 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -527,47 +527,34 @@ class MultiRLModuleSpec: share neural networks across the modules, the build method can be overridden to create the shared module first and then pass it to custom module classes that would then use it as a shared module. - - Args: - multi_rl_module_class: The class of the MultiRLModule to construct. By - default, this is the base `MultiRLModule` class. - observation_space: Optional global observation space for the MultiRLModule. - Useful for shared network components that live only inside the MultiRLModule - and don't have their own ModuleID and own RLModule within - `self._rl_modules`. - action_space: Optional global action space for the MultiRLModule. - Useful for shared network components that live only inside the MultiRLModule - and don't have their own ModuleID and own RLModule within - `self._rl_modules`. - inference_only: An optional global inference_only flag. If not set (None by - default), considers the MultiRLModule to be inference_only=True, only - if all submodules also have their own inference_only flags set to True. - model_config: An optional global model_config dict. Useful to configure shared - network components that only live inside the MultiRLModule and don't have - their own ModuleID and own RLModule within `self._rl_modules`. - rl_module_specs: The module specs for each individual module. It can be either a - RLModuleSpec used for all module_ids or a dictionary mapping - from module IDs to RLModuleSpecs for each individual module. - load_state_path: The path to the module state to load from. NOTE: This must be - an absolute path. NOTE: If the load_state_path of this spec is set, and - the load_state_path of one of the RLModuleSpecs' is also set, - the weights of that RL Module will be loaded from the path specified in - the RLModuleSpec. This is useful if you want to load the weights - of a MultiRLModule and also manually load the weights of some of the RL - modules within that MultiRLModule from other checkpoints. - modules_to_load: A set of module ids to load from the checkpoint. This is - only used if load_state_path is set. If this is None, all modules are - loaded. """ + #: The class of the MultiRLModule to construct. By default, + #: this is the base `MultiRLModule` class. multi_rl_module_class: Type[MultiRLModule] = MultiRLModule + #: Optional global observation space for the MultiRLModule. + #: Useful for shared network components that live only inside the MultiRLModule + #: and don't have their own ModuleID and own RLModule within + #: `self._rl_modules`. observation_space: Optional[gym.Space] = None + #: Optional global action space for the MultiRLModule. Useful for + #: shared network components that live only inside the MultiRLModule and don't + #: have their own ModuleID and own RLModule within `self._rl_modules`. action_space: Optional[gym.Space] = None + #: An optional global inference_only flag. If not set (None by + #: default), considers the MultiRLModule to be inference_only=True, only if all + #: submodules also have their own inference_only flags set to True. inference_only: Optional[bool] = None # TODO (sven): Once we support MultiRLModules inside other MultiRLModules, we would # need this flag in here as well, but for now, we'll leave it out for simplicity. # learner_only: bool = False + #: An optional global model_config dict. Useful to configure shared + #: network components that only live inside the MultiRLModule and don't have + #: their own ModuleID and own RLModule within `self._rl_modules`. model_config: Optional[dict] = None + #: The module specs for each individual module. It can be either + #: an RLModuleSpec used for all module_ids or a dictionary mapping from module + #: IDs to RLModuleSpecs for each individual module. rl_module_specs: Union[RLModuleSpec, Dict[ModuleID, RLModuleSpec]] = None # TODO (sven): Deprecate these in favor of using the pure Checkpointable APIs for diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 27f0861f365ba..ae43c2752cccb 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -361,7 +361,18 @@ class RLModule(Checkpointable, abc.ABC): Args: - config: The config for the RLModule. + observation_space: The observation space of the model. Note that in multi-agent + setups, this is typically the observation space of an agent that maps to + this RLModule. + action_space: The action space of the model. Note that in multi-agent + setups, this is typically the action space of an agent that maps to + this RLModule. + inference_only: If True, this RLModule should construct itself in an inference- + only fashion. This is done automatically, if the user implements the + `InferenceOnlyAPI` with their custom RLModule subclass. False by default. + learner_only: If True, RLlib won't built this RLModule on EnvRunner actors. + False by default. + model_config: A config dict to specify features of this RLModule. Abstract Methods: ``~_forward_train``: Forward pass during training. @@ -390,7 +401,6 @@ def __init__( # TODO (sven): Deprecate Catalog and replace with utility functions to create # primitive components based on obs- and action spaces. self.catalog = None - self._catalog_ctor_error = None # Deprecated self.config = config @@ -409,22 +419,23 @@ def __init__( self.inference_only = inference_only self.learner_only = learner_only self.model_config = model_config - try: - self.catalog = catalog_class( - observation_space=self.observation_space, - action_space=self.action_space, - model_config_dict=self.model_config, - ) - except Exception as e: - logger.warning( - "Could not create a Catalog object for your RLModule! If you are " - "not using the new API stack yet, make sure to switch it off in " - "your config: `config.api_stack(enable_rl_module_and_learner=False" - ", enable_env_runner_and_connector_v2=False)`. All algos " - "use the new stack by default. Ignore this message, if your " - "RLModule does not use a Catalog to build its sub-components." - ) - self._catalog_ctor_error = e + if catalog_class is not None: + try: + self.catalog = catalog_class( + observation_space=self.observation_space, + action_space=self.action_space, + model_config_dict=self.model_config, + ) + except Exception as e: + logger.warning( + "Didn't create a Catalog object for your RLModule! If you are " + "not using the new API stack yet, make sure to switch it off in" + " your config: `config.api_stack(enable_rl_module_and_learner=" + "False, enable_env_runner_and_connector_v2=False)`. All algos " + "use the new stack by default. Ignore this message, if your " + "RLModule does not use a Catalog to build its sub-components." + ) + self._catalog_ctor_error = e # TODO (sven): Deprecate this. We keep it here for now in case users # still have custom models (or subclasses of RLlib default models) diff --git a/rllib/env/remote_base_env.py b/rllib/env/remote_base_env.py index b9e388d50bcfb..9ff6537a9d329 100644 --- a/rllib/env/remote_base_env.py +++ b/rllib/env/remote_base_env.py @@ -328,7 +328,7 @@ def stop(self) -> None: @override(BaseEnv) def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]: if as_dict: - return {env_id: actor for env_id, actor in enumerate(self.actors)} + return dict(enumerate(self.actors)) return self.actors @property diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index cd16a05971151..13ab9cdf42d38 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -3498,10 +3498,7 @@ def _mock_multi_agent_records_from_env( # In the other case we need at least the last observations for the next # actions. else: - obs = { - agent_id: agent_obs - for agent_id, agent_obs in episode.get_observations(-1).items() - } + obs = dict(episode.get_observations(-1)) # Sample `size` many records. done_agents = {aid for aid, t in episode.get_terminateds().items() if t} diff --git a/rllib/env/tests/test_policy_client_server_setup.sh b/rllib/env/tests/test_policy_client_server_setup.sh deleted file mode 100755 index 436f3963b4c7c..0000000000000 --- a/rllib/env/tests/test_policy_client_server_setup.sh +++ /dev/null @@ -1,121 +0,0 @@ -#!/bin/bash - -# Driver script for testing RLlib's client/server setup. -# Run as follows: -# $ test_policy_client_server_setup.sh [inference-mode: local|remote] [env: cartpole|cartpole-dummy-2-episodes|unity3d] - -rm -f last_checkpoint.out - -if [ "$1" == "local" ]; then - inference_mode=local -else - inference_mode=remote -fi - -# CartPole client/server setup. -if [ "$2" == "cartpole" ]; then - server_script=cartpole_server.py - client_script=cartpole_client.py - stop_criterion="--stop-reward=150.0" - algo_cls="PPO" - use_lstm="" -elif [ "$2" == "cartpole_lstm" ]; then - server_script=cartpole_server.py - client_script=cartpole_client.py - stop_criterion="--stop-reward=150.0" - algo_cls="IMPALA" - use_lstm="--use-lstm" -# Unity3D dummy setup. -elif [ "$2" == "unity3d" ]; then - server_script=unity3d_server.py - client_script=unity3d_dummy_client.py - stop_criterion="--num-episodes=10" - algo_cls="PPO" - use_lstm="" -# CartPole dummy test using 2 simultaneous episodes on the client. -# One episode has training_enabled=False (its data should NOT arrive at server). -else - server_script=cartpole_server.py - client_script=dummy_client_with_two_episodes.py - stop_criterion="--dummy-arg=dummy" # no stop criterion: client script terminates either way - algo_cls="PPO" - use_lstm="" -fi - -port=$3 -worker_1_port=$((port)) -# This is hardcoded in the server/client scripts, that per-worker -# port is base_port + worker_idx -worker_2_port=$((port + 1)) - -pkill -f $server_script -sleep 1 - -if [ -f test_policy_client_server_setup.sh ]; then - basedir="../../examples/envs/external_envs" -else - basedir="rllib/examples/envs/external_envs" # In bazel. -fi - -# Start server with 2 workers (will listen on ports worker_1_port and worker_2_port for client -# connections). -# Do not attempt to restore from checkpoint; leads to errors on travis. -# shellcheck disable=SC2086 -(python $basedir/$server_script --run="$algo_cls" --num-workers=2 $use_lstm --no-restore --port=$worker_1_port 2>&1 | grep -v 200) & -server_pid=$! - -echo "Waiting for server to start ..." -while ! curl localhost:$worker_1_port; do - sleep 1 -done -echo "Remote worker #1 on port $worker_1_port is up!" -while ! curl localhost:$worker_2_port; do - sleep 1 -done -echo "Remote worker #2 on port $worker_2_port is up!" - -# Start client 1 (connect to port $worker_1_port). -sleep 2 -(python $basedir/$client_script --inference-mode=$inference_mode --port=$worker_1_port) & -client1_pid=$! - -# Start client 2 (connect to port $worker_2_port). -sleep 2 -(python $basedir/$client_script --inference-mode=$inference_mode --port=$worker_2_port) & -client2_pid=$! - -# Start client 3 (also connecting to port $worker_2_port) and run it until it reaches -# x reward (CartPole) or n episodes (dummy Unity3D). -# Then stop everything. -sleep 2 -python $basedir/$client_script --inference-mode=$inference_mode --port=$worker_2_port "$stop_criterion" - -exit_if_not_running() -{ - local pid=$1 - - if ps -p "$pid"> /dev/null - then - return - fi - - echo "$2 is not running" - - wait "$pid" - exit_code=$? - echo "$2 exited with code $exit_code" - return $exit_code -} - -exit_if_not_running $client1_pid "client 1" -client_1_exit_code=$? -exit_if_not_running $client2_pid "client 2" -client_2_exit_code=$? -exit_if_not_running $server_pid "server" -server_exit_code=$? - -if [ "$client_1_exit_code" != 0 ] || [ "$client_2_exit_code" != 0 ] || [ "$server_exit_code" != 0 ]; then - echo "Test failed!" - exit 1 -fi -kill $server_pid $client1_pid $client2_pid || true diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index c3e0896ba05e7..b1da92dd0cad3 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -498,10 +498,7 @@ def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], di if not as_dict: return self.vector_env.get_sub_environments() else: - return { - _id: env - for _id, env in enumerate(self.vector_env.get_sub_environments()) - } + return dict(enumerate(self.vector_env.get_sub_environments())) @override(BaseEnv) def try_render(self, env_id: Optional[EnvID] = None) -> None: diff --git a/rllib/env/wrappers/exception_wrapper.py b/rllib/env/wrappers/exception_wrapper.py deleted file mode 100644 index 50f05fd344493..0000000000000 --- a/rllib/env/wrappers/exception_wrapper.py +++ /dev/null @@ -1,38 +0,0 @@ -import logging -import traceback - -import gymnasium as gym - -logger = logging.getLogger(__name__) - - -class TooManyResetAttemptsException(Exception): - def __init__(self, max_attempts: int): - super().__init__( - f"Reached the maximum number of attempts ({max_attempts}) " - f"to reset an environment." - ) - - -class ResetOnExceptionWrapper(gym.Wrapper): - def __init__(self, env: gym.Env, max_reset_attempts: int = 5): - super().__init__(env) - self.max_reset_attempts = max_reset_attempts - - def reset(self, **kwargs): - attempt = 0 - while attempt < self.max_reset_attempts: - try: - return self.env.reset(**kwargs) - except Exception: - logger.error(traceback.format_exc()) - attempt += 1 - else: - raise TooManyResetAttemptsException(self.max_reset_attempts) - - def step(self, action): - try: - return self.env.step(action) - except Exception: - logger.error(traceback.format_exc()) - return self.reset(), 0.0, False, {"__terminated__": True} diff --git a/rllib/env/wrappers/open_spiel.py b/rllib/env/wrappers/open_spiel.py index c46c753009880..abc051c657700 100644 --- a/rllib/env/wrappers/open_spiel.py +++ b/rllib/env/wrappers/open_spiel.py @@ -61,7 +61,7 @@ def step(self, action): penalties[curr_player] = -0.1 # Compile rewards dict. - rewards = {ag: r for ag, r in enumerate(self.state.returns())} + rewards = dict(enumerate(self.state.returns())) # Simultaneous game. else: assert self.state.current_player() == -2 @@ -73,7 +73,7 @@ def step(self, action): # Compile rewards dict and add the accumulated penalties # (for taking invalid actions). - rewards = {ag: r for ag, r in enumerate(self.state.returns())} + rewards = dict(enumerate(self.state.returns())) for ag, penalty in penalties.items(): rewards[ag] += penalty diff --git a/rllib/env/wrappers/tests/test_exception_wrapper.py b/rllib/env/wrappers/tests/test_exception_wrapper.py deleted file mode 100644 index 8e4818f0502da..0000000000000 --- a/rllib/env/wrappers/tests/test_exception_wrapper.py +++ /dev/null @@ -1,61 +0,0 @@ -import random -import unittest - -import gymnasium as gym -from ray.rllib.env.wrappers.exception_wrapper import ( - ResetOnExceptionWrapper, - TooManyResetAttemptsException, -) - - -class TestResetOnExceptionWrapper(unittest.TestCase): - def test_unstable_env(self): - class UnstableEnv(gym.Env): - observation_space = gym.spaces.Discrete(2) - action_space = gym.spaces.Discrete(2) - - def step(self, action): - if random.choice([True, False]): - raise ValueError("An error from a unstable environment.") - return self.observation_space.sample(), 0.0, False, False, {} - - def reset(self, *, seed=None, options=None): - return self.observation_space.sample(), {} - - env = UnstableEnv() - env = ResetOnExceptionWrapper(env) - - try: - self._run_for_100_steps(env) - except Exception: - self.fail() - - def test_very_unstable_env(self): - class VeryUnstableEnv(gym.Env): - observation_space = gym.spaces.Discrete(2) - action_space = gym.spaces.Discrete(2) - - def step(self, action): - return self.observation_space.sample(), 0.0, False, False, {} - - def reset(self, *, seed=None, options=None): - raise ValueError("An error from a very unstable environment.") - - env = VeryUnstableEnv() - env = ResetOnExceptionWrapper(env) - self.assertRaises( - TooManyResetAttemptsException, lambda: self._run_for_100_steps(env) - ) - - @staticmethod - def _run_for_100_steps(env): - env.reset() - for _ in range(100): - env.step(env.action_space.sample()) - - -if __name__ == "__main__": - import sys - import pytest - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/examples/_docs/rllib_on_rllib_readme.py b/rllib/examples/_docs/rllib_on_rllib_readme.py index 4463eba4ce85c..be63d2da2c78b 100644 --- a/rllib/examples/_docs/rllib_on_rllib_readme.py +++ b/rllib/examples/_docs/rllib_on_rllib_readme.py @@ -46,7 +46,7 @@ def step(self, action): Returns: New observation, reward, done-flag, info-dict (empty). """ - # Set `done` and `truncated` flags after 10 steps. + # Set `terminated` and `truncated` flags to True after 10 steps. self.episode_len += 1 terminated = truncated = self.episode_len >= 10 # r = -abs(obs - action) @@ -60,9 +60,9 @@ def step(self, action): # act in the above environment. config = ( PPOConfig().environment( - # Env class to use (here: our gym.Env sub-class from above). + # Env class to use (your gym.Env subclass from above). env=ParrotEnv, - # Config dict to be passed to our custom env's constructor. + # Config dict to be passed to your custom env's constructor. env_config={"parrot_shriek_range": gym.spaces.Box(-5.0, 5.0, (1,))}, ) # Parallelize environment rollouts. diff --git a/rllib/examples/envs/classes/d4rl_env.py b/rllib/examples/envs/classes/d4rl_env.py index f77434589b92b..768c66db48bae 100644 --- a/rllib/examples/envs/classes/d4rl_env.py +++ b/rllib/examples/envs/classes/d4rl_env.py @@ -9,7 +9,7 @@ try: import d4rl - d4rl.__name__ # Fool LINTer. + _ = d4rl.__name__ # Fool LINTer. except ImportError: d4rl = None diff --git a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py index e65783ae4a862..95293194a41f1 100644 --- a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py +++ b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py @@ -90,18 +90,16 @@ def compute_values(self, batch: Dict[str, TensorType], embeddings=None): # Squeeze out last dimension (single node value head). return vf_out.squeeze(-1) + # __sphinx_begin__ def _pi(self, obs, inference: bool): - # Prior forward pass. + # Prior forward pass and sample a1. prior_out = self._prior_net(obs) dist_a1 = TorchCategorical.from_logits(prior_out) - - # If in inference mode, we need to set the distribution to be deterministic. if inference: dist_a1 = dist_a1.to_deterministic() - # Sample a1. a1 = dist_a1.sample() - # Posterior forward pass. + # Posterior forward pass and sample a2. posterior_batch = torch.cat( [obs, one_hot(a1, self.action_space[0])], dim=-1, @@ -110,22 +108,18 @@ def _pi(self, obs, inference: bool): dist_a2 = TorchDiagGaussian.from_logits(posterior_out) if inference: dist_a2 = dist_a2.to_deterministic() - a2 = dist_a2.sample() - actions = (a1, a2) - # We need the log-probabilities for the loss. - outputs = { + # We need logp and distribution parameters for the loss. + return { Columns.ACTION_LOGP: ( TorchMultiDistribution((dist_a1, dist_a2)).logp(actions) ), Columns.ACTION_DIST_INPUTS: torch.cat([prior_out, posterior_out], dim=-1), - # Concatenate the prior and posterior actions and log probabilities. Columns.ACTIONS: actions, } - - return outputs + # __sphinx_end__ @override(TorchRLModule) def get_inference_action_dist_cls(self): diff --git a/rllib/examples/rl_modules/classes/modelv2_to_rlm.py b/rllib/examples/rl_modules/classes/modelv2_to_rlm.py index 5efbead7e66f0..4cfa6d34d67d3 100644 --- a/rllib/examples/rl_modules/classes/modelv2_to_rlm.py +++ b/rllib/examples/rl_modules/classes/modelv2_to_rlm.py @@ -187,7 +187,7 @@ def compute_values(self, batch: Dict[str, Any], embeddings: Optional[Any] = None def get_initial_state(self): """Converts the initial state list of ModelV2 into a dict (new API stack).""" init_state_list = self._model_v2.get_initial_state() - return {i: s for i, s in enumerate(init_state_list)} + return dict(enumerate(init_state_list)) def _translate_dist_class(self, old_dist_class): map_ = { diff --git a/rllib/examples/rl_modules/classes/vpg_torch_rlm.py b/rllib/examples/rl_modules/classes/vpg_torch_rlm.py index 676598d090dc0..77ee0e70dc7a4 100644 --- a/rllib/examples/rl_modules/classes/vpg_torch_rlm.py +++ b/rllib/examples/rl_modules/classes/vpg_torch_rlm.py @@ -27,7 +27,7 @@ def setup(self): ) def _forward(self, batch, **kwargs): - # Push the observations from the batch through our pi-head. + # Push the observations from the batch through our `self._policy_net`. action_logits = self._policy_net(batch[Columns.OBS]) # Return parameters for the (default) action distribution, which is # `TorchCategorical` (due to our action space being `gym.spaces.Discrete`). diff --git a/rllib/models/tests/test_action_distributions.py b/rllib/models/tests/test_action_distributions.py index 6de0c1aa62a03..254b8ba315cdc 100644 --- a/rllib/models/tests/test_action_distributions.py +++ b/rllib/models/tests/test_action_distributions.py @@ -1,22 +1,14 @@ -from functools import partial -from gymnasium.spaces import Box, Dict, Tuple +from gymnasium.spaces import Box import numpy as np -from scipy.stats import beta, norm -import tree # pip install dm_tree +from scipy.stats import norm import unittest from ray.rllib.models.torch.torch_action_dist import ( - TorchBeta, TorchCategorical, TorchDiagGaussian, - TorchMultiActionDistribution, - TorchMultiCategorical, - TorchSquashedGaussian, ) from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.numpy import ( - MIN_LOG_NN_OUTPUT, - MAX_LOG_NN_OUTPUT, softmax, SMALL_NUMBER, LARGE_INTEGER, @@ -142,166 +134,6 @@ def test_categorical(self): expected_entropy = -np.sum(probs * np.log(probs), -1) check(out, expected_entropy) - def test_multi_categorical(self): - batch_size = 100 - num_categories = 3 - num_sub_distributions = 5 - # Create 5 categorical distributions of 3 categories each. - inputs_space = Box( - -1.0, 2.0, shape=(batch_size, num_sub_distributions * num_categories) - ) - inputs_space.seed(42) - values_space = Box( - 0, - num_categories - 1, - shape=(num_sub_distributions, batch_size), - dtype=np.int32, - ) - values_space.seed(42) - - inputs = inputs_space.sample() - input_lengths = [num_categories] * num_sub_distributions - inputs_split = np.split(inputs, num_sub_distributions, axis=1) - - # Create the correct distribution object. - cls = TorchMultiCategorical - multi_categorical = cls(inputs, None, input_lengths) - - # Do a stability test using extreme NN outputs to see whether - # sampling and logp'ing result in NaN or +/-inf values. - self._stability_test( - cls, - inputs_space.shape, - fw="torch", - sess=None, - bounds=(0, num_categories - 1), - extra_kwargs={"input_lens": input_lengths}, - ) - - # Batch of size=3 and deterministic (True). - expected = np.transpose(np.argmax(inputs_split, axis=-1)) - # Sample, expect always max value - # (max likelihood for deterministic draw). - out = multi_categorical.deterministic_sample() - check(out, expected) - - # Batch of size=3 and non-deterministic -> expect roughly the mean. - out = multi_categorical.sample() - check(torch.mean(out.float()), 1.0, decimals=0) - - # Test log-likelihood outputs. - probs = softmax(inputs_split) - values = values_space.sample() - - out = multi_categorical.logp( - [torch.Tensor(values[i]) for i in range(num_sub_distributions)] - ) - expected = [] - for i in range(batch_size): - expected.append( - np.sum( - np.log( - np.array( - [ - probs[j][i][values[j][i]] - for j in range(num_sub_distributions) - ] - ) - ) - ) - ) - check(out, expected, decimals=4) - - # Test entropy outputs. - out = multi_categorical.entropy() - expected_entropy = -np.sum(np.sum(probs * np.log(probs), 0), -1) - check(out, expected_entropy) - - def test_squashed_gaussian(self): - """Tests the SquashedGaussian ActionDistribution for all frameworks.""" - input_space = Box(-2.0, 2.0, shape=(2000, 10)) - input_space.seed(42) - - low, high = -2.0, 1.0 - - cls = TorchSquashedGaussian - - # Do a stability test using extreme NN outputs to see whether - # sampling and logp'ing result in NaN or +/-inf values. - self._stability_test( - cls, input_space.shape, fw="torch", sess=None, bounds=(low, high) - ) - - # Batch of size=n and deterministic. - inputs = input_space.sample() - means, _ = np.split(inputs, 2, axis=-1) - squashed_distribution = cls(inputs, {}, low=low, high=high) - expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low - # Sample n times, expect always mean value (deterministic draw). - out = squashed_distribution.deterministic_sample() - check(out, expected) - - # Batch of size=n and non-deterministic -> expect roughly the mean. - inputs = input_space.sample() - means, log_stds = np.split(inputs, 2, axis=-1) - squashed_distribution = cls(inputs, {}, low=low, high=high) - expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low - values = squashed_distribution.sample() - values = values.numpy() - self.assertTrue(np.max(values) <= high) - self.assertTrue(np.min(values) >= low) - - check(np.mean(values), expected.mean(), decimals=1) - - # Test log-likelihood outputs. - sampled_action_logp = squashed_distribution.logp(torch.Tensor(values)) - sampled_action_logp = sampled_action_logp.numpy() - # Convert to parameters for distr. - stds = np.exp(np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT)) - # Unsquash values, then get log-llh from regular gaussian. - # atanh_in = np.clip((values - low) / (high - low) * 2.0 - 1.0, - # -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER) - normed_values = (values - low) / (high - low) * 2.0 - 1.0 - save_normed_values = np.clip( - normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER - ) - unsquashed_values = np.arctanh(save_normed_values) - log_prob_unsquashed = np.sum( - np.log(norm.pdf(unsquashed_values, means, stds)), -1 - ) - log_prob = log_prob_unsquashed - np.sum( - np.log(1 - np.tanh(unsquashed_values) ** 2), axis=-1 - ) - check(np.sum(sampled_action_logp), np.sum(log_prob), rtol=0.05) - - # NN output. - means = np.array([[0.1, 0.2, 0.3, 0.4, 50.0], [-0.1, -0.2, -0.3, -0.4, -1.0]]) - log_stds = np.array([[0.8, -0.2, 0.3, -1.0, 2.0], [0.7, -0.3, 0.4, -0.9, 2.0]]) - squashed_distribution = cls( - inputs=np.concatenate([means, log_stds], axis=-1), - model={}, - low=low, - high=high, - ) - # Convert to parameters for distr. - stds = np.exp(log_stds) - # Values to get log-likelihoods for. - values = np.array( - [[0.9, 0.2, 0.4, -0.1, -1.05], [-0.9, -0.2, 0.4, -0.1, -1.05]] - ) - - # Unsquash values, then get log-llh from regular gaussian. - unsquashed_values = np.arctanh((values - low) / (high - low) * 2.0 - 1.0) - log_prob_unsquashed = np.sum( - np.log(norm.pdf(unsquashed_values, means, stds)), -1 - ) - log_prob = log_prob_unsquashed - np.sum( - np.log(1 - np.tanh(unsquashed_values) ** 2), axis=-1 - ) - - outs = squashed_distribution.logp(torch.Tensor(values)) - check(outs, log_prob, decimals=4) - def test_diag_gaussian(self): """Tests the DiagGaussian ActionDistribution for all frameworks.""" input_space = Box(-2.0, 1.0, shape=(2000, 10)) @@ -357,212 +189,6 @@ def test_diag_gaussian(self): outs = diag_distribution.logp(torch.Tensor(values)) check(outs, log_prob, decimals=4) - def test_beta(self): - input_space = Box(-2.0, 1.0, shape=(2000, 10)) - input_space.seed(42) - low, high = -1.0, 2.0 - plain_beta_value_space = Box(0.0, 1.0, shape=(2000, 5)) - plain_beta_value_space.seed(42) - - cls = TorchBeta - inputs = input_space.sample() - beta_distribution = cls(inputs, {}, low=low, high=high) - - inputs = beta_distribution.inputs - inputs = inputs.numpy() - alpha, beta_ = np.split(inputs, 2, axis=-1) - - # Mean for a Beta distribution: 1 / [1 + (beta/alpha)] - expected = (1.0 / (1.0 + beta_ / alpha)) * (high - low) + low - # Sample n times, expect always mean value (deterministic draw). - out = beta_distribution.deterministic_sample() - check(out, expected, rtol=0.01) - - # Batch of size=n and non-deterministic -> expect roughly the mean. - values = beta_distribution.sample() - values = values.numpy() - self.assertTrue(np.max(values) <= high) - self.assertTrue(np.min(values) >= low) - - check(np.mean(values), expected.mean(), decimals=1) - - # Test log-likelihood outputs (against scipy). - inputs = input_space.sample() - beta_distribution = cls(inputs, {}, low=low, high=high) - inputs = beta_distribution.inputs - inputs = inputs.numpy() - alpha, beta_ = np.split(inputs, 2, axis=-1) - - values = plain_beta_value_space.sample() - values_scaled = values * (high - low) + low - values_scaled = torch.Tensor(values_scaled) - print(values_scaled) - out = beta_distribution.logp(values_scaled) - check(out, np.sum(np.log(beta.pdf(values, alpha, beta_)), -1), rtol=0.01) - - def test_multi_action_distribution(self): - """Tests the MultiActionDistribution (across all frameworks).""" - batch_size = 1000 - input_space = Tuple( - [ - Box(-10.0, 10.0, shape=(batch_size, 4)), - Box( - -2.0, - 2.0, - shape=( - batch_size, - 6, - ), - ), - Dict({"a": Box(-1.0, 1.0, shape=(batch_size, 4))}), - ] - ) - input_space.seed(42) - std_space = Box( - -0.05, - 0.05, - shape=( - batch_size, - 3, - ), - ) - std_space.seed(42) - - low, high = -1.0, 1.0 - value_space = Tuple( - [ - Box(0, 3, shape=(batch_size,), dtype=np.int32), - Box(-2.0, 2.0, shape=(batch_size, 3), dtype=np.float32), - Dict({"a": Box(0.0, 1.0, shape=(batch_size, 2), dtype=np.float32)}), - ] - ) - value_space.seed(42) - - cls = TorchMultiActionDistribution - child_distr_cls = [ - TorchCategorical, - TorchDiagGaussian, - partial(TorchBeta, low=low, high=high), - ] - - inputs = list(input_space.sample()) - distr = cls( - np.concatenate([inputs[0], inputs[1], inputs[2]["a"]], axis=1), - model={}, - action_space=value_space, - child_distributions=child_distr_cls, - input_lens=[4, 6, 4], - ) - - # Adjust inputs for the Beta distr just as Beta itself does. - inputs[2]["a"] = np.clip( - inputs[2]["a"], np.log(SMALL_NUMBER), -np.log(SMALL_NUMBER) - ) - inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0 - # Sample deterministically. - expected_det = [ - np.argmax(inputs[0], axis=-1), - inputs[1][:, :3], # [:3]=Mean values. - # Mean for a Beta distribution: - # 1 / [1 + (beta/alpha)] * range + low - (1.0 / (1.0 + inputs[2]["a"][:, 2:] / inputs[2]["a"][:, 0:2])) - * (high - low) - + low, - ] - out = distr.deterministic_sample() - check(out[0], expected_det[0]) - check(out[1], expected_det[1]) - check(out[2]["a"], expected_det[2]) - - # Stochastic sampling -> expect roughly the mean. - inputs = list(input_space.sample()) - # Fix categorical inputs (not needed for distribution itself, but - # for our expectation calculations). - inputs[0] = softmax(inputs[0], -1) - # Fix std inputs (shouldn't be too large for this test). - inputs[1][:, 3:] = std_space.sample() - # Adjust inputs for the Beta distr just as Beta itself does. - inputs[2]["a"] = np.clip( - inputs[2]["a"], np.log(SMALL_NUMBER), -np.log(SMALL_NUMBER) - ) - inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0 - distr = cls( - np.concatenate([inputs[0], inputs[1], inputs[2]["a"]], axis=1), - model={}, - action_space=value_space, - child_distributions=child_distr_cls, - input_lens=[4, 6, 4], - ) - expected_mean = [ - np.mean(np.sum(inputs[0] * np.array([0, 1, 2, 3]), -1)), - inputs[1][:, :3], # [:3]=Mean values. - # Mean for a Beta distribution: - # 1 / [1 + (beta/alpha)] * range + low - (1.0 / (1.0 + inputs[2]["a"][:, 2:] / inputs[2]["a"][:, :2])) * (high - low) - + low, - ] - out = distr.sample() - out = list(out) - out[0] = out[0].numpy() - out[1] = out[1].numpy() - out[2]["a"] = out[2]["a"].numpy() - check(np.mean(out[0]), expected_mean[0], decimals=1) - check(np.mean(out[1], 0), np.mean(expected_mean[1], 0), decimals=1) - check(np.mean(out[2]["a"], 0), np.mean(expected_mean[2], 0), decimals=1) - - # Test log-likelihood outputs. - # Make sure beta-values are within 0.0 and 1.0 for the numpy - # calculation (which doesn't have scaling). - inputs = list(input_space.sample()) - # Adjust inputs for the Beta distr just as Beta itself does. - inputs[2]["a"] = np.clip( - inputs[2]["a"], np.log(SMALL_NUMBER), -np.log(SMALL_NUMBER) - ) - inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0 - distr = cls( - np.concatenate([inputs[0], inputs[1], inputs[2]["a"]], axis=1), - model={}, - action_space=value_space, - child_distributions=child_distr_cls, - input_lens=[4, 6, 4], - ) - inputs[0] = softmax(inputs[0], -1) - values = list(value_space.sample()) - log_prob_beta = np.log( - beta.pdf(values[2]["a"], inputs[2]["a"][:, :2], inputs[2]["a"][:, 2:]) - ) - # Now do the up-scaling for [2] (beta values) to be between - # low/high. - values[2]["a"] = values[2]["a"] * (high - low) + low - inputs[1][:, 3:] = np.exp(inputs[1][:, 3:]) - expected_log_llh = np.sum( - np.concatenate( - [ - np.expand_dims( - np.log([i[values[0][j]] for j, i in enumerate(inputs[0])]), - -1, - ), - np.log(norm.pdf(values[1], inputs[1][:, :3], inputs[1][:, 3:])), - log_prob_beta, - ], - -1, - ), - -1, - ) - - values[0] = np.expand_dims(values[0], -1) - values = tree.map_structure(lambda s: torch.Tensor(s), values) - # Test all flattened input. - concat = np.concatenate(tree.flatten(values), -1).astype(np.float32) - out = distr.logp(concat) - check(out, expected_log_llh, atol=15) - # Test structured input. - out = distr.logp(values) - check(out, expected_log_llh, atol=15) - # Test flattened input. - out = distr.logp(tree.flatten(values)) - check(out, expected_log_llh, atol=15) - if __name__ == "__main__": import pytest diff --git a/rllib/models/tests/test_conv2d_default_stacks.py b/rllib/models/tests/test_conv2d_default_stacks.py deleted file mode 100644 index 4cbafb7adbd51..0000000000000 --- a/rllib/models/tests/test_conv2d_default_stacks.py +++ /dev/null @@ -1,42 +0,0 @@ -import gymnasium as gym -import unittest - -from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS -from ray.rllib.models.tf.visionnet import VisionNetwork -from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVision -from ray.rllib.utils.framework import try_import_torch - -torch, nn = try_import_torch() - - -class TestConv2DDefaultStacks(unittest.TestCase): - """Tests our ConvTranspose2D Stack modules/layers.""" - - def test_conv2d_default_stacks(self): - """Tests, whether conv2d defaults are available for img obs spaces.""" - action_space = gym.spaces.Discrete(2) - - shapes = [ - (96, 96, 3), - (84, 84, 3), - (42, 42, 3), - (10, 10, 3), - ] - for shape in shapes: - print(f"shape={shape}") - obs_space = gym.spaces.Box(-1.0, 1.0, shape=shape) - model = ModelCatalog.get_model_v2( - obs_space, action_space, 2, MODEL_DEFAULTS.copy(), framework="torch" - ) - self.assertTrue(isinstance(model, (VisionNetwork, TorchVision))) - output, _ = model({"obs": torch.from_numpy(obs_space.sample()[None])}) - # B x [action logits] - self.assertTrue(output.shape == (1, 2)) - print("ok") - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/tests/test_convtranspose2d_stack.py b/rllib/models/tests/test_convtranspose2d_stack.py deleted file mode 100644 index 189b5a9e22e6a..0000000000000 --- a/rllib/models/tests/test_convtranspose2d_stack.py +++ /dev/null @@ -1,63 +0,0 @@ -import gymnasium as gym -import numpy as np -import os -from pathlib import Path -import unittest - -from ray.rllib.models.preprocessors import GenericPixelPreprocessor -from ray.rllib.models.torch.modules.convtranspose2d_stack import ConvTranspose2DStack -from ray.rllib.utils.framework import try_import_torch, try_import_tf -from ray.rllib.utils.images import imread - -torch, nn = try_import_torch() -tf1, tf, tfv = try_import_tf() - - -class TestConvTranspose2DStack(unittest.TestCase): - """Tests our ConvTranspose2D Stack modules/layers.""" - - def test_convtranspose2d_stack(self): - """Tests, whether the conv2d stack can be trained to predict an image.""" - batch_size = 128 - input_size = 1 - module = ConvTranspose2DStack(input_size=input_size) - preprocessor = GenericPixelPreprocessor( - gym.spaces.Box(0, 255, (64, 64, 3), np.uint8), options={"dim": 64} - ) - optim = torch.optim.Adam(module.parameters(), lr=0.0001) - - rllib_dir = Path(__file__).parent.parent.parent - img_file = os.path.join(rllib_dir, "tests/data/images/obstacle_tower.png") - img = imread(img_file) - # Preprocess. - img = preprocessor.transform(img) - # Make channels first. - img = np.transpose(img, (2, 0, 1)) - # Add batch rank and repeat. - imgs = np.reshape(img, (1,) + img.shape) - imgs = np.repeat(imgs, batch_size, axis=0) - # Move to torch. - imgs = torch.from_numpy(imgs) - init_loss = loss = None - for _ in range(10): - # Random inputs. - inputs = torch.from_numpy( - np.random.normal(0.0, 1.0, (batch_size, input_size)) - ).float() - distribution = module(inputs) - # Construct a loss. - loss = -torch.mean(distribution.log_prob(imgs)) - if init_loss is None: - init_loss = loss - print("loss={}".format(loss)) - # Minimize loss. - loss.backward() - optim.step() - self.assertLess(loss, init_loss) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/tests/test_lstms.py b/rllib/models/tests/test_lstms.py deleted file mode 100644 index 4ef0193dbcb84..0000000000000 --- a/rllib/models/tests/test_lstms.py +++ /dev/null @@ -1,80 +0,0 @@ -from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple -import unittest - -import ray -from ray import air, tune -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.algorithms import ppo -from ray.rllib.examples.envs.classes.random_env import RandomEnv - - -class TestLSTMs(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - ray.init(num_cpus=5) - - @classmethod - def tearDownClass(cls) -> None: - ray.shutdown() - - def test_lstm_w_prev_action_and_prev_reward(self): - """Tests LSTM prev-a/r input insertions using complex actions.""" - config = ( - ppo.PPOConfig() - .api_stack( - enable_rl_module_and_learner=True, - enable_env_runner_and_connector_v2=True, - ) - .environment( - RandomEnv, - env_config={ - "action_space": Dict( - { - "a": Box(-1.0, 1.0, ()), - "b": Box(-1.0, 1.0, (2,)), - "c": Tuple( - [ - Discrete(2), - MultiDiscrete([2, 3]), - Box(-1.0, 1.0, (3,)), - ] - ), - } - ), - }, - ) - .training( - # Need to set this to True to enable complex (prev.) actions - # as inputs to the LSTM. - model={ - "fcnet_hiddens": [10], - "use_lstm": True, - "lstm_cell_size": 16, - "lstm_use_prev_action": True, - "lstm_use_prev_reward": True, - }, - num_epochs=1, - train_batch_size=200, - minibatch_size=50, - ) - .env_runners( - rollout_fragment_length=100, - num_env_runners=1, - ) - .experimental( - _disable_action_flattening=True, - ) - ) - - tune.Tuner( - "PPO", - param_space=config.to_dict(), - run_config=air.RunConfig(stop={TRAINING_ITERATION: 1}, verbose=1), - ).fit() - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/tests/test_preprocessors.py b/rllib/models/tests/test_preprocessors.py deleted file mode 100644 index f4451f15f11ab..0000000000000 --- a/rllib/models/tests/test_preprocessors.py +++ /dev/null @@ -1,250 +0,0 @@ -import gymnasium as gym -from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple, MultiBinary -import numpy as np -import unittest - -import ray -import ray.rllib.algorithms.ppo as ppo -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.preprocessors import ( - DictFlatteningPreprocessor, - get_preprocessor, - NoPreprocessor, - TupleFlatteningPreprocessor, - OneHotPreprocessor, - AtariRamPreprocessor, - GenericPixelPreprocessor, - MultiBinaryPreprocessor, -) -from ray.rllib.utils.test_utils import ( - check, - check_compute_single_action, - check_train_results, -) -from ray.rllib.utils.framework import try_import_tf - -tf1, tf, tfv = try_import_tf() - - -class TestPreprocessors(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - ray.init() - - @classmethod - def tearDownClass(cls) -> None: - ray.shutdown() - - def test_preprocessing_disabled_modelv2(self): - config = ( - ppo.PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .environment( - "ray.rllib.examples.envs.classes.random_env.RandomEnv", - env_config={ - "config": { - "observation_space": Dict( - { - "a": Discrete(5), - "b": Dict( - { - "ba": Discrete(4), - "bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32), - } - ), - "c": Tuple((MultiDiscrete([2, 3]), Discrete(1))), - "d": Box(-1.0, 1.0, (1,), dtype=np.int32), - } - ), - }, - }, - ) - # Speed things up a little. - .env_runners(rollout_fragment_length=5) - .training(train_batch_size=100, minibatch_size=10, num_epochs=1) - .debugging(seed=42) - # Set this to True to enforce no preprocessors being used. - # Complex observations now arrive directly in the model as - # structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]} - # for obs-space=Dict(a=..., b=Tuple(..., ...)). - .experimental(_disable_preprocessor_api=True) - ) - - # (Artur): This test only works under the old ModelV2 API because we - # don't offer arbitrarily complex Models under the RLModules API without - # preprocessors. Such input spaces require custom implementations of the - # input space. - - num_iterations = 1 - algo = config.build() - for i in range(num_iterations): - results = algo.train() - check_train_results(results) - print(results) - check_compute_single_action(algo) - algo.stop() - - def test_gym_preprocessors(self): - p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v1")) - self.assertEqual(type(p1), NoPreprocessor) - - p2 = ModelCatalog.get_preprocessor(gym.make("FrozenLake-v1")) - self.assertEqual(type(p2), OneHotPreprocessor) - - p3 = ModelCatalog.get_preprocessor(gym.make("ale_py:ALE/MsPacman-ram-v5")) - self.assertEqual(type(p3), AtariRamPreprocessor) - - p4 = ModelCatalog.get_preprocessor( - gym.make( - "ale_py:ALE/MsPacman-v5", - frameskip=1, - ) - ) - self.assertEqual(type(p4), GenericPixelPreprocessor) - - def test_tuple_preprocessor(self): - class TupleEnv: - def __init__(self): - self.observation_space = Tuple( - [Discrete(5), Box(0, 5, shape=(3,), dtype=np.float32)] - ) - - pp = ModelCatalog.get_preprocessor(TupleEnv()) - self.assertTrue(isinstance(pp, TupleFlatteningPreprocessor)) - self.assertEqual(pp.shape, (8,)) - self.assertEqual( - list(pp.transform((0, np.array([1, 2, 3], np.float32)))), - [float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]], - ) - - def test_multi_binary_preprocessor(self): - observation_space = MultiBinary(5) - # Firstly, exclude MultiBinary from the list of preprocessors. - pp = ModelCatalog.get_preprocessor_for_space( - observation_space, include_multi_binary=False - ) - # Scondly, include MultiBinary with the list of preprocessors. - self.assertTrue(isinstance(pp, NoPreprocessor)) - pp = ModelCatalog.get_preprocessor_for_space( - observation_space, include_multi_binary=True - ) - self.assertTrue(isinstance(pp, MultiBinaryPreprocessor)) - self.assertEqual(pp.observation_space.shape, (5,)) - check(pp.transform(np.array([0, 1, 0, 1, 1])), [0, 1, 0, 1, 1]) - - def test_dict_flattening_preprocessor(self): - space = Dict( - { - "a": Discrete(2), - "b": Tuple([Discrete(3), Box(-1.0, 1.0, (4,))]), - } - ) - pp = get_preprocessor(space)(space) - self.assertTrue(isinstance(pp, DictFlatteningPreprocessor)) - self.assertEqual(pp.shape, (9,)) - check( - pp.transform( - {"a": 1, "b": (1, np.array([0.0, -0.5, 0.1, 0.6], np.float32))} - ), - [0.0, 1.0, 0.0, 1.0, 0.0, 0.0, -0.5, 0.1, 0.6], - ) - - def test_one_hot_preprocessor(self): - space = Discrete(5) - pp = get_preprocessor(space)(space) - self.assertTrue(isinstance(pp, OneHotPreprocessor)) - self.assertTrue(pp.shape == (5,)) - check(pp.transform(3), [0.0, 0.0, 0.0, 1.0, 0.0]) - check(pp.transform(0), [1.0, 0.0, 0.0, 0.0, 0.0]) - - space = MultiDiscrete([2, 3, 4]) - pp = get_preprocessor(space)(space) - self.assertTrue(isinstance(pp, OneHotPreprocessor)) - self.assertTrue(pp.shape == (9,)) - check( - pp.transform(np.array([1, 2, 0])), - [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], - ) - check( - pp.transform(np.array([0, 1, 3])), - [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0], - ) - - def test_nested_multidiscrete_one_hot_preprocessor(self): - space = Tuple((MultiDiscrete([2, 3, 4]),)) - pp = get_preprocessor(space)(space) - self.assertTrue(pp.shape == (9,)) - check( - pp.transform((np.array([1, 2, 0]),)), - [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], - ) - check( - pp.transform((np.array([0, 1, 3]),)), - [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0], - ) - - def test_multidimensional_multidiscrete_one_hot_preprocessor(self): - space2d = MultiDiscrete([[2, 2], [3, 3]]) - space3d = MultiDiscrete([[[2, 2], [3, 4]], [[5, 6], [7, 8]]]) - pp2d = get_preprocessor(space2d)(space2d) - pp3d = get_preprocessor(space3d)(space3d) - self.assertTrue(isinstance(pp2d, OneHotPreprocessor)) - self.assertTrue(isinstance(pp3d, OneHotPreprocessor)) - self.assertTrue(pp2d.shape == (10,)) - self.assertTrue(pp3d.shape == (37,)) - check( - pp2d.transform(np.array([[1, 0], [2, 1]])), - [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0], - ) - check( - pp3d.transform(np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])), - [ - 1.0, - 0.0, - 0.0, - 1.0, - 0.0, - 0.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, - 0.0, - 0.0, - 0.0, - 0.0, - 1.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 1.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 1.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 1.0, - ], - ) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/tf/noop.py b/rllib/models/tf/noop.py deleted file mode 100644 index 30d91988e3f1d..0000000000000 --- a/rllib/models/tf/noop.py +++ /dev/null @@ -1,17 +0,0 @@ -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.utils.annotations import OldAPIStack, override -from ray.rllib.utils.framework import try_import_tf - -_, tf, _ = try_import_tf() - - -@OldAPIStack -class NoopModel(TFModelV2): - """Trivial model that just returns the obs flattened. - - This is the model used if use_state_preprocessor=False.""" - - @override(ModelV2) - def forward(self, input_dict, state, seq_lens): - return tf.cast(input_dict["obs_flat"], tf.float32), state diff --git a/rllib/models/torch/mingpt.py b/rllib/models/torch/mingpt.py index 4bf54aa2fe8e5..7e24cfdc730af 100644 --- a/rllib/models/torch/mingpt.py +++ b/rllib/models/torch/mingpt.py @@ -193,7 +193,7 @@ def configure_gpt_optimizer( no_decay.add(fpn) # validate that we considered every parameter - param_dict = {pn: p for pn, p in model.named_parameters()} + param_dict = dict(model.named_parameters()) inter_params = decay & no_decay union_params = decay | no_decay assert ( diff --git a/rllib/models/torch/modules/convtranspose2d_stack.py b/rllib/models/torch/modules/convtranspose2d_stack.py deleted file mode 100644 index 7740c461cd49b..0000000000000 --- a/rllib/models/torch/modules/convtranspose2d_stack.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Tuple - -from ray.rllib.models.torch.misc import Reshape -from ray.rllib.models.utils import get_activation_fn, get_initializer -from ray.rllib.utils.annotations import OldAPIStack -from ray.rllib.utils.framework import try_import_torch - -torch, nn = try_import_torch() -if torch: - import torch.distributions as td - - -@OldAPIStack -class ConvTranspose2DStack(nn.Module): - """ConvTranspose2D decoder generating an image distribution from a vector.""" - - def __init__( - self, - *, - input_size: int, - filters: Tuple[Tuple[int]] = ( - (1024, 5, 2), - (128, 5, 2), - (64, 6, 2), - (32, 6, 2), - ), - initializer="default", - bias_init=0, - activation_fn: str = "relu", - output_shape: Tuple[int] = (3, 64, 64) - ): - """Initializes a TransposedConv2DStack instance. - - Args: - input_size: The size of the 1D input vector, from which to - generate the image distribution. - filters (Tuple[Tuple[int]]): Tuple of filter setups (1 for each - ConvTranspose2D layer): [in_channels, kernel, stride]. - initializer (Union[str]): - bias_init: The initial bias values to use. - activation_fn: Activation function descriptor (str). - output_shape (Tuple[int]): Shape of the final output image. - """ - super().__init__() - self.activation = get_activation_fn(activation_fn, framework="torch") - self.output_shape = output_shape - initializer = get_initializer(initializer, framework="torch") - - in_channels = filters[0][0] - self.layers = [ - # Map from 1D-input vector to correct initial size for the - # Conv2DTransposed stack. - nn.Linear(input_size, in_channels), - # Reshape from the incoming 1D vector (input_size) to 1x1 image - # format (channels first). - Reshape([-1, in_channels, 1, 1]), - ] - for i, (_, kernel, stride) in enumerate(filters): - out_channels = ( - filters[i + 1][0] if i < len(filters) - 1 else output_shape[0] - ) - conv_transp = nn.ConvTranspose2d(in_channels, out_channels, kernel, stride) - # Apply initializer. - initializer(conv_transp.weight) - nn.init.constant_(conv_transp.bias, bias_init) - self.layers.append(conv_transp) - # Apply activation function, if provided and if not last layer. - if self.activation is not None and i < len(filters) - 1: - self.layers.append(self.activation()) - - # num-outputs == num-inputs for next layer. - in_channels = out_channels - - self._model = nn.Sequential(*self.layers) - - def forward(self, x): - # x is [batch, hor_length, input_size] - batch_dims = x.shape[:-1] - model_out = self._model(x) - - # Equivalent to making a multivariate diag. - reshape_size = batch_dims + self.output_shape - mean = model_out.view(*reshape_size) - return td.Independent(td.Normal(mean, 1.0), len(self.output_shape)) diff --git a/rllib/models/torch/noop.py b/rllib/models/torch/noop.py deleted file mode 100644 index 1a59b6165f7b6..0000000000000 --- a/rllib/models/torch/noop.py +++ /dev/null @@ -1,14 +0,0 @@ -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils.annotations import OldAPIStack, override - - -@OldAPIStack -class TorchNoopModel(TorchModelV2): - """Trivial model that just returns the obs flattened. - - This is the model used if use_state_preprocessor=False.""" - - @override(ModelV2) - def forward(self, input_dict, state, seq_lens): - return input_dict["obs_flat"].float(), state diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py index f2165f1bca65d..0f5630dcaadb6 100644 --- a/rllib/models/torch/torch_distributions.py +++ b/rllib/models/torch/torch_distributions.py @@ -528,7 +528,7 @@ def __init__( self, child_distribution_struct: Union[Tuple, List, Dict], ): - """Initializes a TorchMultiActionDistribution object. + """Initializes a TorchMultiDistribution object. Args: child_distribution_struct: A complex struct that contains the child @@ -652,7 +652,7 @@ def from_logits( **kwargs: Forward compatibility kwargs. Returns: - A TorchMultiActionDistribution object. + A TorchMultiDistribution object. """ logit_lens = tree.flatten(input_lens) child_distribution_cls_list = tree.flatten(child_distribution_cls_struct) diff --git a/rllib/models/utils.py b/rllib/models/utils.py index c57b94bbfd188..4cd70b29e554d 100644 --- a/rllib/models/utils.py +++ b/rllib/models/utils.py @@ -186,12 +186,12 @@ def get_filter_config(shape): [32, [4, 4], 2], [256, [11, 11], 1], ] - # Dreamer-style (S-sized model) Atari or DM Control Suite. + # Dreamer-style (XS-sized model) Atari or DM Control Suite. filters_64x64 = [ + [16, [4, 4], 2], [32, [4, 4], 2], [64, [4, 4], 2], [128, [4, 4], 2], - [256, [4, 4], 2], ] # Small (1/2) Atari. filters_42x42 = [ diff --git a/rllib/offline/offline_data.py b/rllib/offline/offline_data.py index 04d52babc877e..6cc53af6d6ef2 100644 --- a/rllib/offline/offline_data.py +++ b/rllib/offline/offline_data.py @@ -9,6 +9,7 @@ from ray.rllib.core import COMPONENT_RL_MODULE from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.offline.offline_prelearner import OfflinePreLearner +from ray.rllib.utils import force_list from ray.rllib.utils.annotations import ( OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, @@ -130,31 +131,57 @@ def sample( # (b) Rematerialize the data every couple of iterations. This is # is costly. if not self.data_is_mapped: - # Constructor `kwargs` for the `OfflinePreLearner`. - fn_constructor_kwargs = { - "config": self.config, - "learner": self.learner_handles[0], - "spaces": self.spaces[INPUT_ENV_SPACES], - } - # If we have multiple learners, add to the constructor `kwargs`. - if num_shards > 1: + + # Get the RLModule state from learners. + if num_shards >= 1: # Call here the learner to get an up-to-date module state. # TODO (simon): This is a workaround as along as learners cannot # receive any calls from another actor. module_state = ray.get( self.learner_handles[0].get_state.remote( - component=COMPONENT_RL_MODULE + component=COMPONENT_RL_MODULE, ) - ) - # Add constructor `kwargs` when using remote learners. - fn_constructor_kwargs.update( - { - "learner": None, - "module_spec": self.module_spec, - "module_state": module_state, - } - ) + )[COMPONENT_RL_MODULE] + # Provide the `Learner`(s) GPU devices, if needed. + # if not self.map_batches_uses_gpus(self.config) and self.config._validate_config: + # devices = ray.get(self.learner_handles[0].get_device.remote()) + # devices = [devices] if not isinstance(devices, list) else devices + # device_strings = [ + # f"{device.type}:{str(device.index)}" + # if device.type == "cuda" + # else device.type + # for device in devices + # ] + # # Otherwise, set the GPU strings to `None`. + # # TODO (simon): Check inside 'OfflinePreLearner'. + # else: + # device_strings = None + else: + # Get the module state from the `Learner`(S). + module_state = self.learner_handles[0].get_state( + component=COMPONENT_RL_MODULE, + )[COMPONENT_RL_MODULE] + # Provide the `Learner`(s) GPU devices, if needed. + # if not self.map_batches_uses_gpus(self.config) and self.config._validate_config: + # device = self.learner_handles[0].get_device() + # device_strings = [ + # f"{device.type}:{str(device.index)}" + # if device.type == "cuda" + # else device.type + # ] + # else: + # device_strings = None + # Constructor `kwargs` for the `OfflinePreLearner`. + fn_constructor_kwargs = { + "config": self.config, + "spaces": self.spaces[INPUT_ENV_SPACES], + "module_spec": self.module_spec, + "module_state": module_state, + # "device_strings": self.get_devices(), + } + # Map the data to run the `OfflinePreLearner`s in the data pipeline + # for training. self.data = self.data.map_batches( self.prelearner_class, fn_constructor_kwargs=fn_constructor_kwargs, @@ -170,6 +197,7 @@ def sample( # returned now and we have already generated from the iterator, i.e. # `isinstance(self.batch_iterators, types.GeneratorType) == True`, we need # to create here a new iterator. + # TODO (simon): Check, if this iterator could potentially exhaust. if not self.batch_iterators or ( return_iterator and isinstance(self.batch_iterators, types.GeneratorType) ): @@ -190,23 +218,25 @@ def sample( # Otherwise we create a simple iterator and - if necessary - initialize # it here. else: - # If no iterator should be returned, or if we want to return a single - # batch iterator, we instantiate the batch iterator once, here. - self.batch_iterators = self.data.iter_batches( - # This is important. The batch size is now 1, because the data - # is already run through the `OfflinePreLearner` and a single - # instance is a single `MultiAgentBatch` of size `num_samples`. - batch_size=1, - **self.iter_batches_kwargs, - ) - - # If there should be batches - if not return_iterator: + # Should an iterator be returned? + if return_iterator: + self.batch_iterators = self.data.iterator() + # Otherwise, the user wants batches returned. + else: + # If no iterator should be returned, or if we want to return a single + # batch iterator, we instantiate the batch iterator once, here. + self.batch_iterators = self.data.iter_batches( + # This is important. The batch size is now 1, because the data + # is already run through the `OfflinePreLearner` and a single + # instance is a single `MultiAgentBatch` of size `num_samples`. + batch_size=1, + **self.iter_batches_kwargs, + ) self.batch_iterators = iter(self.batch_iterators) # Do we want to return an iterator or a single batch? if return_iterator: - return self.batch_iterators + return force_list(self.batch_iterators) else: # Return a single batch from the iterator. try: diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index 8e8306751796e..1d6cbf3b23629 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -5,10 +5,8 @@ from typing import Any, Dict, List, Optional, Union, Set, Tuple, TYPE_CHECKING -from ray.actor import ActorHandle from ray.rllib.core.columns import Columns -from ray.rllib.core.learner import Learner -from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec, MultiRLModule from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.utils.annotations import ( @@ -72,9 +70,7 @@ class OfflinePreLearner: batches and make them 'Learner'-ready. When deriving from this class the `__call__` method and `_map_to_episodes` can be overridden to induce custom logic for the complete transformation pipeline (`__call__`) or - for converting to episodes only ('_map_to_episodes`). For an example - how this class can be used to also compute values and advantages see - `rllib.algorithm.marwil.marwil_prelearner.MAWRILOfflinePreLearner`. + for converting to episodes only ('_map_to_episodes`). Custom `OfflinePreLearner` classes can be passed into `AlgorithmConfig.offline`'s `prelearner_class`. The `OfflineData` class @@ -86,28 +82,21 @@ def __init__( self, *, config: "AlgorithmConfig", - learner: Union[Learner, list[ActorHandle]], spaces: Optional[Tuple[gym.Space, gym.Space]] = None, module_spec: Optional[MultiRLModuleSpec] = None, module_state: Optional[Dict[ModuleID, Any]] = None, **kwargs: Dict[str, Any], ): - self.config = config - self.input_read_episodes = self.config.input_read_episodes - self.input_read_sample_batches = self.config.input_read_sample_batches - # We need this learner to run the learner connector pipeline. - # If it is a `Learner` instance, the `Learner` is local. - if isinstance(learner, Learner): - self._learner = learner - self.learner_is_remote = False - self._module = self._learner._module - # Otherwise we have remote `Learner`s. - else: - self.learner_is_remote = True - # Build the module from spec. Note, this will be a MultiRLModule. - self._module = module_spec.build() - self._module.set_state(module_state) + self.config: AlgorithmConfig = config + self.input_read_episodes: bool = self.config.input_read_episodes + self.input_read_sample_batches: bool = self.config.input_read_sample_batches + # Build the module from spec. + self._module: MultiRLModule = module_spec.build() + self._module.set_state(module_state) + # Map the module to the device, if necessary. + # TODO (simon): Check here if we already have a list. + # self._set_device(device_strings) # Store the observation and action space if defined, otherwise we # set them to `None`. Note, if `None` the `convert_from_jsonable` @@ -121,9 +110,9 @@ def __init__( ) # Cache the policies to be trained to update weights only for these. self._policies_to_train = self.config.policies_to_train - self._is_multi_agent = config.is_multi_agent + self._is_multi_agent: bool = config.is_multi_agent # Set the counter to zero. - self.iter_since_last_module_update = 0 + self.iter_since_last_module_update: int = 0 # self._future = None # Set up an episode buffer, if the module is stateful or we sample from @@ -166,7 +155,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] import msgpack_numpy as mnp # Read the episodes and decode them. - episodes = [ + episodes: List[SingleAgentEpisode] = [ SingleAgentEpisode.from_state( msgpack.unpackb(state, object_hook=mnp.decode) ) @@ -190,13 +179,17 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] ) # Else, if we have old stack `SampleBatch`es. elif self.input_read_sample_batches: - episodes = OfflinePreLearner._map_sample_batch_to_episode( + episodes: List[ + SingleAgentEpisode + ] = OfflinePreLearner._map_sample_batch_to_episode( self._is_multi_agent, batch, to_numpy=True, schema=SCHEMA | self.config.input_read_schema, input_compress_columns=self.config.input_compress_columns, - )["episodes"] + )[ + "episodes" + ] # Ensure that all episodes are done and no duplicates are in the batch. episodes = self._validate_episodes(episodes) # Add the episodes to the buffer. @@ -215,7 +208,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] ) # Otherwise we map the batch to episodes. else: - episodes = self._map_to_episodes( + episodes: List[SingleAgentEpisode] = self._map_to_episodes( self._is_multi_agent, batch, schema=SCHEMA | self.config.input_read_schema, @@ -404,7 +397,7 @@ def convert(sample, space): if is_multi_agent: # TODO (simon): Add support for multi-agent episodes. - NotImplementedError + pass else: # Build a single-agent episode with a single row of the batch. episode = SingleAgentEpisode( @@ -521,7 +514,7 @@ def _map_sample_batch_to_episode( if is_multi_agent: # TODO (simon): Add support for multi-agent episodes. - NotImplementedError + pass else: # Unpack observations, if needed. Note, observations could # be either compressed by their entirety (the complete batch diff --git a/rllib/offline/tests/test_offline_data.py b/rllib/offline/tests/test_offline_data.py index 89586189bf47c..f872ffeeef88a 100644 --- a/rllib/offline/tests/test_offline_data.py +++ b/rllib/offline/tests/test_offline_data.py @@ -82,15 +82,16 @@ def test_sample_single_learner(self): self.assertIsInstance(algo.offline_data.learner_handles[0], Learner) # Now sample a batch from the data and ensure it is a `MultiAgentBatch`. - batch = algo.offline_data.sample(10) + batch = algo.offline_data.sample(10, num_shards=0, return_iterator=False) self.assertIsInstance(batch, MultiAgentBatch) self.assertEqual(batch.env_steps(), 10) # Now return an iterator. - iter = algo.offline_data.sample(num_samples=10, return_iterator=True) - from ray.data.iterator import _IterableFromIterator + iter = algo.offline_data.sample( + num_samples=10, num_shards=0, return_iterator=True + ) - self.assertIsInstance(iter, _IterableFromIterator) + self.assertIsInstance(iter[0], ray.data.DataIterator) # Tear down. algo.stop() diff --git a/rllib/offline/tests/test_offline_prelearner.py b/rllib/offline/tests/test_offline_prelearner.py index 42897f354314c..0480daf231e6a 100644 --- a/rllib/offline/tests/test_offline_prelearner.py +++ b/rllib/offline/tests/test_offline_prelearner.py @@ -8,6 +8,7 @@ from ray.rllib.algorithms.bc import BCConfig from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core import COMPONENT_RL_MODULE from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.offline.offline_prelearner import OfflinePreLearner @@ -72,10 +73,15 @@ def test_offline_prelearner_buffer_class(self): # Build the algorithm to get the learner. algo = self.config.build() - # Build the `OfflinePreLearner` and add the learner. + # Get the module state from the `Learner`(s). + module_state = algo.offline_data.learner_handles[0].get_state( + component=COMPONENT_RL_MODULE, + )[COMPONENT_RL_MODULE] + # Set up an `OfflinePreLearner` instance. oplr = OfflinePreLearner( config=self.config, - learner=algo.offline_data.learner_handles[0], + module_spec=algo.offline_data.module_spec, + module_state=module_state, ) # Ensure we have indeed a `PrioritizedEpisodeReplayBuffer` in the `PreLearner` @@ -190,10 +196,15 @@ def test_offline_prelearner_sample_from_old_sample_batch_data(self): # Build the algorithm to get the learner. algo = self.config.build() - # Build the `OfflinePreLearner` and add the learner. + # Get the module state from the `Learner`. + module_state = algo.offline_data.learner_handles[0].get_state( + component=COMPONENT_RL_MODULE, + )[COMPONENT_RL_MODULE] + # Set up an `OfflinePreLearner` instance. oplr = OfflinePreLearner( config=self.config, - learner=algo.offline_data.learner_handles[0], + module_spec=algo.offline_data.module_spec, + module_state=module_state, ) # Now, pull a batch of defined size formt he dataset. batch = algo.offline_data.data.take_batch( @@ -270,10 +281,15 @@ def test_offline_prelearner_sample_from_episode_data(self): episode_ds = ray.data.read_parquet(data_path) # Sample a batch of episodes from the episode dataset. episode_batch = episode_ds.take_batch(256) + # Get the module state from the `Learner`. + module_state = algo.offline_data.learner_handles[0].get_state( + component=COMPONENT_RL_MODULE, + )[COMPONENT_RL_MODULE] # Set up an `OfflinePreLearner` instance. oplr = OfflinePreLearner( config=self.config, - learner=algo.offline_data.learner_handles[0], + module_spec=algo.offline_data.module_spec, + module_state=module_state, spaces=algo.offline_data.spaces[INPUT_ENV_SPACES], ) # Sample a `MultiAgentBatch`. diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index ac40205de94ac..9645faf6e08f9 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -782,7 +782,7 @@ def _initialize_loss_from_dummy_batch( {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]} ) - self._loss_input_dict.update({k: v for k, v in train_batch.items()}) + self._loss_input_dict.update(dict(train_batch)) if log_once("loss_init"): logger.debug( diff --git a/rllib/policy/dynamic_tf_policy_v2.py b/rllib/policy/dynamic_tf_policy_v2.py index 7368696044bdc..1b127f3bef21a 100644 --- a/rllib/policy/dynamic_tf_policy_v2.py +++ b/rllib/policy/dynamic_tf_policy_v2.py @@ -733,7 +733,7 @@ def _initialize_loss_from_dummy_batch( {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]} ) - self._loss_input_dict.update({k: v for k, v in train_batch.items()}) + self._loss_input_dict.update(dict(train_batch)) if log_once("loss_init"): logger.debug( diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 36abaa36ad766..249cdb405376b 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -891,12 +891,12 @@ def _zero_pad_in_place(path, value): return self @ExperimentalAPI - def to_device(self, device, framework="torch"): + def to_device(self, device, framework: str = "torch", pin_memory: bool = False): """TODO: transfer batch to given device as framework tensor.""" if framework == "torch": assert torch is not None for k, v in self.items(): - self[k] = convert_to_torch_tensor(v, device) + self[k] = convert_to_torch_tensor(v, device, pin_memory=pin_memory) else: raise NotImplementedError return self @@ -1491,13 +1491,13 @@ def copy(self) -> "MultiAgentBatch": ) @ExperimentalAPI - def to_device(self, device, framework="torch"): + def to_device(self, device, framework="torch", pin_memory: bool = False): """TODO: transfer batch to given device as framework tensor.""" if framework == "torch": assert torch is not None for pid, policy_batch in self.policy_batches.items(): self.policy_batches[pid] = policy_batch.to_device( - device, framework=framework + device, framework=framework, pin_memory=pin_memory ) else: raise NotImplementedError diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 11c524f9c2bf4..abde75af1b55a 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -420,7 +420,7 @@ def compute_log_likelihoods( self._state_inputs, state_batches ) ) - builder.add_feed_dict({k: v for k, v in zip(self._state_inputs, state_batches)}) + builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) if state_batches: builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) # Prev-a and r. diff --git a/rllib/tests/test_eager_support.py b/rllib/tests/test_eager_support.py deleted file mode 100644 index 21acc3f664116..0000000000000 --- a/rllib/tests/test_eager_support.py +++ /dev/null @@ -1,110 +0,0 @@ -import unittest - -import ray -from ray import air, tune -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.utils.framework import try_import_tf -from ray.tune.registry import get_trainable_cls - -tf1, tf, tfv = try_import_tf() - - -def check_support(alg, config, test_eager=False, test_trace=True): - config["framework"] = "tf2" - config["log_level"] = "ERROR" - # Test both continuous and discrete actions. - for cont in [True, False]: - if cont and alg == "DQN": - continue - - if cont: - config["env"] = "Pendulum-v1" - else: - config["env"] = "CartPole-v1" - - a = get_trainable_cls(alg) - if test_eager: - print("tf-eager: alg={} cont.act={}".format(alg, cont)) - config["eager_tracing"] = False - tune.Tuner( - a, - param_space=config, - run_config=air.RunConfig(stop={TRAINING_ITERATION: 1}, verbose=1), - ).fit() - if test_trace: - config["eager_tracing"] = True - print("tf-eager-tracing: alg={} cont.act={}".format(alg, cont)) - tune.Tuner( - a, - param_space=config, - run_config=air.RunConfig(stop={TRAINING_ITERATION: 1}, verbose=1), - ).fit() - - -class TestEagerSupportPolicyGradient(unittest.TestCase): - def setUp(self): - ray.init(num_cpus=4) - - def tearDown(self): - ray.shutdown() - - def test_dqn(self): - check_support( - "DQN", - { - "num_env_runners": 0, - "num_steps_sampled_before_learning_starts": 0, - }, - ) - - def test_ppo(self): - check_support("PPO", {"num_env_runners": 0}) - - def test_appo(self): - check_support("APPO", {"num_env_runners": 1, "num_gpus": 0}) - - def test_impala(self): - check_support("IMPALA", {"num_env_runners": 1, "num_gpus": 0}, test_eager=True) - - -class TestEagerSupportOffPolicy(unittest.TestCase): - def setUp(self): - ray.init(num_cpus=4) - - def tearDown(self): - ray.shutdown() - - def test_dqn(self): - check_support( - "DQN", - { - "num_env_runners": 0, - "num_steps_sampled_before_learning_starts": 0, - }, - ) - - def test_sac(self): - check_support( - "SAC", - { - "num_env_runners": 0, - "num_steps_sampled_before_learning_starts": 0, - }, - ) - - -if __name__ == "__main__": - import sys - - # Don't test anything for version 2.x (all tests are eager anyways). - # TODO: (sven) remove entire file in the future. - if tfv == 2: - print("\tskip due to tf==2.x") - sys.exit(0) - - # One can specify the specific TestCase class to run. - # None for all unittest.TestCase classes in this file. - import pytest - - class_ = sys.argv[1] if len(sys.argv) > 1 else None - sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)])) diff --git a/rllib/tests/test_filters.py b/rllib/tests/test_filters.py deleted file mode 100644 index 0fde45e2efa4a..0000000000000 --- a/rllib/tests/test_filters.py +++ /dev/null @@ -1,83 +0,0 @@ -import numpy as np -import unittest - -import ray -from ray.rllib.utils.filter import RunningStat, MeanStdFilter - - -class RunningStatTest(unittest.TestCase): - def testRunningStat(self): - for shp in ((), (3,), (3, 4)): - li = [] - rs = RunningStat(shp) - for _ in range(5): - val = np.random.randn(*shp) - rs.push(val) - li.append(val) - m = np.mean(li, axis=0) - self.assertTrue(np.allclose(rs.mean, m)) - v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0) - self.assertTrue(np.allclose(rs.var, v)) - - def testCombiningStat(self): - for shape in [(), (3,), (3, 4)]: - li = [] - rs1 = RunningStat(shape) - rs2 = RunningStat(shape) - rs = RunningStat(shape) - for _ in range(5): - val = np.random.randn(*shape) - rs1.push(val) - rs.push(val) - li.append(val) - for _ in range(9): - rs2.push(val) - rs.push(val) - li.append(val) - rs1.update(rs2) - assert np.allclose(rs.mean, rs1.mean) - assert np.allclose(rs.std, rs1.std) - - -class MeanStdFilterTest(unittest.TestCase): - def testBasic(self): - for shape in [(), (3,), (3, 4, 4)]: - filt = MeanStdFilter(shape) - for i in range(5): - filt(np.ones(shape)) - self.assertEqual(filt.running_stats.n, 5) - self.assertEqual(filt.buffer.n, 5) - - filt2 = MeanStdFilter(shape) - filt2.sync(filt) - self.assertEqual(filt2.running_stats.n, 5) - self.assertEqual(filt2.buffer.n, 5) - - filt.reset_buffer() - self.assertEqual(filt.buffer.n, 0) - self.assertEqual(filt2.buffer.n, 5) - - filt.apply_changes(filt2, with_buffer=False) - self.assertEqual(filt.buffer.n, 0) - self.assertEqual(filt.running_stats.n, 10) - - filt.apply_changes(filt2, with_buffer=True) - self.assertEqual(filt.buffer.n, 5) - self.assertEqual(filt.running_stats.n, 15) - - -class FilterManagerTest(unittest.TestCase): - def setUp(self): - ray.init( - num_cpus=1, object_store_memory=1000 * 1024 * 1024, ignore_reinit_error=True - ) - - def tearDown(self): - ray.shutdown() - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/tests/test_gpus.py b/rllib/tests/test_gpus.py deleted file mode 100644 index 24511e14367ae..0000000000000 --- a/rllib/tests/test_gpus.py +++ /dev/null @@ -1,126 +0,0 @@ -import unittest - -import ray -from ray import air -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.utils.framework import try_import_torch -from ray import tune - -torch, _ = try_import_torch() - - -class TestGPUs(unittest.TestCase): - def test_gpus_in_non_local_mode(self): - # Non-local mode. - ray.init() - - actual_gpus = torch.cuda.device_count() - print(f"Actual GPUs found (by torch): {actual_gpus}") - - config = ( - PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .env_runners(num_env_runners=2) - .environment("CartPole-v1") - ) - - # Expect errors when we run a config w/ num_gpus>0 w/o a GPU - # and _fake_gpus=False. - for num_gpus in [0, 0.1, 1, actual_gpus + 4]: - # Only allow possible num_gpus_per_env_runner (so test would not - # block infinitely due to a down worker). - per_worker = ( - [0] if actual_gpus == 0 or actual_gpus < num_gpus else [0, 0.5, 1] - ) - for num_gpus_per_env_runner in per_worker: - for fake_gpus in [False] + ([] if num_gpus == 0 else [True]): - config.resources( - num_gpus=num_gpus, - _fake_gpus=fake_gpus, - ) - config.env_runners(num_gpus_per_env_runner=num_gpus_per_env_runner) - - print( - f"\n------------\nnum_gpus={num_gpus} " - f"num_gpus_per_env_runner={num_gpus_per_env_runner} " - f"_fake_gpus={fake_gpus}" - ) - - # Expect that Algorithm creation causes a num_gpu error. - if ( - actual_gpus < num_gpus + 2 * num_gpus_per_env_runner - and not fake_gpus - ): - # "Direct" RLlib (create Algorithm on the driver). - # Cannot run through ray.tune.Tuner().fit() as it would - # simply wait infinitely for the resources to - # become available. - print("direct RLlib") - self.assertRaisesRegex( - RuntimeError, - "Found 0 GPUs on your machine", - lambda: config.build(), - ) - # If actual_gpus >= num_gpus or faked, - # expect no error. - else: - print("direct RLlib") - algo = config.build() - algo.stop() - # Cannot run through ray.tune.Tuner().fit() w/ fake GPUs - # as it would simply wait infinitely for the - # resources to become available (even though, we - # wouldn't really need them). - if num_gpus == 0: - print("via ray.tune.Tuner().fit()") - tune.Tuner( - "PPO", - param_space=config, - run_config=air.RunConfig(stop={TRAINING_ITERATION: 0}), - ).fit() - ray.shutdown() - - def test_gpus_in_local_mode(self): - # Local mode. - ray.init(local_mode=True) - - actual_gpus_available = torch.cuda.device_count() - - config = ( - PPOConfig() - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .env_runners(num_env_runners=2) - .environment("CartPole-v1") - ) - - # Expect no errors in local mode. - for num_gpus in [0, 0.1, 1, actual_gpus_available + 4]: - print(f"num_gpus={num_gpus}") - for fake_gpus in [False, True]: - print(f"_fake_gpus={fake_gpus}") - config.resources(num_gpus=num_gpus, _fake_gpus=fake_gpus) - print("direct RLlib") - algo = config.build() - algo.stop() - print("via ray.tune.Tuner().fit()") - tune.Tuner( - "PPO", - param_space=config, - run_config=air.RunConfig(stop={TRAINING_ITERATION: 0}), - ).fit() - - ray.shutdown() - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/tests/test_io.py b/rllib/tests/test_io.py deleted file mode 100644 index be0c8aaafafba..0000000000000 --- a/rllib/tests/test_io.py +++ /dev/null @@ -1,459 +0,0 @@ -import glob -import json -import numpy as np -import os -import shutil -import tempfile -import unittest - -import ray -from ray.tune.registry import ( - register_input, - registry_get_input, - registry_contains_input, -) -from ray.rllib.algorithms.algorithm_config import AlgorithmConfig -from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.offline import ( - IOContext, - JsonWriter, - JsonReader, - InputReader, - ShuffledInput, - DatasetWriter, -) -from ray.rllib.offline.json_reader import from_json_data -from ray.rllib.offline.json_writer import _to_json_dict, _to_json -from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - EVALUATION_RESULTS, - NUM_ENV_STEPS_SAMPLED_LIFETIME, -) - -SAMPLES = SampleBatch( - { - "actions": np.array([1, 2, 3, 4]), - "obs": np.array([4, 5, 6, 7]), - "eps_id": [1, 1, 2, 3], - } -) - - -def make_sample_batch(i): - return SampleBatch({"actions": np.array([i, i, i]), "obs": np.array([i, i, i])}) - - -class AgentIOTest(unittest.TestCase): - def setUp(self): - ray.init() - self.test_dir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self.test_dir) - ray.shutdown() - - def write_outputs(self, output, fw, output_config=None): - config = ( - PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .environment("CartPole-v1") - .framework(fw) - .training(train_batch_size=250) - .offline_data( - output=output + (fw if output != "logdir" else ""), - output_config=output_config or {}, - ) - ) - algo = config.build() - print(algo.train()) - return algo - - def test_agent_output_ok(self): - self.write_outputs(self.test_dir, "torch") - # PPO has two workers, so we expect 2 output files. - self.assertEqual(len(os.listdir(self.test_dir + "torch")), 2) - reader = JsonReader(self.test_dir + "torch" + "/*.json") - reader.next() - - def test_agent_output_logdir(self): - """Test special value 'logdir' as Agent's output.""" - agent = self.write_outputs("logdir", "torch") - # PPO has two workers, so we expect 2 output files. - self.assertEqual(len(glob.glob(agent.logdir + "/output-*.json")), 2) - - def test_agent_output_infos(self): - """Verify that the infos dictionary is written to the output files. - - Note, with torch this is always the case.""" - output_config = {"store_infos": True} - self.write_outputs(self.test_dir, "torch", output_config=output_config) - # PPO has two workers, so we expect 2 output files. - self.assertEqual(len(os.listdir(self.test_dir + "torch")), 2) - reader = JsonReader(self.test_dir + "torch" + "/*.json") - data = reader.next() - data = convert_ma_batch_to_sample_batch(data) - self.assertTrue("infos" in data) - - def test_agent_input_dir(self): - config = ( - PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .environment("CartPole-v1") - .evaluation(off_policy_estimation_methods={}) - .training(train_batch_size=250) - ) - - self.write_outputs(self.test_dir, "torch") - config.offline_data( - input_=self.test_dir + "torch", - ) - print("WROTE TO: ", self.test_dir) - algo = config.build() - result = algo.train() - self.assertEqual( - result[f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}"], 250 - ) # read from input - self.assertTrue(np.isnan(result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN])) - - def test_split_by_episode(self): - splits = SAMPLES.split_by_episode() - self.assertEqual(len(splits), 3) - self.assertEqual(splits[0].count, 2) - self.assertEqual(splits[1].count, 1) - self.assertEqual(splits[2].count, 1) - - def test_agent_input_postprocessing_enabled(self): - config = ( - PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .environment("CartPole-v1") - .training(train_batch_size=250) - .offline_data( - postprocess_inputs=True, # adds back 'advantages' - ) - .evaluation(off_policy_estimation_methods={}) - ) - - self.write_outputs(self.test_dir, "torch") - config.offline_data(input_=self.test_dir + "torch") - - # Rewrite the files to drop advantages and value_targets for - # testing - for path in glob.glob(self.test_dir + "torch" + "/*.json"): - out = [] - with open(path) as f: - for line in f.readlines(): - data_string = json.loads(line) - data = from_json_data(data_string, None) - data = convert_ma_batch_to_sample_batch(data) - # Data won't contain rewards as these are not included - # in the write_outputs run (not needed in the - # SampleBatch). Flip out "rewards" for "advantages" - # just for testing. - data["rewards"] = data["advantages"] - del data["advantages"] - if "value_targets" in data: - del data["value_targets"] - out.append(_to_json_dict(data, [])) - with open(path, "w") as f: - for data in out: - f.write(json.dumps(data)) - - algo = config.build() - result = algo.train() - self.assertEqual( - result[f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}"], 250 - ) # read from input - self.assertTrue(np.isnan(result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN])) - algo.stop() - - def test_agent_input_eval_sampler(self): - config = ( - PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .environment("CartPole-v1") - .offline_data( - postprocess_inputs=True, # adds back 'advantages' - ) - .evaluation( - evaluation_interval=1, - evaluation_config=PPOConfig.overrides(input_="sampler"), - ) - ) - - self.write_outputs(self.test_dir, "torch") - config.offline_data(input_=self.test_dir + "torch") - algo = config.build() - result = algo.train() - assert np.isnan( - result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] - ), "episode reward should not be computed for offline data" - assert not np.isnan( - result[EVALUATION_RESULTS][ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] - ), "Did not see simulation results during evaluation" - algo.stop() - - def test_custom_input_procedure(self): - class CustomJsonReader(JsonReader): - def __init__(self, ioctx: IOContext): - super().__init__(ioctx.input_config["input_files"], ioctx) - - def input_creator(ioctx: IOContext) -> InputReader: - return ShuffledInput(CustomJsonReader(ioctx)) - - register_input("custom_input", input_creator) - test_input_procedure = [ - "custom_input", - input_creator, - "ray.rllib.examples.offline_rl.custom_input_api.CustomJsonReader", - ] - - for input_procedure in test_input_procedure: - - config = ( - PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .environment("CartPole-v1") - .offline_data(input_=input_procedure) - .evaluation(off_policy_estimation_methods={}) - ) - - self.write_outputs(self.test_dir, "torch") - config.offline_data(input_config={"input_files": self.test_dir + "torch"}) - algo = config.build() - result = algo.train() - self.assertEqual(result[NUM_ENV_STEPS_SAMPLED_LIFETIME], 4000) - self.assertTrue(np.isnan(result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN])) - algo.stop() - - def test_multiple_output_workers(self): - ray.shutdown() - ray.init(num_cpus=4, ignore_reinit_error=True) - - config = ( - PPOConfig() - .api_stack( - enable_env_runner_and_connector_v2=False, - enable_rl_module_and_learner=False, - ) - .environment("CartPole-v1") - .env_runners(num_env_runners=2) - .training(train_batch_size=500) - .evaluation(off_policy_estimation_methods={}) - ) - - config.offline_data(output=self.test_dir + "torch") - algo = config.build() - algo.train() - self.assertEqual(len(os.listdir(self.test_dir + "torch")), 2) - reader = JsonReader(self.test_dir + "torch" + "/*.json") - reader.next() - algo.stop() - - -class JsonIOTest(unittest.TestCase): - def setUp(self): - ray.init() - self.test_dir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self.test_dir) - ray.shutdown() - - def test_write_dataset(self): - ioctx = IOContext( - self.test_dir, - AlgorithmConfig().offline_data( - output="dataset", - output_config={ - "format": "json", - "path": self.test_dir, - "max_num_samples_per_file": 2, - }, - ), - 0, - None, - ) - writer = DatasetWriter(ioctx, compress_columns=["obs"]) - self.assertEqual(len(os.listdir(self.test_dir)), 0) - writer.write(SAMPLES) - writer.write(SAMPLES) - self.assertEqual(len(os.listdir(self.test_dir)), 1) - - def test_write_simple(self): - ioctx = IOContext(self.test_dir, {}, 0, None) - writer = JsonWriter( - self.test_dir, ioctx, max_file_size=1000, compress_columns=["obs"] - ) - self.assertEqual(len(os.listdir(self.test_dir)), 0) - writer.write(SAMPLES) - writer.write(SAMPLES) - self.assertEqual(len(os.listdir(self.test_dir)), 1) - - def test_write_file_uri(self): - ioctx = IOContext(self.test_dir, None, 0, None) - writer = JsonWriter( - "file://" + self.test_dir, - ioctx, - max_file_size=1000, - compress_columns=["obs"], - ) - self.assertEqual(len(os.listdir(self.test_dir)), 0) - writer.write(SAMPLES) - writer.write(SAMPLES) - self.assertEqual(len(os.listdir(self.test_dir)), 1) - - def test_write_paginate(self): - ioctx = IOContext(self.test_dir, AlgorithmConfig(), 0, None) - writer = JsonWriter( - self.test_dir, ioctx, max_file_size=5000, compress_columns=["obs"] - ) - self.assertEqual(len(os.listdir(self.test_dir)), 0) - for _ in range(100): - writer.write(SAMPLES) - num_files = len(os.listdir(self.test_dir)) - - # Pagination can't really be predicted: - # On travis, it seems to create only 2 files, but sometimes also - # 6, or 7. 12 or 13 usually on a Mac locally. - # Reasons: Different compressions, file-size interpretations, - # json writers? - assert num_files >= 2, "Expected >= 2 files, but found {} ({})".format( - num_files, os.listdir(self.test_dir) - ) - - def test_read_write(self): - ioctx = IOContext(self.test_dir, None, 0, None) - writer = JsonWriter( - self.test_dir, ioctx, max_file_size=5000, compress_columns=["obs"] - ) - for i in range(100): - writer.write(make_sample_batch(i)) - reader = JsonReader(self.test_dir + "/*.json") - seen_a = set() - seen_o = set() - for i in range(1000): - batch = reader.next() - seen_a.add(batch["actions"][0]) - seen_o.add(batch["obs"][0]) - self.assertGreater(len(seen_a), 90) - self.assertLess(len(seen_a), 101) - self.assertGreater(len(seen_o), 90) - self.assertLess(len(seen_o), 101) - - def test_skips_over_empty_lines_and_files(self): - open(self.test_dir + "/empty", "w").close() - with open(self.test_dir + "/f1", "w") as f: - f.write("\n") - f.write("\n") - f.write(_to_json(make_sample_batch(0), [])) - with open(self.test_dir + "/f2", "w") as f: - f.write(_to_json(make_sample_batch(1), [])) - f.write("\n") - reader = JsonReader( - [ - self.test_dir + "/empty", - self.test_dir + "/f1", - "file://" + self.test_dir + "/f2", - ] - ) - seen_a = set() - for i in range(100): - batch = reader.next() - seen_a.add(batch["actions"][0]) - self.assertEqual(len(seen_a), 2) - - def test_skips_over_corrupted_lines(self): - with open(self.test_dir + "/f1", "w") as f: - f.write(_to_json(make_sample_batch(0), [])) - f.write("\n") - f.write(_to_json(make_sample_batch(1), [])) - f.write("\n") - f.write(_to_json(make_sample_batch(2), [])) - f.write("\n") - f.write(_to_json(make_sample_batch(3), [])) - f.write("\n") - f.write("{..corrupted_json_record") - reader = JsonReader( - [ - self.test_dir + "/f1", - ] - ) - seen_a = set() - for i in range(10): - batch = reader.next() - seen_a.add(batch["actions"][0]) - self.assertEqual(len(seen_a), 4) - - def test_abort_on_all_empty_inputs(self): - open(self.test_dir + "/empty", "w").close() - reader = JsonReader( - [ - self.test_dir + "/empty", - ] - ) - self.assertRaises(ValueError, lambda: reader.next()) - with open(self.test_dir + "/empty1", "w") as f: - for _ in range(100): - f.write("\n") - with open(self.test_dir + "/empty2", "w") as f: - for _ in range(100): - f.write("\n") - reader = JsonReader( - [ - self.test_dir + "/empty1", - self.test_dir + "/empty2", - ] - ) - self.assertRaises(ValueError, lambda: reader.next()) - - def test_custom_input_registry(self): - config = AlgorithmConfig().offline_data(input_config={}) - ioctx = IOContext(self.test_dir, config, 0, None) - - class CustomInputReader(InputReader): - def __init__(self, ioctx: IOContext): - self.ioctx = ioctx - - def next(self): - return 0 - - def input_creator(ioctx: IOContext): - return ShuffledInput(CustomInputReader(ioctx)) - - register_input("custom_input", input_creator) - self.assertTrue(registry_contains_input("custom_input")) - creator = registry_get_input("custom_input") - self.assertIsNotNone(creator) - reader = creator(ioctx) - self.assertIsInstance(reader, ShuffledInput) - self.assertEqual(reader.next(), 0) - self.assertEqual(ioctx.log_dir, self.test_dir) - self.assertEqual(ioctx.config, config) - self.assertEqual(ioctx.worker_index, 0) - self.assertIsNone(ioctx.worker) - self.assertEqual(ioctx.input_config, {}) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/tests/test_reproducibility.py b/rllib/tests/test_reproducibility.py deleted file mode 100644 index 682fd1984ef53..0000000000000 --- a/rllib/tests/test_reproducibility.py +++ /dev/null @@ -1,81 +0,0 @@ -import gymnasium as gym -import numpy as np -import unittest - -import ray -from ray.rllib.algorithms.dqn import DQNConfig -from ray.rllib.utils.metrics import ( - EPISODE_RETURN_MAX, - EPISODE_RETURN_MIN, - ENV_RUNNER_RESULTS, -) -from ray.tune.registry import register_env - - -class TestReproducibility(unittest.TestCase): - def test_reproducing_trajectory(self): - class PickLargest(gym.Env): - def __init__(self): - self.observation_space = gym.spaces.Box( - low=float("-inf"), high=float("inf"), shape=(4,) - ) - self.action_space = gym.spaces.Discrete(4) - - def reset(self, *, seed=None, options=None): - self.obs = np.random.randn(4) - return self.obs, {} - - def step(self, action): - reward = self.obs[action] - return self.obs, reward, True, False, {} - - def env_creator(env_config): - return PickLargest() - - trajs = [] - for trial in range(3): - ray.init() - register_env("PickLargest", env_creator) - config = ( - DQNConfig() - .environment("PickLargest") - .debugging(seed=666 if trial in [0, 1] else 999) - .reporting( - min_time_s_per_iteration=0, - min_sample_timesteps_per_iteration=100, - ) - ) - algo = config.build() - - trajectory = list() - for _ in range(8): - r = algo.train() - trajectory.append(r[ENV_RUNNER_RESULTS][EPISODE_RETURN_MAX]) - trajectory.append(r[ENV_RUNNER_RESULTS][EPISODE_RETURN_MIN]) - trajs.append(trajectory) - - algo.stop() - ray.shutdown() - - # trial0 and trial1 use same seed and thus - # expect identical trajectories. - all_same = True - for v0, v1 in zip(trajs[0], trajs[1]): - if v0 != v1: - all_same = False - self.assertTrue(all_same) - - # trial1 and trial2 use different seeds and thus - # most rewards tend to be different. - diff_cnt = 0 - for v1, v2 in zip(trajs[1], trajs[2]): - if v1 != v2: - diff_cnt += 1 - self.assertTrue(diff_cnt > 8) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/tuned_examples/appo/cartpole-appo.yaml b/rllib/tuned_examples/appo/cartpole-appo.yaml deleted file mode 100644 index bfceaddcf02f6..0000000000000 --- a/rllib/tuned_examples/appo/cartpole-appo.yaml +++ /dev/null @@ -1,21 +0,0 @@ -# @OldAPIStack -cartpole-appo: - env: CartPole-v1 - run: APPO - stop: - env_runners/episode_return_mean: 180 - timesteps_total: 200000 - config: - # Works for both torch and tf. - framework: torch - num_envs_per_env_runner: 5 - num_env_runners: 4 - num_gpus: 0 - observation_filter: MeanStdFilter - num_epochs: 1 - vf_loss_coeff: 0.01 - vtrace: true - model: - fcnet_hiddens: [32] - fcnet_activation: linear - vf_share_layers: true diff --git a/rllib/tuned_examples/appo/pong-appo-w-rl-modules-and-learner.yaml b/rllib/tuned_examples/appo/pong-appo-w-rl-modules-and-learner.yaml deleted file mode 100644 index 2c11e896744ed..0000000000000 --- a/rllib/tuned_examples/appo/pong-appo-w-rl-modules-and-learner.yaml +++ /dev/null @@ -1,52 +0,0 @@ -# @OldAPIStack -# This can reach 18.0 reward in ~10 minutes on 4x M60 GPUs -# with 30 rollout workers, 4 learning workers, and 8 envs per rollout worker. -appo-pongnoframeskip-v5: - env: ale_py:ALE/Pong-v5 - run: APPO - stop: - env_runners/episode_return_mean: 18.0 - timesteps_total: 20000000 - config: - # Run with Learner- and RLModule API (new stack). - enable_rl_module_and_learner: true - # Make analogous to old v4 + NoFrameskip. - env_config: - frameskip: 1 - full_action_space: false - repeat_action_probability: 0.0 - vtrace: True - use_kl_loss: False - rollout_fragment_length: 50 - train_batch_size: 4000 - lr: 0.0006 - # On a 32 CPU machine (g3.2xlarge), we use 30 CPUs for the rollout workers - # and 2 for the learner workers. - num_env_runners: 31 - broadcast_interval: 1 - max_sample_requests_in_flight_per_worker: 1 - num_envs_per_env_runner: 8 - num_epochs: 2 - vf_loss_coeff: 1.0 - clip_param: 0.3 - - grad_clip: 10.0 - grad_clip_by: global_norm - model: - dim: 42 - conv_filters: [[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]] - conv_activation: relu - # Use a (256, 1)-MLP for values and (256, [num actions])-MLP - # for the policy. - post_fcnet_hiddens: [256] - - # Use N Learner worker on the GPU - num_learners: 4 - num_gpus_per_learner: 1 - num_gpus: 0 # No GPU needed for driver. - # Since we are using learner workers, the driver process does not need - # a CPU in particular. - num_cpus_for_main_process: 1 - # Need to unset this b/c we are using the RLModule API, which - # provides exploration control via the RLModule's `forward_exploration` method. - exploration_config: {} diff --git a/rllib/tuned_examples/appo/pong-appo.yaml b/rllib/tuned_examples/appo/pong-appo.yaml deleted file mode 100644 index 3b1ecd9215cba..0000000000000 --- a/rllib/tuned_examples/appo/pong-appo.yaml +++ /dev/null @@ -1,37 +0,0 @@ -# @OldAPIStack -# This can reach 18-19 reward in ~5-7 minutes on a Titan XP GPU -# with 32 workers and 8 envs per worker. IMPALA, when ran with -# similar configurations, solved Pong in 10-12 minutes. -# APPO can also solve Pong in 2.5 million timesteps, which is -# 2x more efficient than that of IMPALA. -pong-appo: - env: ale_py:ALE/Pong-v5 - run: APPO - stop: - env_runners/episode_return_mean: 18.0 - timesteps_total: 5000000 - config: - # Works for both torch and tf. - framework: torch - # Make analogous to old v4 + NoFrameskip. - env_config: - frameskip: 1 # no frameskip - full_action_space: false - repeat_action_probability: 0.0 # deterministic - vtrace: true - use_kl_loss: false - rollout_fragment_length: 50 - train_batch_size: 750 - num_env_runners: 32 - broadcast_interval: 1 - max_sample_requests_in_flight_per_worker: 1 - num_multi_gpu_tower_stacks: 1 - num_envs_per_env_runner: 8 - minibatch_buffer_size: 4 - num_epochs: 2 - vf_loss_coeff: 1.0 - clip_param: 0.3 - num_gpus: 1 - grad_clip: 10 - model: - dim: 42 diff --git a/rllib/tuned_examples/appo/pong_appo.py b/rllib/tuned_examples/appo/pong_appo.py index dfed9e1bdb94f..7cf639eb6cd23 100644 --- a/rllib/tuned_examples/appo/pong_appo.py +++ b/rllib/tuned_examples/appo/pong_appo.py @@ -60,7 +60,7 @@ def _env_creator(cfg): .training( learner_connector=_make_learner_connector, train_batch_size_per_learner=500, - target_network_update_freq=4, + target_network_update_freq=2, lr=0.0005 * ((args.num_learners or 1) ** 0.5), vf_loss_coeff=1.0, entropy_coeff=[[0, 0.01], [3000000, 0.0]], # <- crucial parameter to finetune diff --git a/rllib/tuned_examples/bc/cartpole_bc.py b/rllib/tuned_examples/bc/cartpole_bc.py index 393860f1d2231..fbfd864b754ca 100644 --- a/rllib/tuned_examples/bc/cartpole_bc.py +++ b/rllib/tuned_examples/bc/cartpole_bc.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path from ray.air.constants import TRAINING_ITERATION @@ -43,7 +44,7 @@ # Concurrency defines the number of processes that run the # `map_batches` transformations. This should be aligned with the # 'prefetch_batches' argument in 'iter_batches_kwargs'. - map_batches_kwargs={"concurrency": 2, "num_cpus": 2}, + map_batches_kwargs={"concurrency": 2, "num_cpus": 1}, # This data set is small so do not prefetch too many batches and use no # local shuffle. iter_batches_kwargs={"prefetch_batches": 1}, @@ -51,7 +52,7 @@ # mode in a single RLlib training iteration. Leave this to `None` to # run an entire epoch on the dataset during a single RLlib training # iteration. For single-learner mode, 1 is the only option. - dataset_num_iters_per_learner=1 if not args.num_learners else None, + dataset_num_iters_per_learner=5, ) .training( train_batch_size_per_learner=1024, @@ -72,6 +73,14 @@ ) ) +if not args.no_tune: + warnings.warn( + "You are running the example with Ray Tune. Offline RL uses " + "Ray Data, which doesn't does not interact seamlessly with Ray Tune. " + "If you encounter difficulties try to run the example without " + "Ray Tune using `--no-tune`." + ) + stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 350.0, TRAINING_ITERATION: 350, diff --git a/rllib/tuned_examples/cql/halfcheetah-bc.yaml b/rllib/tuned_examples/cql/halfcheetah-bc.yaml index 1036f4e25aaef..8096cc2246cc0 100644 --- a/rllib/tuned_examples/cql/halfcheetah-bc.yaml +++ b/rllib/tuned_examples/cql/halfcheetah-bc.yaml @@ -48,4 +48,3 @@ halfcheetah_bc: lagrangian: False evaluation_config: input: sampler - diff --git a/rllib/tuned_examples/cql/hopper-bc.yaml b/rllib/tuned_examples/cql/hopper-bc.yaml index 28985e7308bcd..5c82d80f4cdba 100644 --- a/rllib/tuned_examples/cql/hopper-bc.yaml +++ b/rllib/tuned_examples/cql/hopper-bc.yaml @@ -48,4 +48,3 @@ hopper_bc: lagrangian: False evaluation_config: input: sampler - diff --git a/rllib/tuned_examples/cql/hopper-cql.yaml b/rllib/tuned_examples/cql/hopper-cql.yaml index c8e04e449c8c8..690bc58fe4c4e 100644 --- a/rllib/tuned_examples/cql/hopper-cql.yaml +++ b/rllib/tuned_examples/cql/hopper-cql.yaml @@ -48,4 +48,3 @@ hopper_cql: lagrangian: False evaluation_config: input: sampler - diff --git a/rllib/tuned_examples/cql/pendulum_cql.py b/rllib/tuned_examples/cql/pendulum_cql.py index dff5446a0e0aa..33cdbafaac512 100644 --- a/rllib/tuned_examples/cql/pendulum_cql.py +++ b/rllib/tuned_examples/cql/pendulum_cql.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path from ray.rllib.algorithms.cql.cql import CQLConfig @@ -32,17 +33,12 @@ config = ( CQLConfig() .environment("Pendulum-v1") - # Use the new API stack. - .api_stack( - enable_env_runner_and_connector_v2=True, - enable_rl_module_and_learner=True, - ) .offline_data( input_=[data_path.as_posix()], # The `kwargs` for the `map_batches` method in which our # `OfflinePreLearner` is run. 2 data workers should be run # concurrently. - map_batches_kwargs={"concurrency": 2, "num_cpus": 2}, + map_batches_kwargs={"concurrency": 2, "num_cpus": 1}, # The `kwargs` for the `iter_batches` method. Due to the small # dataset we choose only a single batch to prefetch. iter_batches_kwargs={"prefetch_batches": 1}, @@ -50,7 +46,7 @@ # mode in a single RLlib training iteration. Leave this to `None` to # run an entire epoch on the dataset during a single RLlib training # iteration. For single-learner mode 1 is the only option. - dataset_num_iters_per_learner=1 if not args.num_learners else None, + dataset_num_iters_per_learner=5, # TODO (sven): Has this any influence in the connectors? actions_in_input_normalized=True, ) @@ -81,6 +77,14 @@ ) ) +if not args.no_tune: + warnings.warn( + "You are running the example with Ray Tune. Offline RL uses " + "Ray Data, which doesn't does not interact seamlessly with Ray Tune. " + "If you encounter difficulties try to run the example without " + "Ray Tune using `--no-tune`." + ) + stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -700.0, NUM_ENV_STEPS_SAMPLED_LIFETIME: 800000, diff --git a/rllib/tuned_examples/dqn/cartpole-dqn-softq.yaml b/rllib/tuned_examples/dqn/cartpole-dqn-softq.yaml index a16514e3a3647..1f18d28cd7a5f 100644 --- a/rllib/tuned_examples/dqn/cartpole-dqn-softq.yaml +++ b/rllib/tuned_examples/dqn/cartpole-dqn-softq.yaml @@ -14,4 +14,4 @@ cartpole-dqn: n_step: 3 exploration_config: type: SoftQ - temperature: 0.5 \ No newline at end of file + temperature: 0.5 diff --git a/rllib/tuned_examples/marwil/cartpole_marwil.py b/rllib/tuned_examples/marwil/cartpole_marwil.py index 9536dc4b1f897..d31836b93960a 100644 --- a/rllib/tuned_examples/marwil/cartpole_marwil.py +++ b/rllib/tuned_examples/marwil/cartpole_marwil.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path from ray.rllib.algorithms.marwil import MARWILConfig @@ -32,10 +33,7 @@ config = ( MARWILConfig() .environment(env="CartPole-v1") - .api_stack( - enable_rl_module_and_learner=True, - enable_env_runner_and_connector_v2=True, - ) + # Evaluate every 3 training iterations. .evaluation( evaluation_interval=3, evaluation_num_env_runners=1, @@ -52,7 +50,7 @@ # The `kwargs` for the `map_batches` method in which our # `OfflinePreLearner` is run. 2 data workers should be run # concurrently. - map_batches_kwargs={"concurrency": 2, "num_cpus": 2}, + map_batches_kwargs={"concurrency": 2, "num_cpus": 1}, # The `kwargs` for the `iter_batches` method. Due to the small # dataset we choose only a single batch to prefetch. iter_batches_kwargs={"prefetch_batches": 1}, @@ -60,7 +58,7 @@ # mode in a single RLlib training iteration. Leave this to `None` to # run an entire epoch on the dataset during a single RLlib training # iteration. For single-learner mode 1 is the only option. - dataset_num_iters_per_learner=1 if not args.num_learners else None, + dataset_num_iters_per_learner=5, ) .training( beta=1.0, @@ -71,6 +69,14 @@ ) ) +if not args.no_tune: + warnings.warn( + "You are running the example with Ray Tune. Offline RL uses " + "Ray Data, which doesn't does not interact seamlessly with Ray Tune. " + "If you encounter difficulties try to run the example without " + "Ray Tune using `--no-tune`." + ) + stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 250.0, NUM_ENV_STEPS_SAMPLED_LIFETIME: 500000, diff --git a/rllib/utils/actor_manager.py b/rllib/utils/actor_manager.py index 1dc1401fed185..85500a4e483fa 100644 --- a/rllib/utils/actor_manager.py +++ b/rllib/utils/actor_manager.py @@ -515,7 +515,8 @@ def foreach_actor_async( for i in range(len(func)) ] # Update our round-robin pointer. - self._current_actor_id += len(func) % self.num_actors() + self._current_actor_id += len(func) + self._current_actor_id %= self.num_actors() if healthy_only: func, remote_actor_ids = self._filter_func_and_remote_actor_id_by_state( diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 44143c766f6f1..c0b9a28fa4726 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -6,6 +6,7 @@ import tree # pip install dm_tree +import ray from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import ( @@ -76,7 +77,12 @@ def get_device(config: "AlgorithmConfig", num_gpus_requested: int = 1): # the user the option to run on the gpu of their choice, so we enable that # option here through `config.local_gpu_idx`. devices = get_devices() - if len(devices) == 1: + # Note, if we have a single learner and we do not run on Ray Tune, the local + # learner is not an Ray actor and Ray does not manage devices for it. + if ( + len(devices) == 1 + and ray._private.worker._mode() == ray._private.worker.WORKER_MODE + ): return devices[0] else: assert config.local_gpu_idx < torch.cuda.device_count(), ( diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py index 96fd31b88da22..e7f8d4a425128 100644 --- a/rllib/utils/metrics/__init__.py +++ b/rllib/utils/metrics/__init__.py @@ -81,9 +81,10 @@ NUM_EPISODES_LIFETIME = "num_episodes_lifetime" TIME_BETWEEN_SAMPLING = "time_between_sampling" - +DATASET_NUM_ITERS_TRAINED = "dataset_num_iters_trained" +DATASET_NUM_ITERS_TRAINED_LIFETIME = "dataset_num_iters_trained_lifetime" MEAN_NUM_LEARNER_GROUP_UPDATE_CALLED = "mean_num_learner_group_update_called" -MEAN_NUM_LEARNER_GROUP_RESULTS_RECEIVED = "mean_num_learner_group_results_received" +MEAN_NUM_LEARNER_RESULTS_RECEIVED = "mean_num_learner_results_received" NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained" NUM_AGENT_STEPS_TRAINED_LIFETIME = "num_agent_steps_trained_lifetime" NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter" # @OldAPIStack diff --git a/rllib/utils/metrics/stats.py b/rllib/utils/metrics/stats.py index 9fe47964d6a54..1fd14a7b28349 100644 --- a/rllib/utils/metrics/stats.py +++ b/rllib/utils/metrics/stats.py @@ -682,7 +682,7 @@ def _reduced_values(self, values=None, window=None) -> Tuple[Any, Any]: if self._reduce_method is None: return values, values - # Special case: Internal values list is empty -> return NaN. + # Special case: Internal values list is empty -> return NaN or 0.0 for sum. elif len(values) == 0: if self._reduce_method in ["min", "max", "mean"]: return float("nan"), [] diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 9f739ee9aa1c8..6da24e42e2728 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -292,6 +292,13 @@ def add_rllib_example_script_args( help="The number of Learners to use. If `None`, use the algorithm's default " "value.", ) + parser.add_argument( + "--num-cpus-per-learner", + type=float, + default=None, + help="The number of CPUs per Learner to use. If `None`, use the algorithm's " + "default value.", + ) parser.add_argument( "--num-gpus-per-learner", type=float, @@ -1148,8 +1155,7 @@ def run_rllib_example_script_experiment( if num_gpus_available >= num_gpus_needed_if_available: config.learners(num_gpus_per_learner=1) else: - config.learners(num_gpus_per_learner=0, num_cpus_per_learner=1) - + config.learners(num_gpus_per_learner=0) # User hard-requires n GPUs, but they are not available -> Error. elif num_gpus_available < num_gpus_requested: raise ValueError( @@ -1164,6 +1170,10 @@ def run_rllib_example_script_experiment( else: config.learners(num_gpus_per_learner=args.num_gpus_per_learner) + # Set CPUs per Learner. + if args.num_cpus_per_learner is not None: + config.learners(num_cpus_per_learner=args.num_cpus_per_learner) + # Old stack (override only if arg was provided by user). elif args.num_gpus is not None: config.resources(num_gpus=args.num_gpus) diff --git a/rllib/utils/torch_utils.py b/rllib/utils/torch_utils.py index e3783d0583c5f..9ba343ea5c9d0 100644 --- a/rllib/utils/torch_utils.py +++ b/rllib/utils/torch_utils.py @@ -292,7 +292,7 @@ def mapping(item): if pin_memory and torch.cuda.is_available(): tensor.pin_memory() - return tensor if device is None else tensor.to(device) + return tensor if device is None else tensor.to(device, non_blocking=True) return tree.map_structure(mapping, x) diff --git a/setup_hooks.sh b/setup_hooks.sh index 1061f708c9118..0cc004056683a 100755 --- a/setup_hooks.sh +++ b/setup_hooks.sh @@ -12,4 +12,3 @@ RELATIVE_PATH="../../ci/lint" ln -sf "${RELATIVE_PATH}/pre-push" "${ROOT}/.git/hooks/pre-push" ln -sf "${RELATIVE_PATH}/prepare-commit-msg" "${ROOT}/.git/hooks/prepare-commit-msg" - diff --git a/src/ray/common/BUILD b/src/ray/common/BUILD index f5e6b9d8b9cd1..cda0f5ee1d6ee 100644 --- a/src/ray/common/BUILD +++ b/src/ray/common/BUILD @@ -74,15 +74,12 @@ ray_cc_library( ray_cc_library( name = "file_system_monitor", - srcs = [ - "file_system_monitor.cc", - ], - hdrs = [ - "file_system_monitor.h", - ], + srcs = ["file_system_monitor.cc"], + hdrs = ["file_system_monitor.h"], deps = [ ":asio", "//src/ray/util", + "//src/ray/util:event", "@com_google_googletest//:gtest_prod", ], ) @@ -139,6 +136,7 @@ ray_cc_library( "//src/ray/protobuf:common_cc_proto", "//src/ray/protobuf:gcs_cc_proto", "//src/ray/util", + "//src/ray/util:random", "@com_github_google_flatbuffers//:flatbuffers", "@msgpack", ], @@ -211,6 +209,7 @@ ray_cc_library( ":event_stats", ":ray_config", "//src/ray/util", + "//src/ray/util:array", "@boost//:asio", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/synchronization", @@ -283,7 +282,12 @@ ray_cc_library( hdrs = ["status.h"], deps = [ ":source_location", - "//src/ray/util", + "//src/ray/util:logging", + "//src/ray/util:macros", + "//src/ray/util:visibility", + "@boost//:system", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", ], ) diff --git a/src/ray/common/buffer.h b/src/ray/common/buffer.h index 0ce0829d0af73..a4607e60ce292 100644 --- a/src/ray/common/buffer.h +++ b/src/ray/common/buffer.h @@ -21,6 +21,7 @@ #include "ray/common/status.h" #include "ray/thirdparty/aligned_alloc.h" +#include "ray/util/logging.h" #define BUFFER_ALIGNMENT 64 diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index 0542186d2388a..46f46bf99f35a 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -170,4 +170,4 @@ static inline ray::rpc::ActorTableData::ActorState StringToActorState( RAY_CHECK(false) << "Invalid actor state name:" << actor_state_name; return {}; } -} \ No newline at end of file +} diff --git a/src/ray/common/status.h b/src/ray/common/status.h index e8206ba90aaf6..c7bd759352b37 100644 --- a/src/ray/common/status.h +++ b/src/ray/common/status.h @@ -32,20 +32,15 @@ #include #include +#include "absl/strings/str_cat.h" #include "ray/common/source_location.h" #include "ray/util/logging.h" #include "ray/util/macros.h" #include "ray/util/visibility.h" -namespace boost { - -namespace system { - +namespace boost::system { class error_code; - -} // namespace system - -} // namespace boost +} // namespace boost::system // Return the given status if it is not OK. #define RAY_RETURN_NOT_OK(s) \ @@ -328,6 +323,12 @@ class RAY_EXPORT Status { std::string message() const { return ok() ? "" : state_->msg; } + template + Status &operator<<(T &&...msg) { + absl::StrAppend(&state_->msg, std::forward(msg)...); + return *this; + } + private: struct State { StatusCode code; diff --git a/src/ray/common/test/BUILD b/src/ray/common/test/BUILD index 75f15f741c6ba..08bf2a2695238 100644 --- a/src/ray/common/test/BUILD +++ b/src/ray/common/test/BUILD @@ -160,6 +160,7 @@ ray_cc_test( ray_cc_library( name = "testing", hdrs = ["testing.h"], + deps = ["//src/ray/util:macros"], testonly = True, ) diff --git a/src/ray/core_worker/transport/sequential_actor_submit_queue.cc b/src/ray/core_worker/transport/sequential_actor_submit_queue.cc index 35de4e41bb99f..3a396841e5fad 100644 --- a/src/ray/core_worker/transport/sequential_actor_submit_queue.cc +++ b/src/ray/core_worker/transport/sequential_actor_submit_queue.cc @@ -128,4 +128,4 @@ void SequentialActorSubmitQueue::MarkSeqnoCompleted(uint64_t sequence_no, } } // namespace core -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/src/ray/core_worker/transport/sequential_actor_submit_queue.h b/src/ray/core_worker/transport/sequential_actor_submit_queue.h index 3e5e42a8119a4..75cbf49ce2096 100644 --- a/src/ray/core_worker/transport/sequential_actor_submit_queue.h +++ b/src/ray/core_worker/transport/sequential_actor_submit_queue.h @@ -138,4 +138,4 @@ class SequentialActorSubmitQueue : public IActorSubmitQueue { std::map out_of_order_completed_tasks; }; } // namespace core -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/src/ray/gcs/gcs_client/accessor.cc b/src/ray/gcs/gcs_client/accessor.cc index 204552c1f30d8..b2caf9a755dec 100644 --- a/src/ray/gcs/gcs_client/accessor.cc +++ b/src/ray/gcs/gcs_client/accessor.cc @@ -1494,6 +1494,26 @@ Status AutoscalerStateAccessor::GetClusterStatus(int64_t timeout_ms, return Status::OK(); } +Status AutoscalerStateAccessor::AsyncGetClusterStatus( + int64_t timeout_ms, + const OptionalItemCallback &callback) { + rpc::autoscaler::GetClusterStatusRequest request; + rpc::autoscaler::GetClusterStatusRequest reply; + + client_impl_->GetGcsRpcClient().GetClusterStatus( + request, + [callback](const Status &status, rpc::autoscaler::GetClusterStatusReply &&reply) { + if (!status.ok()) { + callback(status, std::nullopt); + return; + } + callback(Status::OK(), std::move(reply)); + }, + timeout_ms); + + return Status::OK(); +} + Status AutoscalerStateAccessor::ReportAutoscalingState( int64_t timeout_ms, const std::string &serialized_state) { rpc::autoscaler::ReportAutoscalingStateRequest request; diff --git a/src/ray/gcs/gcs_client/accessor.h b/src/ray/gcs/gcs_client/accessor.h index dce21418099b1..2b48f45fcd2d2 100644 --- a/src/ray/gcs/gcs_client/accessor.h +++ b/src/ray/gcs/gcs_client/accessor.h @@ -22,6 +22,7 @@ #include "ray/gcs/callback.h" #include "ray/rpc/client_call.h" #include "ray/util/sequencer.h" +#include "src/ray/protobuf/autoscaler.pb.h" #include "src/ray/protobuf/gcs.pb.h" #include "src/ray/protobuf/gcs_service.pb.h" @@ -991,6 +992,10 @@ class AutoscalerStateAccessor { virtual Status GetClusterStatus(int64_t timeout_ms, std::string &serialized_reply); + virtual Status AsyncGetClusterStatus( + int64_t timeout_ms, + const OptionalItemCallback &callback); + virtual Status ReportAutoscalingState(int64_t timeout_ms, const std::string &serialized_state); diff --git a/src/ray/gcs/gcs_server/pubsub_handler.cc b/src/ray/gcs/gcs_server/pubsub_handler.cc index be224cd78299d..45fe0e0e46af2 100644 --- a/src/ray/gcs/gcs_server/pubsub_handler.cc +++ b/src/ray/gcs/gcs_server/pubsub_handler.cc @@ -25,8 +25,8 @@ void InternalPubSubHandler::HandleGcsPublish(rpc::GcsPublishRequest request, rpc::GcsPublishReply *reply, rpc::SendReplyCallback send_reply_callback) { RAY_LOG(DEBUG) << "received publish request: " << request.DebugString(); - for (const auto &msg : request.pub_messages()) { - gcs_publisher_.GetPublisher().Publish(msg); + for (auto &&msg : std::move(*request.mutable_pub_messages())) { + gcs_publisher_.GetPublisher().Publish(std::move(msg)); } send_reply_callback(Status::OK(), nullptr, nullptr); } diff --git a/src/ray/object_manager/plasma/test/stats_collector_test.cc b/src/ray/object_manager/plasma/test/stats_collector_test.cc index 03991705e8b87..46106107da304 100644 --- a/src/ray/object_manager/plasma/test/stats_collector_test.cc +++ b/src/ray/object_manager/plasma/test/stats_collector_test.cc @@ -300,4 +300,4 @@ TEST_F(ObjectStatsCollectorTest, RefCountPassThrough) { manager_->DeleteObject(id2); ExpectStatsMatch(); } -} // namespace plasma \ No newline at end of file +} // namespace plasma diff --git a/src/ray/ray_exported_symbols.lds b/src/ray/ray_exported_symbols.lds index bfbc2197795aa..9d57bf03a6fef 100644 --- a/src/ray/ray_exported_symbols.lds +++ b/src/ray/ray_exported_symbols.lds @@ -34,4 +34,4 @@ *Java_io_ray* _JNI_On* *aligned_free* -*aligned_malloc* \ No newline at end of file +*aligned_malloc* diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index de4aa7e047ee1..117d0a506615f 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -177,7 +177,12 @@ void WorkerPool::Start() { } if (RayConfig::instance().enable_worker_prestart()) { - PrestartDefaultCpuWorkers(Language::PYTHON, num_prestart_python_workers); + rpc::TaskSpec rpc_task_spec; + rpc_task_spec.set_language(Language::PYTHON); + rpc_task_spec.mutable_runtime_env_info()->set_serialized_runtime_env("{}"); + + TaskSpecification task_spec{std::move(rpc_task_spec)}; + PrestartWorkersInternal(task_spec, num_prestart_python_workers); } } @@ -898,7 +903,12 @@ Status WorkerPool::RegisterDriver(const std::shared_ptr &driver if (!first_job_registered_ && RayConfig::instance().prestart_worker_first_driver() && !RayConfig::instance().enable_worker_prestart()) { RAY_LOG(DEBUG) << "PrestartDefaultCpuWorkers " << num_prestart_python_workers; - PrestartDefaultCpuWorkers(Language::PYTHON, num_prestart_python_workers); + rpc::TaskSpec rpc_task_spec; + rpc_task_spec.set_language(Language::PYTHON); + rpc_task_spec.mutable_runtime_env_info()->set_serialized_runtime_env("{}"); + + TaskSpecification task_spec{std::move(rpc_task_spec)}; + PrestartWorkersInternal(task_spec, num_prestart_python_workers); } // Invoke the `send_reply_callback` later to only finish driver @@ -1448,10 +1458,8 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, << task_spec.DebugString() << " has runtime env " << task_spec.HasRuntimeEnv(); if ((task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) || - task_spec.HasRuntimeEnv() || task_spec.GetLanguage() != ray::Language::PYTHON) { + task_spec.GetLanguage() != ray::Language::PYTHON) { return; // Not handled. - // TODO(architkulkarni): We'd eventually like to prestart workers with the same - // runtime env to improve initial startup performance. } auto &state = GetStateForLanguage(task_spec.GetLanguage()); @@ -1470,21 +1478,45 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, << backlog_size << " and available CPUs " << num_available_cpus << " num idle workers " << state.idle.size() << " num registered workers " << state.registered_workers.size(); - PrestartDefaultCpuWorkers(task_spec.GetLanguage(), num_needed); + PrestartWorkersInternal(task_spec, num_needed); } } -void WorkerPool::PrestartDefaultCpuWorkers(ray::Language language, int64_t num_needed) { - // default workers don't use runtime env. - RAY_LOG(DEBUG) << "PrestartDefaultCpuWorkers " << num_needed; - for (int i = 0; i < num_needed; i++) { - PopWorkerStatus status; - StartWorkerProcess(language, - rpc::WorkerType::WORKER, - JobID::Nil(), - &status, - /*dynamic_options*/ {}, - CalculateRuntimeEnvHash("{}")); +void WorkerPool::PrestartWorkersInternal(const TaskSpecification &task_spec, + int64_t num_needed) { + RAY_LOG(DEBUG) << "PrestartWorkers " << num_needed; + for (int ii = 0; ii < num_needed; ++ii) { + // Prestart worker with no runtime env. + if (IsRuntimeEnvEmpty(task_spec.SerializedRuntimeEnv())) { + PopWorkerStatus status; + StartWorkerProcess( + task_spec.GetLanguage(), rpc::WorkerType::WORKER, task_spec.JobId(), &status); + continue; + } + + // Prestart worker with runtime env. + GetOrCreateRuntimeEnv( + task_spec.SerializedRuntimeEnv(), + task_spec.RuntimeEnvConfig(), + task_spec.JobId(), + [this, task_spec = task_spec](bool successful, + const std::string &serialized_runtime_env_context, + const std::string &setup_error_message) { + if (!successful) { + RAY_LOG(ERROR) << "Fails to create or get runtime env " + << setup_error_message; + return; + } + PopWorkerStatus status; + StartWorkerProcess(task_spec.GetLanguage(), + rpc::WorkerType::WORKER, + task_spec.JobId(), + &status, + /*dynamic_options=*/{}, + task_spec.GetRuntimeEnvHash(), + serialized_runtime_env_context, + task_spec.RuntimeEnvInfo()); + }); } } diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 1c114ebca8da1..cd3af644d5122 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -426,9 +426,7 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// We aim to prestart 1 worker per CPU, up to the the backlog size. void PrestartWorkers(const TaskSpecification &task_spec, int64_t backlog_size); - /// Try to prestart a number of CPU workers with the given language. - /// - void PrestartDefaultCpuWorkers(ray::Language language, int64_t num_needed); + void PrestartWorkersInternal(const TaskSpecification &task_spec, int64_t num_needed); /// Return the current size of the worker pool for the requested language. Counts only /// idle workers. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 73447daae2e7f..29d6d378fcc88 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -651,6 +651,24 @@ TEST_F(WorkerPoolDriverRegisteredTest, TestPrestartingWorkers) { ASSERT_EQ(worker_pool_->NumWorkersStarting(), POOL_SIZE_SOFT_LIMIT); } +TEST_F(WorkerPoolDriverRegisteredTest, TestPrestartingWorkersWithRuntimeEnv) { + auto task_spec = ExampleTaskSpec(); + task_spec.GetMutableMessage().mutable_runtime_env_info()->set_serialized_runtime_env( + "{\"env_vars\": {\"FOO\": \"bar\"}}"); + // Prestarts 2 workers. + worker_pool_->PrestartWorkers(task_spec, 2); + ASSERT_EQ(worker_pool_->NumWorkersStarting(), 2); + // Prestarts 1 more worker. + worker_pool_->PrestartWorkers(task_spec, 3); + ASSERT_EQ(worker_pool_->NumWorkersStarting(), 3); + // No more needed. + worker_pool_->PrestartWorkers(task_spec, 1); + ASSERT_EQ(worker_pool_->NumWorkersStarting(), 3); + // Capped by soft limit. + worker_pool_->PrestartWorkers(task_spec, 20); + ASSERT_EQ(worker_pool_->NumWorkersStarting(), POOL_SIZE_SOFT_LIMIT); +} + TEST_F(WorkerPoolDriverRegisteredTest, HandleWorkerPushPop) { std::shared_ptr popped_worker; const auto task_spec = ExampleTaskSpec(); @@ -2124,8 +2142,7 @@ TEST_F(WorkerPoolDriverRegisteredTest, TestIOWorkerFailureAndSpawn) { TEST_F(WorkerPoolDriverRegisteredTest, WorkerReuseForPrestartedWorker) { const auto task_spec = ExampleTaskSpec(); - - worker_pool_->PrestartDefaultCpuWorkers(ray::Language::PYTHON, 1); + worker_pool_->PrestartWorkersInternal(task_spec, /*num_needed=*/1); worker_pool_->PushWorkers(0, task_spec.JobId()); // One worker process has been prestarted. ASSERT_EQ(worker_pool_->GetProcessSize(), 1); diff --git a/src/ray/rpc/test/grpc_bench/Dockerfile b/src/ray/rpc/test/grpc_bench/Dockerfile index fea0434857d40..5e6d45f70812c 100644 --- a/src/ray/rpc/test/grpc_bench/Dockerfile +++ b/src/ray/rpc/test/grpc_bench/Dockerfile @@ -4,4 +4,3 @@ RUN apt-get update -y && apt-get install -y libjemalloc-dev COPY grpc_bench / ENTRYPOINT LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.2 /grpc_bench - diff --git a/src/ray/thirdparty/.clang-format b/src/ray/thirdparty/.clang-format index 37f3d57668fb9..e3845288a2aec 100644 --- a/src/ray/thirdparty/.clang-format +++ b/src/ray/thirdparty/.clang-format @@ -1 +1 @@ -DisableFormat: true \ No newline at end of file +DisableFormat: true diff --git a/src/ray/util/BUILD b/src/ray/util/BUILD index 2afc20d3fa84b..f0dedfb839f9f 100644 --- a/src/ray/util/BUILD +++ b/src/ray/util/BUILD @@ -40,26 +40,30 @@ ray_cc_library( # TODO(hjiang): filesystem and logging has interdependency, we should split them into three targets: filesystem, logging, ray_check_macros. ray_cc_library( name = "logging", - hdrs = [ - "filesystem.h", - "logging.h", - ], - srcs = [ - "filesystem.cc", - "logging.cc", - ], + hdrs = ["logging.h"], + srcs = ["logging.cc"], deps = [ ":event_label", ":macros", + ":string_utils", ":thread_utils", "@com_github_spdlog//:spdlog", "@com_google_absl//absl/debugging:failure_signal_handler", "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest_main", + "@com_google_googletest//:gtest_prod", "@nlohmann_json", ], ) +ray_cc_library( + name = "filesystem", + hdrs = ["filesystem.h"], + srcs = ["filesystem.cc"], + deps = [ + "//src/ray/common:status_or", + ], +) + ray_cc_library( name = "container_util", hdrs = ["container_util.h"], @@ -82,6 +86,7 @@ ray_cc_library( ], deps = [ ":cmd_line_utils", + ":filesystem", ":logging", ":macros", "@boost//:asio", @@ -196,27 +201,13 @@ ray_cc_library( hdrs = ["util.h"], srcs = ["util.cc"], deps = [ - ":array", ":cmd_line_utils", ":container_util", - ":counter_map", - ":event", - ":event_label", - ":exponential_backoff", ":function_traits", ":logging", ":macros", - ":memory", ":process", - ":random", - ":sample", - ":sequencer", ":string_utils", - ":timestamp_utils", - ":throttler", - ":thread_utils", - ":type_traits", - ":visibility", "//:sha256", ], ) @@ -267,6 +258,7 @@ ray_cc_library( ":stream_redirection_options", ":thread_utils", ":util", + "@boost//:iostreams", "@com_github_spdlog//:spdlog", "@com_google_absl//absl/strings", ], @@ -282,3 +274,23 @@ ray_cc_library( ":util", ], ) + +ray_cc_library( + name = "spdlog_fd_sink", + hdrs = ["spdlog_fd_sink.h"], + deps = [ + ":compat", + ":util", + "@com_github_spdlog//:spdlog", + ], +) + +ray_cc_library( + name = "temporary_directory", + hdrs = ["temporary_directory.h"], + srcs = ["temporary_directory.cc"], + deps = [ + ":util", + "@com_google_absl//absl/strings:str_format", + ], +) diff --git a/src/ray/util/filesystem.cc b/src/ray/util/filesystem.cc index 9d4e2fd163d90..d8e4f3f135d3c 100644 --- a/src/ray/util/filesystem.cc +++ b/src/ray/util/filesystem.cc @@ -17,8 +17,6 @@ #include #include -#include "ray/util/logging.h" - #ifdef _WIN32 #include #endif @@ -57,17 +55,20 @@ std::string GetUserTempDir() { while (!result.empty() && IsDirSep(result.back())) { result.pop_back(); } - RAY_CHECK(!result.empty()); return result; } -std::string CompleteReadFile(const std::string &fname) { +StatusOr ReadEntireFile(const std::string &fname) { std::ifstream file(fname); - RAY_CHECK(file.good()) << "Fails to open file " << fname; + if (!file.good()) { + return Status::IOError("") << "Fails to open file " << fname; + } std::ostringstream buffer; buffer << file.rdbuf(); - RAY_CHECK(file.good()) << "Fails to read from file " << fname; + if (!file.good()) { + return Status::IOError("") << "Fails to read from file " << fname; + } std::string content = buffer.str(); file.close(); diff --git a/src/ray/util/filesystem.h b/src/ray/util/filesystem.h index beed7387b102a..bc1b8a0430b0c 100644 --- a/src/ray/util/filesystem.h +++ b/src/ray/util/filesystem.h @@ -19,6 +19,8 @@ #include #include +#include "ray/common/status_or.h" + // Filesystem and path manipulation APIs. // (NTFS stream & attribute paths are not supported.) @@ -42,29 +44,8 @@ static inline bool IsDirSep(char ch) { return result; } -/// \return The result of joining multiple path components. -template -std::string JoinPaths(std::string base, const Paths &...components) { - auto join = [](auto &joined_path, const auto &component) { - // if the components begin with "/" or "////", just get the path name. - if (!component.empty() && - component.front() == std::filesystem::path::preferred_separator) { - joined_path = std::filesystem::path(joined_path) - .append(std::filesystem::path(component).filename().string()) - .string(); - } else { - joined_path = std::filesystem::path(joined_path).append(component).string(); - } - }; - (join(base, std::string_view(components)), ...); - return base; -} - // Read the whole content for the given [fname], and return as string. -// If any error happens, error message will be logged and the program will exit -// immediately. -// -// TODO(hjiang): Use StatusOr as return type in the followup PR. -std::string CompleteReadFile(const std::string &fname); +// Return IO error status if open and read operation fail. +StatusOr ReadEntireFile(const std::string &fname); } // namespace ray diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index 96835c5912a55..f2460b5e371f0 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -41,7 +41,7 @@ #include "absl/strings/str_format.h" #include "nlohmann/json.hpp" #include "ray/util/event_label.h" -#include "ray/util/filesystem.h" +#include "ray/util/string_utils.h" #include "ray/util/thread_utils.h" #include "spdlog/sinks/basic_file_sink.h" #include "spdlog/sinks/rotating_file_sink.h" @@ -367,7 +367,7 @@ void RayLog::InitLogFormat() { app_name_without_path = "DefaultApp"; } else { // Find the app name without the path. - std::string app_file_name = ray::GetFileName(app_name); + std::string app_file_name = std::filesystem::path(app_name).filename().string(); if (!app_file_name.empty()) { app_name_without_path = app_file_name; } diff --git a/src/ray/util/logging.h b/src/ray/util/logging.h index 8b6401f99be2a..5da393f4304e2 100644 --- a/src/ray/util/logging.h +++ b/src/ray/util/logging.h @@ -61,6 +61,7 @@ #include #include "ray/util/macros.h" +#include "ray/util/string_utils.h" #if defined(_WIN32) #ifndef _WINDOWS_ diff --git a/src/ray/util/pipe_logger.cc b/src/ray/util/pipe_logger.cc index f66a35bfd9549..97536a047d6be 100644 --- a/src/ray/util/pipe_logger.cc +++ b/src/ray/util/pipe_logger.cc @@ -33,118 +33,56 @@ namespace ray { namespace { -// Default pipe log read buffer size. -constexpr size_t kDefaultPipeLogReadBufSize = 1024; - -size_t GetPipeLogReadSizeOrDefault() { - // TODO(hjiang): Write a util function `GetEnvOrDefault`. - const char *var_value = std::getenv(kPipeLogReadBufSizeEnv.data()); - if (var_value != nullptr) { - size_t read_buf_size = 0; - if (absl::SimpleAtoi(var_value, &read_buf_size) && read_buf_size > 0) { - return read_buf_size; - } - } - return kDefaultPipeLogReadBufSize; -} - struct StreamDumper { absl::Mutex mu; bool stopped ABSL_GUARDED_BY(mu) = false; std::deque content ABSL_GUARDED_BY(mu); }; -// File descriptors which indicates standard stream. -#if defined(__APPLE__) || defined(__linux__) -struct StdStreamFd { - int stdout_fd = STDOUT_FILENO; - int stderr_fd = STDERR_FILENO; -}; -#elif defined(_WIN32) -// TODO(hjiang): not used for windows, implement later. -struct StdStreamFd { - int stdout_fd = -1; - int stderr_fd = -1; +// Used to write to dup-ed stdout and stderr; use shared pointer to make it copy +// constructible. +struct StdOstream { + std::shared_ptr> + stdout_ostream; + std::shared_ptr> + stderr_ostream; }; -#endif -// Read bytes from handle into [data], return number of bytes read. -// If read fails, throw exception. -#if defined(__APPLE__) || defined(__linux__) -size_t Read(int read_fd, char *data, size_t len) { - // TODO(hjiang): Notice frequent read could cause performance issue. - ssize_t bytes_read = read(read_fd, data, len); - // TODO(hjiang): Add macros which checks for syscalls. - RAY_CHECK(bytes_read != -1) << "Fails to read from pipe because " << strerror(errno); - return bytes_read; -} -#endif - -template -void StartStreamDump(ReadFunc read_func, - WriteFunc write_func, - FlushFunc flush_func, - std::function close_read_handle, - std::function on_close_completion) { +// Start two threads: +// 1. A reader thread which continuously reads from [pipe_stream] until close; +// 2. A dumper thread which writes content to sink via [write_func]. +template +void StartStreamDump( + std::shared_ptr> + pipe_instream, + WriteFunc write_func, + FlushFunc flush_func, + std::function on_close_completion) { auto stream_dumper = std::make_shared(); // Create two threads, so there's no IO operation within critical section thus no // blocking on write. - std::thread([read_func = std::move(read_func), - close_read_handle = std::move(close_read_handle), + std::thread([pipe_instream = std::move(pipe_instream), stream_dumper = stream_dumper]() { SetThreadName("PipeReaderThd"); - const size_t buf_size = GetPipeLogReadSizeOrDefault(); - // TODO(hjiang): Should resize without initialization. - std::string content(buf_size, '\0'); - // Logging are written in lines, `last_line` records part of the strings left in - // last `read` syscall. - std::string last_line; - - while (true) { - size_t bytes_read = read_func(content.data(), content.length()); - - // Bytes read of size 0 indicates write-side of pipe has been closed. - if (bytes_read == 0) { - { - absl::MutexLock lock(&stream_dumper->mu); - stream_dumper->stopped = true; - } - - // Place IO operation out of critical section. - close_read_handle(); - - return; - } + std::string newline; - std::string_view cur_content{content.data(), bytes_read}; - std::vector newlines = absl::StrSplit(cur_content, '\n'); - - for (size_t idx = 0; idx < newlines.size() - 1; ++idx) { - std::string cur_new_line = std::move(last_line); - cur_new_line += newlines[idx]; - last_line.clear(); - - // We only log non-empty lines. - // - // TODO(hjiang): Newliners should also appear in the stdout/stderr/log, current - // behavior simply ignore everything. - if (!cur_new_line.empty()) { - absl::MutexLock lock(&stream_dumper->mu); - stream_dumper->content.emplace_back(std::move(cur_new_line)); - } + // Exit at pipe read EOF. + while (std::getline(*pipe_instream, newline)) { + // Backfill newliner for current segment. + if (!pipe_instream->eof()) { + newline += '\n'; } - // Special handle the last segment we've read. - // - // Nothing to do if we've read a complete newline. - if (content.back() == '\n') { - continue; - } + absl::MutexLock lock(&stream_dumper->mu); + stream_dumper->content.emplace_back(std::move(newline)); + } - // Otherwise record the newline so we could reuse in the next read iteration. - last_line += newlines.back(); + RAY_CHECK(pipe_instream->eof()); + { + absl::MutexLock lock(&stream_dumper->mu); + stream_dumper->stopped = true; } }).detach(); @@ -177,7 +115,7 @@ void StartStreamDump(ReadFunc read_func, } // Perform IO operation out of critical section. - write_func(curline); + write_func(std::move(curline)); } }).detach(); } @@ -215,57 +153,31 @@ bool ShouldUsePipeStream(const StreamRedirectionOption &stream_redirect_opt) { stream_redirect_opt.tee_to_stderr; } -#if defined(__APPLE__) || defined(__linux__) RedirectionFileHandle OpenFileForRedirection(const std::string &file_path) { - int fd = open(file_path.data(), O_WRONLY | O_CREAT, 0644); - RAY_CHECK_NE(fd, -1) << "Fails to open file " << file_path << " with failure reason " - << strerror(errno); - - auto flush_fn = [fd]() { - RAY_CHECK_EQ(fsync(fd), 0) << "Fails to flush data to disk because " - << strerror(errno); - }; - auto close_fn = [fd]() { - RAY_CHECK_EQ(fsync(fd), 0) << "Fails to flush data to disk because " - << strerror(errno); - RAY_CHECK_EQ(close(fd), 0) << "Fails to close redirection file because " - << strerror(errno); - }; - - return RedirectionFileHandle{fd, std::move(flush_fn), std::move(close_fn)}; -} + boost::iostreams::file_descriptor_sink sink{file_path, std::ios_base::out}; + auto handle = sink.handle(); + auto ostream = + std::make_shared>( + std::move(sink)); + auto flush_fn = [ostream, handle]() { + // Flush stream internal buffer to fd. + ostream->flush(); +// Flush file handle. +#if defined(__APPLE__) || defined(__linux__) + RAY_CHECK_EQ(fdatasync(handle), 0); #elif defined(_WIN32) -#include -RedirectionFileHandle OpenFileForRedirection(const std::string &file_path) { - HANDLE file_handle = CreateFile(file_path.c_str(), - GENERIC_WRITE, - 0, // No sharing - NULL, // Default security attributes - CREATE_ALWAYS, // Always create a new file - FILE_ATTRIBUTE_NORMAL, // Normal file attributes - NULL // No template file - ); - RAY_CHECK(file_handle != INVALID_HANDLE_VALUE) - << "Fails to open file " << file_path << " with error " - << std::to_string(GetLastError()); - - auto flush_fn = [file_handle]() { - RAY_CHECK(FlushFileBuffers(file_handle)) - << "Failed to flush data to disk with error: " << std::to_string(GetLastError()); + RAY_CHECK(FlushFileBuffers(handle)); +#endif }; - auto close_fn = [file_handle]() { - RAY_CHECK(FlushFileBuffers(file_handle)) - << "Failed to flush data to disk with error: " << std::to_string(GetLastError()); - RAY_CHECK(CloseHandle(file_handle)) - << "Failed to close file with error: " << std::to_string(GetLastError()); + auto close_fn = [flush_fn, ostream]() { + flush_fn(); + ostream->close(); }; - return RedirectionFileHandle{file_handle, std::move(flush_fn), std::move(close_fn)}; + return RedirectionFileHandle{ + handle, std::move(ostream), std::move(flush_fn), std::move(close_fn)}; } -#endif - } // namespace -#if defined(__APPLE__) || defined(__linux__) RedirectionFileHandle CreateRedirectionFileHandle( const StreamRedirectionOption &stream_redirect_opt) { // Case-1: only redirection, but not rotation and tee involved. @@ -283,85 +195,149 @@ RedirectionFileHandle CreateRedirectionFileHandle( // Invoked after flush and close finished. auto on_close_completion = [promise = promise]() { promise->set_value(); }; - StdStreamFd std_stream_fd{}; + StdOstream std_ostream{}; + +#if defined(__APPLE__) || defined(__linux__) if (stream_redirect_opt.tee_to_stdout) { - std_stream_fd.stdout_fd = dup(STDOUT_FILENO); - RAY_CHECK_NE(std_stream_fd.stdout_fd, -1) - << "Fails to duplicate stdout: " << strerror(errno); + int duped_stdout_fd = dup(STDOUT_FILENO); + RAY_CHECK_NE(duped_stdout_fd, -1) << "Fails to duplicate stdout: " << strerror(errno); + + boost::iostreams::file_descriptor_sink sink{ + duped_stdout_fd, /*file_descriptor_flags=*/boost::iostreams::close_handle}; + std_ostream.stdout_ostream = std::make_shared< + boost::iostreams::stream>( + std::move(sink)); } if (stream_redirect_opt.tee_to_stderr) { - std_stream_fd.stderr_fd = dup(STDERR_FILENO); - RAY_CHECK_NE(std_stream_fd.stderr_fd, -1) - << "Fails to duplicate stderr: " << strerror(errno); + int duped_stderr_fd = dup(STDERR_FILENO); + RAY_CHECK_NE(duped_stderr_fd, -1) << "Fails to duplicate stderr: " << strerror(errno); + + boost::iostreams::file_descriptor_sink sink{ + duped_stderr_fd, /*file_descriptor_flags=*/boost::iostreams::close_handle}; + std_ostream.stderr_ostream = std::make_shared< + boost::iostreams::stream>( + std::move(sink)); } - // TODO(hjiang): Use `boost::iostreams` to represent pipe write fd, which supports - // cross-platform and line-wise read. int pipefd[2] = {0}; - // TODO(hjiang): We shoud have our own syscall macro. RAY_CHECK_EQ(pipe(pipefd), 0); - int read_fd = pipefd[0]; - int write_fd = pipefd[1]; + int read_handle = pipefd[0]; + int write_handle = pipefd[1]; + boost::iostreams::file_descriptor_source pipe_read_source{ + read_handle, /*file_descriptor_flags=*/boost::iostreams::close_handle}; + boost::iostreams::file_descriptor_sink pipe_write_sink{ + write_handle, /*file_descriptor_flags=*/boost::iostreams::close_handle}; + +#elif defined(_WIN32) + if (stream_redirect_opt.tee_to_stdout) { + HANDLE duped_stdout_handle; + BOOL result = DuplicateHandle(GetCurrentProcess(), + GetStdHandle(STD_OUTPUT_HANDLE), + GetCurrentProcess(), + &duped_stdout_handle, + 0, + FALSE, + DUPLICATE_SAME_ACCESS); + RAY_CHECK(result) << "Fails to duplicate stdout handle"; + + boost::iostreams::file_descriptor_sink sink{ + duped_stdout_handle, /*file_descriptor_flags=*/boost::iostreams::close_handle}; + std_ostream.stdout_ostream = std::make_shared< + boost::iostreams::stream>( + std::move(sink)); + } + if (stream_redirect_opt.tee_to_stderr) { + HANDLE duped_stderr_handle; + BOOL result = DuplicateHandle(GetCurrentProcess(), + GetStdHandle(STD_ERROR_HANDLE), + GetCurrentProcess(), + &duped_stderr_handle, + 0, + FALSE, + DUPLICATE_SAME_ACCESS); + RAY_CHECK(result) << "Fails to duplicate stderr handle"; + + boost::iostreams::file_descriptor_sink sink{ + duped_stderr_handle, /*file_descriptor_flags=*/boost::iostreams::close_handle}; + std_ostream.stderr_ostream = std::make_shared< + boost::iostreams::stream>( + std::move(sink)); + } + + HANDLE read_handle = nullptr; + HANDLE write_handle = nullptr; + SECURITY_ATTRIBUTES sa = {sizeof(SECURITY_ATTRIBUTES), nullptr, TRUE}; + RAY_CHECK(CreatePipe(&read_handle, &write_handle, &sa, 0)) << "Fails to create pipe"; + boost::iostreams::file_descriptor_source pipe_read_source{ + read_handle, + /*file_descriptor_flags=*/boost::iostreams::close_handle}; + boost::iostreams::file_descriptor_sink pipe_write_sink{ + write_handle, + /*file_descriptor_flags=*/boost::iostreams::close_handle}; + +#endif + + auto pipe_instream = std::make_shared< + boost::iostreams::stream>( + std::move(pipe_read_source)); + auto pipe_ostream = + std::make_shared>( + std::move(pipe_write_sink)); - auto read_func = [read_fd](char *data, size_t len) { return Read(read_fd, data, len); }; - auto close_read_handle = [read_fd]() { RAY_CHECK_EQ(close(read_fd), 0); }; - auto close_fn = [write_fd, promise]() { - RAY_CHECK_EQ(close(write_fd), 0); + auto close_fn = [pipe_ostream, promise]() mutable { + pipe_ostream->close(); // Block until destruction finishes. promise->get_future().get(); }; auto logger = CreateLogger(stream_redirect_opt); - // [content] doesn't have trailing newliner. + // [content] is exactly what application writes to pipe, including the trailing + // newliner, if any. auto write_fn = [logger, stream_redirect_opt = stream_redirect_opt, - std_stream_fd = std_stream_fd](const std::string &content) { - if (logger != nullptr) { - logger->log(spdlog::level::info, content); - } + std_ostream = std_ostream](std::string content) { if (stream_redirect_opt.tee_to_stdout) { - RAY_CHECK_EQ(write(std_stream_fd.stdout_fd, content.data(), content.length()), - static_cast(content.length())); - RAY_CHECK_EQ(write(std_stream_fd.stdout_fd, "\n", 1), 1); + std_ostream.stdout_ostream->write(content.data(), content.length()); + RAY_CHECK(std_ostream.stdout_ostream->good()); } if (stream_redirect_opt.tee_to_stderr) { - RAY_CHECK_EQ(write(std_stream_fd.stderr_fd, content.data(), content.length()), - static_cast(content.length())); - RAY_CHECK_EQ(write(std_stream_fd.stderr_fd, "\n", 1), 1); + std_ostream.stderr_ostream->write(content.data(), content.length()); + RAY_CHECK(std_ostream.stderr_ostream->good()); } - }; - auto flush_fn = [logger, - stream_redirect_opt = stream_redirect_opt, - std_stream_fd = std_stream_fd]() { if (logger != nullptr) { - logger->flush(); - } - if (stream_redirect_opt.tee_to_stdout) { - fsync(std_stream_fd.stdout_fd); + // spdlog adds newliner for every content, no need to maintan the application-passed + // one. + if (!content.empty() && content.back() == '\n') { + content.pop_back(); + } + logger->log(spdlog::level::info, content); } - // No need to sync for stderr since it's unbuffered. }; + auto flush_fn = + [logger, stream_redirect_opt = stream_redirect_opt, std_ostream = std_ostream]() { + if (logger != nullptr) { + logger->flush(); + } + if (stream_redirect_opt.tee_to_stdout) { + std_ostream.stdout_ostream->flush(); + RAY_CHECK(std_ostream.stdout_ostream->good()); + } + if (stream_redirect_opt.tee_to_stderr) { + std_ostream.stderr_ostream->flush(); + RAY_CHECK(std_ostream.stderr_ostream->good()); + } + }; - StartStreamDump(std::move(read_func), + StartStreamDump(std::move(pipe_instream), std::move(write_fn), flush_fn, - std::move(close_read_handle), std::move(on_close_completion)); RedirectionFileHandle redirection_file_handle{ - write_fd, std::move(flush_fn), std::move(close_fn)}; + write_handle, std::move(pipe_ostream), std::move(flush_fn), std::move(close_fn)}; return redirection_file_handle; } -#elif defined(_WIN32) -RedirectionFileHandle CreateRedirectionFileHandle( - const StreamRedirectionOption &stream_redirect_opt) { - // TODO(hjiang): For windows, we currently doesn't support redirection with rotation and - // tee to stdout/stderr. - return OpenFileForRedirection(stream_redirect_opt.file_path); -} -#endif - } // namespace ray diff --git a/src/ray/util/pipe_logger.h b/src/ray/util/pipe_logger.h index 054e4789d1c7c..bfee9f120f3c6 100644 --- a/src/ray/util/pipe_logger.h +++ b/src/ray/util/pipe_logger.h @@ -17,6 +17,8 @@ #pragma once +#include +#include #include #include #include @@ -43,16 +45,22 @@ namespace ray { inline constexpr std::string_view kPipeLogReadBufSizeEnv = "RAY_PIPE_LOG_READ_BUF_SIZE"; // File handle requires active destruction via owner calling [Close]. +// +// TODO(hjiang): Wrap fd with spdlog sink to manage stream flush and close. class RedirectionFileHandle { public: RedirectionFileHandle() = default; // @param termination_synchronizer is used to block wait until destruction operation // finishes. - RedirectionFileHandle(MEMFD_TYPE_NON_UNIQUE write_handle, - std::function flush_fn, - std::function close_fn) + RedirectionFileHandle( + MEMFD_TYPE_NON_UNIQUE write_handle, + std::shared_ptr> + pipe_ostream, + std::function flush_fn, + std::function close_fn) : write_handle_(write_handle), + pipe_ostream_(std::move(pipe_ostream)), flush_fn_(std::move(flush_fn)), close_fn_(std::move(close_fn)) { RAY_CHECK(flush_fn_); @@ -65,6 +73,7 @@ class RedirectionFileHandle { RedirectionFileHandle(RedirectionFileHandle &&rhs) { write_handle_ = rhs.write_handle_; rhs.write_handle_ = INVALID_FD; + pipe_ostream_ = std::move(rhs.pipe_ostream_); flush_fn_ = std::move(rhs.flush_fn_); close_fn_ = std::move(rhs.close_fn_); } @@ -74,6 +83,7 @@ class RedirectionFileHandle { } write_handle_ = rhs.write_handle_; rhs.write_handle_ = INVALID_FD; + pipe_ostream_ = std::move(rhs.pipe_ostream_); flush_fn_ = std::move(rhs.flush_fn_); close_fn_ = std::move(rhs.close_fn_); return *this; @@ -82,6 +92,10 @@ class RedirectionFileHandle { if (write_handle_ != INVALID_FD) { close_fn_(); write_handle_ = INVALID_FD; + + // Unset flush and close functor to close logger and underlying file handle. + flush_fn_ = nullptr; + close_fn_ = nullptr; } } @@ -97,25 +111,16 @@ class RedirectionFileHandle { MEMFD_TYPE_NON_UNIQUE GetWriteHandle() const { return write_handle_; } - // Used to write to. - // - // TODO(hjiang): I will followup with another PR to make a `FD` class, which is not - // copiable to avoid manual `dup`. -#if defined(__APPLE__) || defined(__linux__) - void CompleteWrite(const char *data, size_t len) { - ssize_t bytes_written = write(write_handle_, data, len); - RAY_CHECK_EQ(bytes_written, static_cast(len)); - } -#elif defined(_WIN32) - void CompleteWrite(const char *data, size_t len) { - DWORD bytes_written = 0; - WriteFile(write_handle_, data, len, &bytes_written, nullptr); - } -#endif + // Write the given data into redirection handle; currently only for testing usage. + void CompleteWrite(const char *data, size_t len) { pipe_ostream_->write(data, len); } private: MEMFD_TYPE_NON_UNIQUE write_handle_; + // A high-level wrapper for [write_handle_]. + std::shared_ptr> + pipe_ostream_; + // Used to flush log message. std::function flush_fn_; diff --git a/src/ray/util/spdlog_fd_sink.h b/src/ray/util/spdlog_fd_sink.h new file mode 100644 index 0000000000000..f2130b5ee41ee --- /dev/null +++ b/src/ray/util/spdlog_fd_sink.h @@ -0,0 +1,70 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "ray/util/compat.h" +#include "ray/util/util.h" + +#if defined(__APPLE__) || defined(__linux__) +#include +#elif defined(_WIN32) +#include +#endif + +namespace ray { + +// A sink which logs to the file descriptor. +template +class non_owned_fd_sink final : public spdlog::sinks::base_sink { + public: + // [fd] is not owned by [FdSink], which means the file descriptor should be closed by + // caller. + explicit non_owned_fd_sink(MEMFD_TYPE_NON_UNIQUE fd) : fd_(fd) {} + + protected: + void sink_it_(const spdlog::details::log_msg &msg) override { + spdlog::memory_buf_t formatted; + spdlog::sinks::base_sink::formatter_->format(msg, formatted); + +#if defined(__APPLE__) || defined(__linux__) + RAY_CHECK_EQ(write(fd_, formatted.data(), formatted.size()), + static_cast(formatted.size())) + << "Fails to write because " << strerror(errno); +#elif defined(_WIN32) + DWORD bytes_written; + BOOL success = + WriteFile(fd_, formatted.data(), (DWORD)formatted.size(), &bytes_written, NULL); + RAY_CHECK(success); + RAY_CHECK_EQ((DWORD)formatted.size(), bytes_written); +#endif + } + void flush_() override { +#if defined(__APPLE__) || defined(__linux__) + RAY_CHECK_EQ(fdatasync(fd_), 0) << "Fails to flush file because " << strerror(errno); +#elif defined(_WIN32) + RAY_CHECK(FlushFileBuffers(fd_)); +#endif + } + + private: + MEMFD_TYPE_NON_UNIQUE fd_; +}; + +using non_owned_fd_sink_mt = non_owned_fd_sink; +using non_owned_fd_sink_st = non_owned_fd_sink; + +} // namespace ray diff --git a/src/ray/util/string_utils.h b/src/ray/util/string_utils.h index e66b117397b82..6d8d2121b10aa 100644 --- a/src/ray/util/string_utils.h +++ b/src/ray/util/string_utils.h @@ -14,6 +14,7 @@ #pragma once +#include #include namespace ray { @@ -27,4 +28,22 @@ std::string StringToHex(const std::string &str); /// \return The scanned prefix of the string, if any. std::string ScanToken(std::string::const_iterator &c_str, std::string format); +/// \return The result of joining multiple path components. +template +std::string JoinPaths(std::string base, const Paths &...components) { + auto join = [](auto &joined_path, const auto &component) { + // if the components begin with "/" or "////", just get the path name. + if (!component.empty() && + component.front() == std::filesystem::path::preferred_separator) { + joined_path = std::filesystem::path(joined_path) + .append(std::filesystem::path(component).filename().string()) + .string(); + } else { + joined_path = std::filesystem::path(joined_path).append(component).string(); + } + }; + (join(base, std::string_view(components)), ...); + return base; +} + } // namespace ray diff --git a/src/ray/util/temporary_directory.cc b/src/ray/util/temporary_directory.cc new file mode 100644 index 0000000000000..1cbe3ae6ed871 --- /dev/null +++ b/src/ray/util/temporary_directory.cc @@ -0,0 +1,34 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/util/temporary_directory.h" + +#include + +#include "ray/util/util.h" + +namespace ray { + +ScopedTemporaryDirectory::ScopedTemporaryDirectory(const std::string &dir) { + temporary_directory_ = + dir.empty() ? std::filesystem::temp_directory_path() : std::filesystem::path{dir}; + // Manually generate a directory name by appending UUID. + temporary_directory_ = temporary_directory_ / GenerateUUIDV4(); + RAY_CHECK(std::filesystem::create_directory(temporary_directory_)); +} +ScopedTemporaryDirectory::~ScopedTemporaryDirectory() { + RAY_CHECK(std::filesystem::remove_all(temporary_directory_)); +} + +} // namespace ray diff --git a/src/ray/util/temporary_directory.h b/src/ray/util/temporary_directory.h new file mode 100644 index 0000000000000..7e63e71e48164 --- /dev/null +++ b/src/ray/util/temporary_directory.h @@ -0,0 +1,42 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This class creates a temporary directory, which gets deleted (including all files +// inside of it) at its destruction. + +#pragma once + +#include +#include + +namespace ray { + +// A scoped temporary directory, which deletes all files and sub-directories recursively +// at its destruction. +class ScopedTemporaryDirectory { + public: + // Create a sub-directory under the given [dir]. + // If unspecified, new directory will be created under system temporary directory. + // + // If creation or deletion fails, the program will exit after logging error message. + ScopedTemporaryDirectory(const std::string &dir = ""); + ~ScopedTemporaryDirectory(); + + const std::filesystem::path &GetDirectory() const { return temporary_directory_; } + + private: + std::filesystem::path temporary_directory_; +}; + +} // namespace ray diff --git a/src/ray/util/tests/BUILD b/src/ray/util/tests/BUILD index 7159a0f61a69f..6c76d6be8fc8e 100644 --- a/src/ray/util/tests/BUILD +++ b/src/ray/util/tests/BUILD @@ -6,6 +6,7 @@ ray_cc_test( tags = ["team:core"], deps = [ "//src/ray/util", + "//src/ray/util:array", "@com_google_googletest//:gtest_main", ], ) @@ -53,6 +54,7 @@ ray_cc_test( tags = ["team:core"], deps = [ "//src/ray/util", + "//src/ray/util:counter_map", "@com_google_googletest//:gtest_main", ], ) @@ -67,11 +69,12 @@ ray_cc_test( "team:core", ], deps = [ + "//src/ray/common:ray_config", + "//src/ray/protobuf:gcs_cc_proto", "//src/ray/util", + "//src/ray/util:event", "@boost//:range", "@com_google_googletest//:gtest_main", - "//src/ray/protobuf:gcs_cc_proto", - "//src/ray/common:ray_config", ], ) @@ -82,6 +85,7 @@ ray_cc_test( tags = ["team:core"], deps = [ "//src/ray/util", + "//src/ray/util:exponential_backoff", "@com_google_googletest//:gtest_main", ], ) @@ -127,6 +131,7 @@ ray_cc_test( tags = ["team:core"], deps = [ "//:ray_common", + "//src/ray/util:sample", "@com_google_googletest//:gtest_main", ], ) @@ -138,6 +143,7 @@ ray_cc_test( tags = ["team:core"], deps = [ "//src/ray/util", + "//src/ray/util:sequencer", "@com_google_googletest//:gtest_main", ], ) @@ -160,6 +166,7 @@ ray_cc_test( tags = ["team:core"], deps = [ "//src/ray/util", + "//src/ray/util:throttler", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", ], @@ -217,6 +224,8 @@ ray_cc_test( name = "pipe_logger_test", srcs = ["pipe_logger_test.cc"], deps = [ + "//src/ray/util:temporary_directory", + "//src/ray/common/test:testing", "//src/ray/util", "//src/ray/util:pipe_logger", "@com_google_googletest//:gtest_main", @@ -230,12 +239,20 @@ ray_cc_test( name = "stream_redirection_utils_test", srcs = ["stream_redirection_utils_test.cc"], deps = [ + "//src/ray/common/test:testing", "//src/ray/util", "//src/ray/util:stream_redirection_utils", "@com_google_googletest//:gtest_main", ], size = "small", - tags = ["team:core"], + tags = [ + "team:core", + # TSAN fails to understand synchroization logic, from the stacktrace, it shows we flush + # ostream concurrently at pipe dumper thread and main thread, which we have ordered + # properly. Disable the complete test suite here since it always contains exactly one test + # case. + "no_tsan", + ], ) ray_cc_test( @@ -248,3 +265,25 @@ ray_cc_test( size = "small", tags = ["team:core"], ) + +ray_cc_test( + name = "spdlog_fd_sink_test", + srcs = ["spdlog_fd_sink_test.cc"], + deps = [ + "//src/ray/util:spdlog_fd_sink", + "@com_google_googletest//:gtest_main", + ], + size = "small", + tags = ["team:core"], +) + +ray_cc_test( + name = "temporary_directory_test", + srcs = ["temporary_directory_test.cc"], + deps = [ + "//src/ray/util:temporary_directory", + "@com_google_googletest//:gtest_main", + ], + size = "small", + tags = ["team:core"], +) diff --git a/src/ray/util/tests/pipe_logger_test.cc b/src/ray/util/tests/pipe_logger_test.cc index cfa3f149d19d3..a717b49482ce1 100644 --- a/src/ray/util/tests/pipe_logger_test.cc +++ b/src/ray/util/tests/pipe_logger_test.cc @@ -23,13 +23,11 @@ #include #include "absl/cleanup/cleanup.h" +#include "ray/common/test/testing.h" #include "ray/util/filesystem.h" +#include "ray/util/temporary_directory.h" #include "ray/util/util.h" -///////////////////////////////////////////////// -// Unit test for both windows and unix platform. -///////////////////////////////////////////////// - namespace ray { namespace { @@ -37,15 +35,9 @@ namespace { constexpr std::string_view kLogLine1 = "hello\n"; constexpr std::string_view kLogLine2 = "world\n"; -class PipeLoggerTest : public ::testing::TestWithParam {}; - -TEST_P(PipeLoggerTest, NoPipeWrite) { - const size_t pipe_buffer_size = GetParam(); - setEnv(kPipeLogReadBufSizeEnv.data(), absl::StrFormat("%d", pipe_buffer_size)); - - // TODO(hjiang): We should have a better test util, which allows us to create a - // temporary testing directory. - const std::string test_file_path = absl::StrFormat("%s.out", GenerateUUIDV4()); +TEST(PipeLoggerTest, RedirectionTest) { + ScopedTemporaryDirectory scoped_directory; + const auto test_file_path = scoped_directory.GetDirectory() / GenerateUUIDV4(); // Delete temporary file. absl::Cleanup cleanup_test_file = [&test_file_path]() { @@ -54,74 +46,22 @@ TEST_P(PipeLoggerTest, NoPipeWrite) { // Take the default option, which doesn't have rotation enabled. StreamRedirectionOption stream_redirection_opt{}; - stream_redirection_opt.file_path = test_file_path; + stream_redirection_opt.file_path = test_file_path.native(); auto stream_redirection_handle = CreateRedirectionFileHandle(stream_redirection_opt); stream_redirection_handle.CompleteWrite(kLogLine1.data(), kLogLine1.length()); stream_redirection_handle.CompleteWrite(kLogLine2.data(), kLogLine2.length()); stream_redirection_handle.Close(); // Check log content after completion. - const auto actual_content = CompleteReadFile(test_file_path); + const auto actual_content = ReadEntireFile(test_file_path.native()); + RAY_ASSERT_OK(actual_content); const std::string expected_content = absl::StrFormat("%s%s", kLogLine1, kLogLine2); - EXPECT_EQ(actual_content, expected_content); -} - -INSTANTIATE_TEST_SUITE_P(PipeLoggerTest, PipeLoggerTest, testing::Values(1024, 3)); - -} // namespace - -} // namespace ray - -///////////////////////////////////////////////// -// Unit test for both unix platform only. -///////////////////////////////////////////////// - -#if defined(__APPLE__) || defined(__linux__) - -#include - -namespace ray { - -namespace { - -TEST_P(PipeLoggerTest, PipeWrite) { - const size_t pipe_buffer_size = GetParam(); - setEnv(kPipeLogReadBufSizeEnv.data(), absl::StrFormat("%d", pipe_buffer_size)); - - // TODO(hjiang): We should have a better test util, which allows us to create a - // temporary testing directory. - const std::string test_file_path = absl::StrFormat("%s.out", GenerateUUIDV4()); - const std::string log_file_path1 = test_file_path; - const std::string log_file_path2 = absl::StrFormat("%s.1", test_file_path); - - // Delete temporary file. - absl::Cleanup cleanup_test_file = [&log_file_path1, &log_file_path2]() { - EXPECT_TRUE(std::filesystem::remove(log_file_path1)); - EXPECT_TRUE(std::filesystem::remove(log_file_path2)); - }; - - StreamRedirectionOption stream_redirection_opt{}; - stream_redirection_opt.file_path = test_file_path; - stream_redirection_opt.rotation_max_size = 5; - stream_redirection_opt.rotation_max_file_count = 2; - - auto stream_redirection_handle = CreateRedirectionFileHandle(stream_redirection_opt); - stream_redirection_handle.CompleteWrite(kLogLine1.data(), kLogLine1.length()); - stream_redirection_handle.CompleteWrite(kLogLine2.data(), kLogLine2.length()); - // Write empty line, which is not expected to appear. - stream_redirection_handle.CompleteWrite("\n", /*count=*/1); - // Synchronize on log flush completion. - stream_redirection_handle.Close(); - - // Check log content after completion. - EXPECT_EQ(CompleteReadFile(log_file_path1), kLogLine2); - EXPECT_EQ(CompleteReadFile(log_file_path2), kLogLine1); + EXPECT_EQ(*actual_content, expected_content); } TEST(PipeLoggerTestWithTee, RedirectionWithTee) { - // TODO(hjiang): We should have a better test util, which allows us to create a - // temporary testing directory. - const std::string test_file_path = absl::StrFormat("%s.out", GenerateUUIDV4()); + ScopedTemporaryDirectory scoped_directory; + const auto test_file_path = scoped_directory.GetDirectory() / GenerateUUIDV4(); // Delete temporary file. absl::Cleanup cleanup_test_file = [&test_file_path]() { @@ -129,7 +69,7 @@ TEST(PipeLoggerTestWithTee, RedirectionWithTee) { }; StreamRedirectionOption stream_redirection_opt{}; - stream_redirection_opt.file_path = test_file_path; + stream_redirection_opt.file_path = test_file_path.native(); stream_redirection_opt.tee_to_stdout = true; // Capture stdout via `dup`. @@ -145,16 +85,18 @@ TEST(PipeLoggerTestWithTee, RedirectionWithTee) { EXPECT_EQ(stdout_content, absl::StrFormat("%s%s", kLogLine1, kLogLine2)); // Check log content after completion. - EXPECT_EQ(CompleteReadFile(test_file_path), - absl::StrFormat("%s%s", kLogLine1, kLogLine2)); + const auto actual_content = ReadEntireFile(test_file_path.native()); + RAY_ASSERT_OK(actual_content); + EXPECT_EQ(*actual_content, absl::StrFormat("%s%s", kLogLine1, kLogLine2)); } TEST(PipeLoggerTestWithTee, RotatedRedirectionWithTee) { - // TODO(hjiang): We should have a better test util, which allows us to create a - // temporary testing directory. - const std::string test_file_path = absl::StrFormat("%s.out", GenerateUUIDV4()); - const std::string log_file_path1 = test_file_path; - const std::string log_file_path2 = absl::StrFormat("%s.1", test_file_path); + ScopedTemporaryDirectory scoped_directory; + const auto uuid = GenerateUUIDV4(); + const auto test_file_path = scoped_directory.GetDirectory() / uuid; + const auto log_file_path1 = test_file_path; + const auto log_file_path2 = + scoped_directory.GetDirectory() / absl::StrFormat("%s.1", uuid); // Delete temporary file. absl::Cleanup cleanup_test_file = [&log_file_path1, &log_file_path2]() { @@ -163,7 +105,7 @@ TEST(PipeLoggerTestWithTee, RotatedRedirectionWithTee) { }; StreamRedirectionOption stream_redirection_opt{}; - stream_redirection_opt.file_path = test_file_path; + stream_redirection_opt.file_path = test_file_path.native(); stream_redirection_opt.rotation_max_size = 5; stream_redirection_opt.rotation_max_file_count = 2; stream_redirection_opt.tee_to_stderr = true; @@ -181,12 +123,198 @@ TEST(PipeLoggerTestWithTee, RotatedRedirectionWithTee) { EXPECT_EQ(stderr_content, absl::StrFormat("%s%s", kLogLine1, kLogLine2)); // Check log content after completion. - EXPECT_EQ(CompleteReadFile(test_file_path), kLogLine2); - EXPECT_EQ(CompleteReadFile(log_file_path2), kLogLine1); + const auto actual_content1 = ReadEntireFile(log_file_path1.native()); + RAY_ASSERT_OK(actual_content1); + EXPECT_EQ(*actual_content1, kLogLine2); + + const auto actual_content2 = ReadEntireFile(log_file_path2.native()); + RAY_ASSERT_OK(actual_content2); + EXPECT_EQ(*actual_content2, kLogLine1); +} + +// Testing senario: log to stdout and file; check whether these two sinks generate +// expected output. +TEST(PipeLoggerCompatTest, CompatibilityTest) { + // Testing-1: No newliner in the middle nor at the end. + { + constexpr std::string_view kContent = "hello"; + ScopedTemporaryDirectory scoped_directory; + const auto test_file_path = scoped_directory.GetDirectory() / GenerateUUIDV4(); + + StreamRedirectionOption logging_option{}; + logging_option.file_path = test_file_path; + logging_option.tee_to_stdout = true; + + testing::internal::CaptureStdout(); + auto stream_redirection_handle = CreateRedirectionFileHandle(logging_option); + stream_redirection_handle.CompleteWrite(kContent.data(), kContent.length()); + stream_redirection_handle.Close(); + + const std::string stdout_content = testing::internal::GetCapturedStdout(); + EXPECT_EQ(stdout_content, kContent); + + // Pipe logger automatically adds a newliner at the end. + const auto actual_content = ReadEntireFile(test_file_path.native()); + RAY_ASSERT_OK(actual_content); + EXPECT_EQ(*actual_content, "hello\n"); + + EXPECT_TRUE(std::filesystem::remove(test_file_path)); + } + + // Testing-2: Newliner at the end. + { + constexpr std::string_view kContent = "hello\n"; + ScopedTemporaryDirectory scoped_directory; + const auto test_file_path = scoped_directory.GetDirectory() / GenerateUUIDV4(); + + StreamRedirectionOption logging_option{}; + logging_option.file_path = test_file_path; + logging_option.tee_to_stdout = true; + + testing::internal::CaptureStdout(); + auto stream_redirection_handle = CreateRedirectionFileHandle(logging_option); + stream_redirection_handle.CompleteWrite(kContent.data(), kContent.length()); + stream_redirection_handle.Close(); + + const std::string stdout_content = testing::internal::GetCapturedStdout(); + EXPECT_EQ(stdout_content, kContent); + + const auto actual_content = ReadEntireFile(test_file_path.native()); + RAY_ASSERT_OK(actual_content); + EXPECT_EQ(*actual_content, "hello\n"); + + EXPECT_TRUE(std::filesystem::remove(test_file_path)); + } + + // Testing-3: Newliner in the middle. + { + constexpr std::string_view kContent = "hello\nworld"; + ScopedTemporaryDirectory scoped_directory; + const auto test_file_path = scoped_directory.GetDirectory() / GenerateUUIDV4(); + + StreamRedirectionOption logging_option{}; + logging_option.file_path = test_file_path; + logging_option.tee_to_stdout = true; + + testing::internal::CaptureStdout(); + auto stream_redirection_handle = CreateRedirectionFileHandle(logging_option); + stream_redirection_handle.CompleteWrite(kContent.data(), kContent.length()); + stream_redirection_handle.Close(); + + const std::string stdout_content = testing::internal::GetCapturedStdout(); + EXPECT_EQ(stdout_content, kContent); + + // Pipe logger automatically adds a newliner at the end. + const auto actual_content = ReadEntireFile(test_file_path.native()); + RAY_EXPECT_OK(actual_content); + EXPECT_EQ(*actual_content, "hello\nworld\n"); + + EXPECT_TRUE(std::filesystem::remove(test_file_path)); + } + + // Testing-4: Newliner in the middle and the end. + { + constexpr std::string_view kContent = "hello\nworld\n"; + ScopedTemporaryDirectory scoped_directory; + const auto test_file_path = scoped_directory.GetDirectory() / GenerateUUIDV4(); + + StreamRedirectionOption logging_option{}; + logging_option.file_path = test_file_path; + logging_option.tee_to_stdout = true; + + testing::internal::CaptureStdout(); + auto stream_redirection_handle = CreateRedirectionFileHandle(logging_option); + stream_redirection_handle.CompleteWrite(kContent.data(), kContent.length()); + stream_redirection_handle.Close(); + + const std::string stdout_content = testing::internal::GetCapturedStdout(); + EXPECT_EQ(stdout_content, kContent); + + const auto actual_content = ReadEntireFile(test_file_path.native()); + RAY_EXPECT_OK(actual_content); + EXPECT_EQ(*actual_content, "hello\nworld\n"); + + EXPECT_TRUE(std::filesystem::remove(test_file_path)); + } + + // Testing-5: Continuous newliner at the end. + { + constexpr std::string_view kContent = "helloworld\n\n\n"; + ScopedTemporaryDirectory scoped_directory; + const auto test_file_path = scoped_directory.GetDirectory() / GenerateUUIDV4(); + + StreamRedirectionOption logging_option{}; + logging_option.file_path = test_file_path; + logging_option.tee_to_stdout = true; + + testing::internal::CaptureStdout(); + auto stream_redirection_handle = CreateRedirectionFileHandle(logging_option); + stream_redirection_handle.CompleteWrite(kContent.data(), kContent.length()); + stream_redirection_handle.Close(); + + const std::string stdout_content = testing::internal::GetCapturedStdout(); + EXPECT_EQ(stdout_content, kContent); + + const auto actual_content = ReadEntireFile(test_file_path.native()); + RAY_EXPECT_OK(actual_content); + EXPECT_EQ(*actual_content, "helloworld\n\n\n"); + + EXPECT_TRUE(std::filesystem::remove(test_file_path)); + } + + // Testing-6: Continous newliner in the middle. + { + constexpr std::string_view kContent = "hello\n\n\nworld"; + ScopedTemporaryDirectory scoped_directory; + const auto test_file_path = scoped_directory.GetDirectory() / GenerateUUIDV4(); + + StreamRedirectionOption logging_option{}; + logging_option.file_path = test_file_path; + logging_option.tee_to_stdout = true; + + testing::internal::CaptureStdout(); + auto stream_redirection_handle = CreateRedirectionFileHandle(logging_option); + stream_redirection_handle.CompleteWrite(kContent.data(), kContent.length()); + stream_redirection_handle.Close(); + + const std::string stdout_content = testing::internal::GetCapturedStdout(); + EXPECT_EQ(stdout_content, kContent); + + // Pipe logger automatically adds a newliner at the end. + const auto actual_content = ReadEntireFile(test_file_path.native()); + RAY_EXPECT_OK(actual_content); + EXPECT_EQ(*actual_content, "hello\n\n\nworld\n"); + + EXPECT_TRUE(std::filesystem::remove(test_file_path)); + } + + // Testing-7: Continuous newliner in the middle and at the end. + { + constexpr std::string_view kContent = "hello\n\nworld\n\n"; + ScopedTemporaryDirectory scoped_directory; + const auto test_file_path = scoped_directory.GetDirectory() / GenerateUUIDV4(); + + StreamRedirectionOption logging_option{}; + logging_option.file_path = test_file_path; + logging_option.tee_to_stdout = true; + + testing::internal::CaptureStdout(); + auto stream_redirection_handle = CreateRedirectionFileHandle(logging_option); + stream_redirection_handle.CompleteWrite(kContent.data(), kContent.length()); + stream_redirection_handle.Close(); + + const std::string stdout_content = testing::internal::GetCapturedStdout(); + EXPECT_EQ(stdout_content, kContent); + + // Pipe logger automatically adds a newliner at the end. + const auto actual_content = ReadEntireFile(test_file_path.native()); + RAY_ASSERT_OK(actual_content); + EXPECT_EQ(*actual_content, "hello\n\nworld\n\n"); + + EXPECT_TRUE(std::filesystem::remove(test_file_path)); + } } } // namespace } // namespace ray - -#endif diff --git a/src/ray/util/tests/spdlog_fd_sink_test.cc b/src/ray/util/tests/spdlog_fd_sink_test.cc new file mode 100644 index 0000000000000..5d31439a409e8 --- /dev/null +++ b/src/ray/util/tests/spdlog_fd_sink_test.cc @@ -0,0 +1,57 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/util/spdlog_fd_sink.h" + +#include + +namespace ray { + +namespace { + +#if defined(__APPLE__) || defined(__linux__) +int GetStdoutHandle() { return STDOUT_FILENO; } +#elif defined(_WIN32) +HANDLE GetStdoutHandle() { return GetStdHandle(STD_OUTPUT_HANDLE); } +#endif + +// Logs "helloworld" for whatever given message; here we don't care the what message is +// logged, the only thing matters is whether msg has been written to the given file +// descriptor correctly. +class HelloworldFormatter : public spdlog::formatter { + public: + void format(const spdlog::details::log_msg &msg, spdlog::memory_buf_t &dest) override { + dest.append(std::string{"helloworld"}); + } + std::unique_ptr clone() const override { + return std::make_unique(); + } +}; + +TEST(SpdlogFdSinkTest, SinkWithFd) { + non_owned_fd_sink_st sink{GetStdoutHandle()}; + sink.set_formatter(std::make_unique()); + spdlog::details::log_msg msg_to_log{ + /*logger_name=*/"logger_name", spdlog::level::level_enum::info, /*msg=*/"content"}; + + testing::internal::CaptureStdout(); + sink.log(msg_to_log); + const std::string stdout_content = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(stdout_content, "helloworld"); +}; + +} // namespace + +} // namespace ray diff --git a/src/ray/util/tests/stream_redirection_utils_test.cc b/src/ray/util/tests/stream_redirection_utils_test.cc index f6995e68fde93..d528dd3c1892f 100644 --- a/src/ray/util/tests/stream_redirection_utils_test.cc +++ b/src/ray/util/tests/stream_redirection_utils_test.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#if defined(__APPLE__) || defined(__linux__) - #include "ray/util/stream_redirection_utils.h" #include @@ -22,6 +20,7 @@ #include #include +#include "ray/common/test/testing.h" #include "ray/util/filesystem.h" #include "ray/util/util.h" @@ -30,15 +29,33 @@ namespace ray { namespace { constexpr std::string_view kLogLine1 = "hello\n"; constexpr std::string_view kLogLine2 = "world\n"; + +// Output logging files to cleanup at process termination. +std::vector log_files; +void CleanupOutputLogFiles() { + for (const auto &cur_log : log_files) { + EXPECT_TRUE(std::filesystem::remove(cur_log)); + } +} + } // namespace TEST(LoggingUtilTest, RedirectStderr) { + const std::string test_file_path = absl::StrFormat("%s.err", GenerateUUIDV4()); + const std::string log_file_path1 = test_file_path; + const std::string log_file_path2 = absl::StrFormat("%s.1", test_file_path); + log_files.emplace_back(log_file_path1); + log_files.emplace_back(log_file_path2); + + // Cleanup generated log files at test completion; because loggers are closed at process + // termination via exit hook, and hooked functions are executed at the reverse order of + // their registration, so register cleanup hook before logger close hook. + ASSERT_EQ(std::atexit(CleanupOutputLogFiles), 0); + // Works via `dup`, so have to execute before we redirect via `dup2` and close stderr. testing::internal::CaptureStderr(); // Redirect stderr for testing, so we could have stdout for debugging. - const std::string test_file_path = absl::StrFormat("%s.err", GenerateUUIDV4()); - StreamRedirectionOption opts; opts.file_path = test_file_path; opts.tee_to_stderr = true; @@ -55,23 +72,19 @@ TEST(LoggingUtilTest, RedirectStderr) { FlushOnRedirectedStderr(); // Check log content after completion. - const std::string log_file_path1 = test_file_path; - EXPECT_EQ(CompleteReadFile(test_file_path), kLogLine2); + const auto actual_content1 = ReadEntireFile(log_file_path1); + RAY_ASSERT_OK(actual_content1); + EXPECT_EQ(*actual_content1, kLogLine2); - const std::string log_file_path2 = absl::StrFormat("%s.1", test_file_path); - EXPECT_EQ(CompleteReadFile(log_file_path2), kLogLine1); + const auto actual_content2 = ReadEntireFile(log_file_path2); + RAY_ASSERT_OK(actual_content2); + EXPECT_EQ(*actual_content2, kLogLine1); // Check tee-ed to stderr content. std::string stderr_content = testing::internal::GetCapturedStderr(); EXPECT_EQ(stderr_content, absl::StrFormat("%s%s", kLogLine1, kLogLine2)); - // Delete temporary file. - EXPECT_EQ(unlink(log_file_path1.data()), 0); - EXPECT_EQ(unlink(log_file_path2.data()), 0); - // Make sure flush hook works fine and process terminates with no problem. } } // namespace ray - -#endif diff --git a/src/ray/util/tests/temporary_directory_test.cc b/src/ray/util/tests/temporary_directory_test.cc new file mode 100644 index 0000000000000..457f86c7cce35 --- /dev/null +++ b/src/ray/util/tests/temporary_directory_test.cc @@ -0,0 +1,50 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/util/temporary_directory.h" + +#include + +#include +#include + +#include "absl/strings/str_format.h" + +namespace ray { + +TEST(TemporaryDirectoryTest, CreationAndDestruction) { + std::filesystem::path temp_directory; + { + ScopedTemporaryDirectory dir{}; + temp_directory = dir.GetDirectory(); + + // Create a file under temporary directory. + std::filesystem::path empty_file = temp_directory / "empty_file"; + std::ofstream(empty_file).close(); + ASSERT_TRUE(std::filesystem::exists(empty_file)); + + // Create a sub-directory under temporary directory. + std::filesystem::path internal_dir = temp_directory / "dir"; + ASSERT_TRUE(std::filesystem::create_directory(internal_dir)); + ASSERT_TRUE(std::filesystem::exists(internal_dir)); + + // Create a file under internal directory. + std::filesystem::path internal_file = internal_dir / "empty_file"; + std::ofstream(internal_file).close(); + ASSERT_TRUE(std::filesystem::exists(empty_file)); + } + ASSERT_FALSE(std::filesystem::exists(temp_directory)); +} + +} // namespace ray