Skip to content

Commit 84a8689

Browse files
sararobcopybara-github
authored andcommitted
feat: Add GenAI client (experimental)
PiperOrigin-RevId: 766750716
1 parent c81eb72 commit 84a8689

File tree

10 files changed

+208
-85
lines changed

10 files changed

+208
-85
lines changed

noxfile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
"pytest-asyncio",
6868
# Preventing: py.test: error: unrecognized arguments: -n=auto --dist=loadscope
6969
"pytest-xdist",
70+
# "pandas",
71+
# "tqdm",
7072
]
7173
UNIT_TEST_EXTERNAL_DEPENDENCIES = []
7274
UNIT_TEST_LOCAL_DEPENDENCIES = []

tests/unit/architecture/test_vertexai_import.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,28 @@ def test_vertexai_import():
9090
new_modules = modules_after_vertexai - modules_before_aip
9191
_test_external_imports(new_modules)
9292

93+
# Tests the GenAI client module is lazy loaded.
94+
from vertexai._genai import client as _ # noqa: F401,F811
95+
96+
modules_after_genai_client_import = set(sys.modules)
97+
98+
# The evals module has additional required deps that should not be required
99+
# to instantiate a client.
100+
assert "pandas" not in modules_after_genai_client_import
101+
102+
# Tests the evals module is lazy loaded.
103+
eval_module_deps_subset = ("pydantic", "mcp", "pandas")
104+
105+
assert not all(module in new_modules for module in eval_module_deps_subset)
106+
107+
from vertexai._genai import evals as _ # noqa: F401,F811
108+
109+
modules_after_eval_import = set(sys.modules)
110+
111+
assert all(
112+
module in modules_after_eval_import for module in eval_module_deps_subset
113+
)
114+
93115

94116
def _test_external_imports(new_modules: list):
95117
builtin_modules = {

tests/unit/vertexai/genai/test_evals.py

Lines changed: 58 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,29 @@
1818
import json
1919
import os
2020
from unittest import mock
21+
import warnings
2122

2223
from google.cloud import aiplatform
2324
import vertexai
2425
from google.cloud.aiplatform import initializer as aiplatform_initializer
2526
from vertexai import _genai
27+
from vertexai._genai import evals
2628
from vertexai._genai import types as vertexai_genai_types
29+
from google.genai import client
30+
from google.genai import errors as genai_errors
2731
from google.genai import types as genai_types
28-
import google.genai.errors as genai_errors
2932
import pandas as pd
3033
import pytest
31-
import warnings
34+
3235

3336
_TEST_PROJECT = "test-project"
3437
_TEST_LOCATION = "us-central1"
3538

3639

40+
_genai.evals._lazy_load_evals_common()
41+
_evals_common = _genai.evals._evals_common
42+
_evals_utils = _genai._evals_utils
43+
3744
pytestmark = pytest.mark.usefixtures("google_auth_mock")
3845

3946

@@ -48,47 +55,41 @@ def setup_method(self):
4855
project=_TEST_PROJECT,
4956
location=_TEST_LOCATION,
5057
)
58+
self.client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
5159

5260
@pytest.mark.usefixtures("google_auth_mock")
53-
def test_evaluate_instances(self):
54-
test_client = _genai.client.Client(
55-
project=_TEST_PROJECT, location=_TEST_LOCATION
56-
)
61+
@mock.patch.object(client.Client, "_get_api_client")
62+
@mock.patch.object(evals.Evals, "_evaluate_instances")
63+
def test_evaluate_instances(self, mock_evaluate, mock_get_api_client):
5764
with warnings.catch_warnings(record=True) as captured_warnings:
5865
warnings.simplefilter("always")
59-
with mock.patch.object(
60-
test_client.evals, "_evaluate_instances"
61-
) as mock_evaluate:
62-
test_client.evals._evaluate_instances(
63-
bleu_input=_genai.types.BleuInput()
64-
)
65-
mock_evaluate.assert_called_once_with(
66-
bleu_input=_genai.types.BleuInput()
67-
)
68-
assert captured_warnings[0].category == genai_errors.ExperimentalWarning
66+
self.client.evals._evaluate_instances(
67+
bleu_input=vertexai_genai_types.BleuInput()
68+
)
69+
mock_evaluate.assert_called_once_with(
70+
bleu_input=vertexai_genai_types.BleuInput()
71+
)
72+
assert captured_warnings[0].category == genai_errors.ExperimentalWarning
6973

7074
@pytest.mark.usefixtures("google_auth_mock")
7175
def test_eval_run(self):
72-
test_client = _genai.client.Client(
73-
project=_TEST_PROJECT, location=_TEST_LOCATION
74-
)
76+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
7577
with pytest.raises(NotImplementedError):
7678
test_client.evals.run()
7779

7880
@pytest.mark.usefixtures("google_auth_mock")
79-
def test_eval_batch_eval(self):
80-
test_client = _genai.client.Client(
81-
project=_TEST_PROJECT, location=_TEST_LOCATION
82-
)
83-
with mock.patch.object(test_client.evals, "batch_eval") as mock_batch_eval:
84-
test_client.evals.batch_eval(
85-
dataset=_genai.types.EvaluationDataset(),
86-
metrics=[_genai.types.Metric()],
87-
output_config=_genai.types.OutputConfig(),
88-
autorater_config=_genai.types.AutoraterConfig(),
89-
config=_genai.types.EvaluateDatasetConfig(),
90-
)
91-
mock_batch_eval.assert_called_once()
81+
@mock.patch.object(client.Client, "_get_api_client")
82+
@mock.patch.object(evals.Evals, "batch_eval")
83+
def test_eval_batch_eval(self, mock_evaluate, mock_get_api_client):
84+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
85+
test_client.evals.batch_eval(
86+
dataset=vertexai_genai_types.EvaluationDataset(),
87+
metrics=[vertexai_genai_types.Metric()],
88+
output_config=vertexai_genai_types.OutputConfig(),
89+
autorater_config=vertexai_genai_types.AutoraterConfig(),
90+
config=vertexai_genai_types.EvaluateDatasetConfig(),
91+
)
92+
mock_evaluate.assert_called_once()
9293

9394

9495
class TestEvalsClientInference:
@@ -99,20 +100,16 @@ def setup_method(self):
99100
importlib.reload(aiplatform)
100101
importlib.reload(vertexai)
101102
importlib.reload(_genai.client)
102-
importlib.reload(_genai.types)
103+
importlib.reload(vertexai_genai_types)
103104
importlib.reload(_genai.evals)
104-
importlib.reload(_genai._evals_utils)
105-
importlib.reload(_genai._evals_common)
106105
vertexai.init(
107106
project=_TEST_PROJECT,
108107
location=_TEST_LOCATION,
109108
)
110-
self.client = _genai.client.Client(
111-
project=_TEST_PROJECT, location=_TEST_LOCATION
112-
)
109+
self.client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
113110

114-
@mock.patch(f"{_genai._evals_common.__name__}.Models")
115-
@mock.patch(f"{_genai._evals_utils.__name__}.EvalDatasetLoader")
111+
@mock.patch.object(_evals_common, "Models")
112+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
116113
def test_inference_with_string_model_success(
117114
self, mock_eval_dataset_loader, mock_models
118115
):
@@ -153,7 +150,7 @@ def test_inference_with_string_model_success(
153150
),
154151
)
155152

156-
@mock.patch(f"{_genai._evals_utils.__name__}.EvalDatasetLoader")
153+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
157154
def test_inference_with_callable_model_success(self, mock_eval_dataset_loader):
158155
mock_df = pd.DataFrame({"prompt": ["test prompt"]})
159156
mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict(
@@ -178,8 +175,8 @@ def mock_model_fn(contents):
178175
),
179176
)
180177

181-
@mock.patch(f"{_genai._evals_common.__name__}.Models")
182-
@mock.patch(f"{_genai._evals_utils.__name__}.EvalDatasetLoader")
178+
@mock.patch.object(_evals_common, "Models")
179+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
183180
def test_inference_with_prompt_template(
184181
self, mock_eval_dataset_loader, mock_models
185182
):
@@ -202,7 +199,7 @@ def test_inference_with_prompt_template(
202199
mock_generate_content_response
203200
)
204201

205-
config = _genai.types.EvalRunInferenceConfig(
202+
config = vertexai_genai_types.EvalRunInferenceConfig(
206203
prompt_template="Hello {text_input}"
207204
)
208205
result_df = self.client.evals.run_inference(
@@ -223,9 +220,9 @@ def test_inference_with_prompt_template(
223220
),
224221
)
225222

226-
@mock.patch(f"{_genai._evals_common.__name__}.Models")
227-
@mock.patch(f"{_genai._evals_utils.__name__}.EvalDatasetLoader")
228-
@mock.patch(f"{_genai._evals_utils.__name__}.GcsUtils")
223+
@mock.patch.object(_evals_common, "Models")
224+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
225+
@mock.patch.object(_evals_utils, "GcsUtils")
229226
def test_inference_with_gcs_destination(
230227
self, mock_gcs_utils, mock_eval_dataset_loader, mock_models
231228
):
@@ -249,7 +246,7 @@ def test_inference_with_gcs_destination(
249246
)
250247

251248
gcs_dest_path = "gs://bucket/output.jsonl"
252-
config = _genai.types.EvalRunInferenceConfig(dest=gcs_dest_path)
249+
config = vertexai_genai_types.EvalRunInferenceConfig(dest=gcs_dest_path)
253250

254251
result_df = self.client.evals.run_inference(
255252
model="gemini-pro", src=mock_df, config=config
@@ -270,8 +267,8 @@ def test_inference_with_gcs_destination(
270267
)
271268
pd.testing.assert_frame_equal(result_df, expected_df_to_save)
272269

273-
@mock.patch(f"{_genai._evals_common.__name__}.Models")
274-
@mock.patch(f"{_genai._evals_utils.__name__}.EvalDatasetLoader")
270+
@mock.patch.object(_evals_common, "Models")
271+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
275272
@mock.patch("pandas.DataFrame.to_json")
276273
@mock.patch("os.makedirs")
277274
def test_inference_with_local_destination(
@@ -301,7 +298,7 @@ def test_inference_with_local_destination(
301298
)
302299

303300
local_dest_path = "/tmp/test/output_dir/results.jsonl"
304-
config = _genai.types.EvalRunInferenceConfig(dest=local_dest_path)
301+
config = vertexai_genai_types.EvalRunInferenceConfig(dest=local_dest_path)
305302

306303
result_df = self.client.evals.run_inference(
307304
model="gemini-pro", src=mock_df, config=config
@@ -319,8 +316,8 @@ def test_inference_with_local_destination(
319316
)
320317
pd.testing.assert_frame_equal(result_df, expected_df)
321318

322-
@mock.patch(f"{_genai._evals_common.__name__}.Models")
323-
@mock.patch(f"{_genai._evals_utils.__name__}.EvalDatasetLoader")
319+
@mock.patch.object(_evals_common, "Models")
320+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
324321
def test_inference_from_request_column_save_locally(
325322
self, mock_eval_dataset_loader, mock_models
326323
):
@@ -359,7 +356,7 @@ def test_inference_from_request_column_save_locally(
359356
)
360357

361358
local_dest_path = "/tmp/output.jsonl"
362-
config = _genai.types.EvalRunInferenceConfig(dest=local_dest_path)
359+
config = vertexai_genai_types.EvalRunInferenceConfig(dest=local_dest_path)
363360

364361
result_df = self.client.evals.run_inference(
365362
model="gemini-pro", src=mock_df, config=config
@@ -396,7 +393,7 @@ def test_inference_from_request_column_save_locally(
396393
assert saved_records == expected_records
397394
os.remove(local_dest_path)
398395

399-
@mock.patch(f"{_genai._evals_common.__name__}.Models")
396+
@mock.patch.object(_evals_common, "Models")
400397
def test_inference_from_local_jsonl_file(self, mock_models):
401398
# Create a temporary JSONL file
402399
local_src_path = "/tmp/input.jsonl"
@@ -450,7 +447,7 @@ def test_inference_from_local_jsonl_file(self, mock_models):
450447
pd.testing.assert_frame_equal(result_df, expected_df)
451448
os.remove(local_src_path)
452449

453-
@mock.patch(f"{_genai._evals_common.__name__}.Models")
450+
@mock.patch.object(_evals_common, "Models")
454451
def test_inference_from_local_csv_file(self, mock_models):
455452
# Create a temporary CSV file
456453
local_src_path = "/tmp/input.csv"
@@ -501,8 +498,8 @@ def test_inference_from_local_csv_file(self, mock_models):
501498
pd.testing.assert_frame_equal(result_df, expected_df)
502499
os.remove(local_src_path)
503500

504-
@mock.patch(f"{_genai._evals_common.__name__}.Models")
505-
@mock.patch(f"{_genai._evals_utils.__name__}.EvalDatasetLoader")
501+
@mock.patch.object(_evals_common, "Models")
502+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
506503
def test_inference_with_row_level_config_overrides(
507504
self, mock_eval_dataset_loader, mock_models
508505
):
@@ -584,8 +581,8 @@ def test_inference_with_row_level_config_overrides(
584581
)
585582
pd.testing.assert_frame_equal(result_df, expected_df)
586583

587-
@mock.patch(f"{_genai._evals_common.__name__}.Models")
588-
@mock.patch(f"{_genai._evals_utils.__name__}.EvalDatasetLoader")
584+
@mock.patch.object(_evals_common, "Models")
585+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
589586
def test_inference_with_multimodal_content(
590587
self, mock_eval_dataset_loader, mock_models
591588
):
@@ -623,7 +620,7 @@ def test_inference_with_multimodal_content(
623620
mock_generate_content_response
624621
)
625622

626-
config = _genai.types.EvalRunInferenceConfig(
623+
config = vertexai_genai_types.EvalRunInferenceConfig(
627624
prompt_template="multimodal prompt: {media_content}{text_input}"
628625
)
629626
result_df = self.client.evals.run_inference(

tests/unit/vertexai/genai/test_genai_client.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from google.cloud import aiplatform
2121
import vertexai
2222
from google.cloud.aiplatform import initializer as aiplatform_initializer
23-
from vertexai import _genai
2423
import pytest
2524

2625
_TEST_PROJECT = "test-project"
@@ -44,9 +43,7 @@ def setup_method(self):
4443

4544
@pytest.mark.usefixtures("google_auth_mock")
4645
def test_genai_client(self):
47-
test_client = _genai.client.Client(
48-
project=_TEST_PROJECT, location=_TEST_LOCATION
49-
)
46+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
5047
assert test_client is not None
5148
assert test_client._api_client.vertexai
5249
assert test_client._api_client.project == _TEST_PROJECT

vertexai/__init__.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222

2323
from google.cloud.aiplatform import init
2424

25-
__all__ = [
26-
"init",
27-
"preview",
28-
]
25+
_genai_client = None
26+
_genai_types = None
2927

3028

3129
def __getattr__(name):
@@ -39,4 +37,24 @@ def __getattr__(name):
3937
# `import google.cloud.aiplatform.vertexai.preview as vertexai_preview`
4038
return importlib.import_module(".preview", __name__)
4139

40+
if name == "Client":
41+
global _genai_client
42+
if _genai_client is None:
43+
_genai_client = importlib.import_module("._genai.client", __name__)
44+
return getattr(_genai_client, name)
45+
46+
if name == "types":
47+
global _genai_types
48+
if _genai_types is None:
49+
_genai_types = importlib.import_module("._genai.types", __name__)
50+
return _genai_types
51+
4252
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
53+
54+
55+
__all__ = [
56+
"init",
57+
"preview",
58+
"Client",
59+
"types",
60+
]

vertexai/_genai/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,29 @@
1414
#
1515
"""The vertexai module."""
1616

17-
from . import evals
17+
import importlib
18+
1819
from .client import Client
1920

21+
_evals = None
22+
23+
24+
def __getattr__(name):
25+
if name == "evals":
26+
global _evals
27+
if _evals is None:
28+
try:
29+
_evals = importlib.import_module(".evals", __package__)
30+
except ImportError as e:
31+
raise ImportError(
32+
"The 'evals' module requires 'pandas' and 'tqdm'. "
33+
"Please install them using pip install "
34+
"google-cloud-aiplatform[evaluation]"
35+
) from e
36+
return _evals
37+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
38+
39+
2040
__all__ = [
2141
"Client",
2242
"evals",

0 commit comments

Comments
 (0)