Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,12 @@ def dedicated_endpoint_enabled(self) -> bool:
return getattr(self._gca_resource, "dedicated_endpoint_enabled", False)
return False

@property
def enable_private_model_server(self) -> bool:
"""The private model server is enabled for this Endpoint.
"""
return getattr(self._gca_resource, "private_model_server_enabled", False)

@classmethod
def create(
cls,
Expand All @@ -834,6 +840,7 @@ def create(
request_response_logging_bq_destination_table: Optional[str] = None,
dedicated_endpoint_enabled=False,
inference_timeout: Optional[int] = None,
enable_private_model_server=False,
) -> "Endpoint":
"""Creates a new endpoint.

Expand Down Expand Up @@ -907,7 +914,11 @@ def create(
latency will be reduced.
inference_timeout (int):
Optional. It defines the prediction timeout, in seconds, for online predictions using cloud-based endpoints. This applies to either PSC endpoints, when private_service_connect_config is set, or dedicated endpoints, when dedicated_endpoint_enabled is true.

enable_private_model_server (bool):
Optional. If enabled, a private model server will be created and
the model server will be isolated from the external traffic.
By default, set to false, which means the model server will have
access to the external traffic.
Returns:
endpoint (aiplatform.Endpoint):
Created endpoint.
Expand Down Expand Up @@ -964,6 +975,7 @@ def create(
predict_request_response_logging_config=predict_request_response_logging_config,
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
client_connection_config=client_connection_config,
enable_private_model_server=enable_private_model_server,
)

@classmethod
Expand Down Expand Up @@ -993,6 +1005,7 @@ def _create(
client_connection_config: Optional[
gca_endpoint_compat.ClientConnectionConfig
] = None,
enable_private_model_server: bool = False,
) -> "Endpoint":
"""Creates a new endpoint by calling the API client.

Expand Down Expand Up @@ -1065,7 +1078,11 @@ def _create(
latency will be reduced.
client_connection_config (aiplatform.endpoint.ClientConnectionConfig):
Optional. The inference timeout which is applied on cloud-based (PSC, or dedicated) endpoints for online prediction.

enable_private_model_server (bool):
Optional. If enabled, a private model server will be created and
the model server will be isolated from the external traffic.
By default, set to false, which means the model server will have
access to the external traffic.
Returns:
endpoint (aiplatform.Endpoint):
Created endpoint.
Expand All @@ -1085,6 +1102,7 @@ def _create(
private_service_connect_config=private_service_connect_config,
dedicated_endpoint_enabled=dedicated_endpoint_enabled,
client_connection_config=client_connection_config,
private_model_server_enabled=enable_private_model_server,
)

operation_future = api_client.create_endpoint(
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/aiplatform_v1/types/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ class Endpoint(proto.Message):
native RAG integration can be configured.
Currently, only Model Garden models are
supported.
private_model_server_enabled (bool):
If true, the model server will be isolated from the external
network. By default, set to false.
"""

name: str = proto.Field(
Expand Down Expand Up @@ -274,6 +277,10 @@ class Endpoint(proto.Message):
number=29,
message="GenAiAdvancedFeaturesConfig",
)
private_model_server_enabled: bool = proto.Field(
proto.BOOL,
number=30,
)


class DeployedModel(proto.Message):
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/aiplatform_v1beta1/types/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ class Endpoint(proto.Message):
native RAG integration can be configured.
Currently, only Model Garden models are
supported.
private_model_server_enabled (bool):
If true, the model server will be isolated from the external
network. By default, set to false.
"""

name: str = proto.Field(
Expand Down Expand Up @@ -279,6 +282,10 @@ class Endpoint(proto.Message):
number=29,
message="GenAiAdvancedFeaturesConfig",
)
private_model_server_enabled: bool = proto.Field(
proto.BOOL,
number=30,
)


class DeployedModel(proto.Message):
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,18 @@ def get_endpoint_alt_location_mock():
)
yield get_endpoint_mock

@pytest.fixture
def get_endpoint_with_private_model_server_enabled_mock():
with mock.patch.object(
endpoint_service_client.EndpointServiceClient, "get_endpoint"
) as get_endpoint_mock:
get_endpoint_mock.return_value = gca_endpoint.Endpoint(
display_name=_TEST_DISPLAY_NAME,
name=_TEST_ENDPOINT_NAME_ALT_LOCATION,
private_model_server_enabled=True,
)
yield get_endpoint_mock


@pytest.fixture
def get_endpoint_with_models_mock():
Expand Down Expand Up @@ -1479,6 +1491,33 @@ def test_create_dedicated_endpoint_with_timeout(
endpoint_id=None,
)

@pytest.mark.usefixtures("get_endpoint_with_private_model_server_enabled_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create_endpoint_with_private_model_server_enabled(
self, create_endpoint_mock, sync
):
my_endpoint = models.Endpoint.create(
display_name=_TEST_DISPLAY_NAME,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
enable_private_model_server=True,
sync=sync,
)
if not sync:
my_endpoint.wait()

expected_endpoint = gca_endpoint.Endpoint(
display_name=_TEST_DISPLAY_NAME,
private_model_server_enabled=True,
)
create_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
metadata=(),
timeout=None,
endpoint_id=None,
)

@pytest.mark.usefixtures("get_empty_endpoint_mock")
def test_accessing_properties_with_no_resource_raises(
self,
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/gapic/aiplatform_v1/test_endpoint_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7140,6 +7140,7 @@ def test_create_endpoint_rest_call_success(request_type):
"satisfies_pzs": True,
"satisfies_pzi": True,
"gen_ai_advanced_features_config": {"rag_config": {"enable_rag": True}},
"private_model_server_enabled": True,
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
Expand Down Expand Up @@ -7765,6 +7766,7 @@ def test_update_endpoint_rest_call_success(request_type):
"satisfies_pzs": True,
"satisfies_pzi": True,
"gen_ai_advanced_features_config": {"rag_config": {"enable_rag": True}},
"private_model_server_enabled": True,
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
Expand Down Expand Up @@ -9605,6 +9607,7 @@ async def test_create_endpoint_rest_asyncio_call_success(request_type):
"satisfies_pzs": True,
"satisfies_pzi": True,
"gen_ai_advanced_features_config": {"rag_config": {"enable_rag": True}},
"private_model_server_enabled": True,
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
Expand Down Expand Up @@ -10287,6 +10290,7 @@ async def test_update_endpoint_rest_asyncio_call_success(request_type):
"satisfies_pzs": True,
"satisfies_pzi": True,
"gen_ai_advanced_features_config": {"rag_config": {"enable_rag": True}},
"private_model_server_enabled": True,
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8364,6 +8364,7 @@ def test_create_endpoint_rest_call_success(request_type):
"satisfies_pzs": True,
"satisfies_pzi": True,
"gen_ai_advanced_features_config": {"rag_config": {"enable_rag": True}},
"private_model_server_enabled": True,
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
Expand Down Expand Up @@ -9009,6 +9010,7 @@ def test_update_endpoint_rest_call_success(request_type):
"satisfies_pzs": True,
"satisfies_pzi": True,
"gen_ai_advanced_features_config": {"rag_config": {"enable_rag": True}},
"private_model_server_enabled": True,
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
Expand Down Expand Up @@ -11172,6 +11174,7 @@ async def test_create_endpoint_rest_asyncio_call_success(request_type):
"satisfies_pzs": True,
"satisfies_pzi": True,
"gen_ai_advanced_features_config": {"rag_config": {"enable_rag": True}},
"private_model_server_enabled": True,
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
Expand Down Expand Up @@ -11874,6 +11877,7 @@ async def test_update_endpoint_rest_asyncio_call_success(request_type):
"satisfies_pzs": True,
"satisfies_pzi": True,
"gen_ai_advanced_features_config": {"rag_config": {"enable_rag": True}},
"private_model_server_enabled": True,
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
Expand Down