Skip to content

Commit cee65a2

Browse files
sararobcopybara-github
authored andcommitted
chore: fix remaining non-strict mypy errors in _api_client
PiperOrigin-RevId: 739224190
1 parent 43c5379 commit cee65a2

File tree

6 files changed

+164
-70
lines changed

6 files changed

+164
-70
lines changed

google/genai/_api_client.py

+70-55
Original file line numberDiff line numberDiff line change
@@ -68,26 +68,30 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
6868

6969

7070
def _patch_http_options(
71-
options: HttpOptionsDict, patch_options: dict[str, Any]
72-
) -> HttpOptionsDict:
73-
# use shallow copy so we don't override the original objects.
74-
copy_option = HttpOptionsDict()
75-
copy_option.update(options)
76-
for patch_key, patch_value in patch_options.items():
77-
# if both are dicts, update the copy.
78-
# This is to handle cases like merging headers.
79-
if isinstance(patch_value, dict) and isinstance(
80-
copy_option.get(patch_key, None), dict
81-
):
82-
copy_option[patch_key] = {}
83-
copy_option[patch_key].update(
84-
options[patch_key]
85-
) # shallow copy from original options.
86-
copy_option[patch_key].update(patch_value)
87-
elif patch_value is not None: # Accept empty values.
88-
copy_option[patch_key] = patch_value
89-
if copy_option['headers']:
90-
_append_library_version_headers(copy_option['headers'])
71+
options: HttpOptions, patch_options: HttpOptions
72+
) -> HttpOptions:
73+
copy_option = options.model_copy()
74+
75+
options_headers = copy_option.headers or {}
76+
patch_options_headers = patch_options.headers or {}
77+
copy_option.headers = {
78+
**options_headers,
79+
**patch_options_headers,
80+
}
81+
82+
http_options_keys = HttpOptions.model_fields.keys()
83+
84+
for key in http_options_keys:
85+
if key == 'headers':
86+
continue
87+
patch_value = getattr(patch_options, key, None)
88+
if patch_value is not None:
89+
setattr(copy_option, key, patch_value)
90+
else:
91+
setattr(copy_option, key, getattr(options, key))
92+
93+
if copy_option.headers is not None:
94+
_append_library_version_headers(copy_option.headers)
9195
return copy_option
9296

9397

@@ -200,7 +204,7 @@ async def async_segments(self) -> AsyncIterator[Any]:
200204
for chunk in self.response_stream:
201205
yield json.loads(chunk) if chunk else {}
202206
elif self.response_stream is None:
203-
async for c in []:
207+
async for c in []: # type: ignore[attr-defined]
204208
yield c
205209
else:
206210
# Iterator of objects retrieved from the API.
@@ -306,16 +310,14 @@ def __init__(
306310
)
307311

308312
# Validate http_options if it is provided.
309-
validated_http_options: dict[str, Any]
313+
validated_http_options = HttpOptions()
310314
if isinstance(http_options, dict):
311315
try:
312-
validated_http_options = HttpOptions.model_validate(
313-
http_options
314-
).model_dump()
316+
validated_http_options = HttpOptions.model_validate(http_options)
315317
except ValidationError as e:
316318
raise ValueError(f'Invalid http_options: {e}')
317319
elif isinstance(http_options, HttpOptions):
318-
validated_http_options = http_options.model_dump()
320+
validated_http_options = http_options
319321

320322
# Retrieve implicitly set values from the environment.
321323
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
@@ -326,7 +328,7 @@ def __init__(
326328
self.api_key = api_key or env_api_key
327329

328330
self._credentials = credentials
329-
self._http_options = HttpOptionsDict()
331+
self._http_options = HttpOptions()
330332
# Initialize the lock. This lock will be used to protect access to the
331333
# credentials. This is crucial for thread safety when multiple coroutines
332334
# might be accessing the credentials at the same time.
@@ -374,40 +376,40 @@ def __init__(
374376
'AI API.'
375377
)
376378
if self.api_key or self.location == 'global':
377-
self._http_options['base_url'] = f'https://aiplatform.googleapis.com/'
379+
self._http_options.base_url = f'https://aiplatform.googleapis.com/'
378380
else:
379-
self._http_options['base_url'] = (
381+
self._http_options.base_url = (
380382
f'https://{self.location}-aiplatform.googleapis.com/'
381383
)
382-
self._http_options['api_version'] = 'v1beta1'
384+
self._http_options.api_version = 'v1beta1'
383385
else: # Implicit initialization or missing arguments.
384386
if not self.api_key:
385387
raise ValueError(
386388
'Missing key inputs argument! To use the Google AI API,'
387389
'provide (`api_key`) arguments. To use the Google Cloud API,'
388390
' provide (`vertexai`, `project` & `location`) arguments.'
389391
)
390-
self._http_options['base_url'] = (
391-
'https://generativelanguage.googleapis.com/'
392-
)
393-
self._http_options['api_version'] = 'v1beta'
392+
self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
393+
self._http_options.api_version = 'v1beta'
394394
# Default options for both clients.
395-
self._http_options['headers'] = {'Content-Type': 'application/json'}
395+
self._http_options.headers = {'Content-Type': 'application/json'}
396396
if self.api_key:
397-
self._http_options['headers']['x-goog-api-key'] = self.api_key
397+
if self._http_options.headers is not None:
398+
self._http_options.headers['x-goog-api-key'] = self.api_key
398399
# Update the http options with the user provided http options.
399400
if http_options:
400401
self._http_options = _patch_http_options(
401402
self._http_options, validated_http_options
402403
)
403404
else:
404-
_append_library_version_headers(self._http_options['headers'])
405+
if self._http_options.headers is not None:
406+
_append_library_version_headers(self._http_options.headers)
405407
# Initialize the httpx client.
406408
self._httpx_client = SyncHttpxClient()
407409
self._async_httpx_client = AsyncHttpxClient()
408410

409411
def _websocket_base_url(self):
410-
url_parts = urlparse(self._http_options['base_url'])
412+
url_parts = urlparse(self._http_options.base_url)
411413
return url_parts._replace(scheme='wss').geturl()
412414

413415
def _access_token(self) -> str:
@@ -418,9 +420,7 @@ def _access_token(self) -> str:
418420
self.project = project
419421

420422
if self._credentials:
421-
if (
422-
self._credentials.expired or not self._credentials.token
423-
):
423+
if self._credentials.expired or not self._credentials.token:
424424
# Only refresh when it needs to. Default expiration is 3600 seconds.
425425
_refresh_auth(self._credentials)
426426
if not self._credentials.token:
@@ -473,11 +473,12 @@ def _build_request(
473473
if http_options:
474474
if isinstance(http_options, HttpOptions):
475475
patched_http_options = _patch_http_options(
476-
self._http_options, http_options.model_dump()
476+
self._http_options,
477+
http_options,
477478
)
478479
else:
479480
patched_http_options = _patch_http_options(
480-
self._http_options, http_options
481+
self._http_options, HttpOptions.model_validate(http_options)
481482
)
482483
else:
483484
patched_http_options = self._http_options
@@ -496,13 +497,27 @@ def _build_request(
496497
and not self.api_key
497498
):
498499
path = f'projects/{self.project}/locations/{self.location}/' + path
500+
501+
if patched_http_options.api_version is None:
502+
versioned_path = f'/{path}'
503+
else:
504+
versioned_path = f'{patched_http_options.api_version}/{path}'
505+
506+
if (
507+
patched_http_options.base_url is None
508+
or not patched_http_options.base_url
509+
):
510+
raise ValueError('Base URL must be set.')
511+
else:
512+
base_url = patched_http_options.base_url
513+
499514
url = _join_url_path(
500-
patched_http_options.get('base_url', ''),
501-
patched_http_options.get('api_version', '') + '/' + path,
515+
base_url,
516+
versioned_path,
502517
)
503518

504-
timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
505-
'timeout', None
519+
timeout_in_seconds: Optional[Union[float, int]] = (
520+
patched_http_options.timeout
506521
)
507522
if timeout_in_seconds:
508523
# HttpOptions.timeout is in milliseconds. But httpx.Client.request()
@@ -511,10 +526,12 @@ def _build_request(
511526
else:
512527
timeout_in_seconds = None
513528

529+
if patched_http_options.headers is None:
530+
raise ValueError('Request headers must be set.')
514531
return HttpRequest(
515532
method=http_method,
516533
url=url,
517-
headers=patched_http_options['headers'],
534+
headers=patched_http_options.headers,
518535
data=request_dict,
519536
timeout=timeout_in_seconds,
520537
)
@@ -526,9 +543,7 @@ def _request(
526543
) -> HttpResponse:
527544
data: Optional[Union[str, bytes]] = None
528545
if self.vertexai and not self.api_key:
529-
http_request.headers['Authorization'] = (
530-
f'Bearer {self._access_token()}'
531-
)
546+
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
532547
if self._credentials and self._credentials.quota_project_id:
533548
http_request.headers['x-goog-user-project'] = (
534549
self._credentials.quota_project_id
@@ -616,11 +631,11 @@ async def _async_request(
616631
response.headers, response if stream else [response.text]
617632
)
618633

619-
def get_read_only_http_options(self) -> HttpOptionsDict:
620-
copied = HttpOptionsDict()
634+
def get_read_only_http_options(self) -> dict[str, Any]:
621635
if isinstance(self._http_options, BaseModel):
622-
self._http_options = self._http_options.model_dump()
623-
copied.update(self._http_options)
636+
copied = self._http_options.model_dump()
637+
else:
638+
copied = self._http_options
624639
return copied
625640

626641
def request(

google/genai/live.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -873,9 +873,9 @@ async def connect(
873873

874874
if self._api_client.api_key:
875875
api_key = self._api_client.api_key
876-
version = self._api_client._http_options['api_version']
876+
version = self._api_client._http_options.api_version
877877
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
878-
headers = self._api_client._http_options['headers']
878+
headers = self._api_client._http_options.headers
879879
request_dict = _common.convert_to_dict(
880880
self._LiveSetup_to_mldev(
881881
model=transformed_model,
@@ -894,12 +894,12 @@ async def connect(
894894
auth_req = google.auth.transport.requests.Request()
895895
creds.refresh(auth_req)
896896
bearer_token = creds.token
897-
headers = self._api_client._http_options['headers']
897+
headers = self._api_client._http_options.headers
898898
if headers is not None:
899899
headers.update({
900900
'Authorization': 'Bearer {}'.format(bearer_token),
901901
})
902-
version = self._api_client._http_options['api_version']
902+
version = self._api_client._http_options.api_version
903903
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
904904
location = self._api_client.location
905905
project = self._api_client.project

google/genai/tests/client/test_client_initialization.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def test_vertexai_apikey_from_constructor(monkeypatch):
457457
assert not client.models._api_client.project
458458
assert not client.models._api_client.location
459459
assert client.models._api_client.api_key == api_key
460-
assert "aiplatform" in client._api_client._http_options["base_url"]
460+
assert "aiplatform" in client._api_client._http_options.base_url
461461
assert isinstance(client.models._api_client, api_client.BaseApiClient)
462462

463463

@@ -477,7 +477,7 @@ def test_vertexai_apikey_from_env(monkeypatch):
477477
assert client.models._api_client.api_key == api_key
478478
assert not client.models._api_client.project
479479
assert not client.models._api_client.location
480-
assert "aiplatform" in client._api_client._http_options["base_url"]
480+
assert "aiplatform" in client._api_client._http_options.base_url
481481
assert isinstance(client.models._api_client, api_client.BaseApiClient)
482482

483483

@@ -512,7 +512,7 @@ def test_vertexai_apikey_combo1(monkeypatch):
512512
assert client.models._api_client.api_key == api_key
513513
assert not client.models._api_client.project
514514
assert not client.models._api_client.location
515-
assert "aiplatform" in client._api_client._http_options["base_url"]
515+
assert "aiplatform" in client._api_client._http_options.base_url
516516
assert isinstance(client.models._api_client, api_client.BaseApiClient)
517517

518518

@@ -532,7 +532,7 @@ def test_vertexai_apikey_combo2(monkeypatch):
532532
assert not client.models._api_client.api_key
533533
assert client.models._api_client.project == project_id
534534
assert client.models._api_client.location == location
535-
assert "aiplatform" in client._api_client._http_options["base_url"]
535+
assert "aiplatform" in client._api_client._http_options.base_url
536536
assert isinstance(client.models._api_client, api_client.BaseApiClient)
537537

538538

@@ -552,7 +552,7 @@ def test_vertexai_apikey_combo3(monkeypatch):
552552
assert not client.models._api_client.api_key
553553
assert client.models._api_client.project == project_id
554554
assert client.models._api_client.location == location
555-
assert "aiplatform" in client._api_client._http_options["base_url"]
555+
assert "aiplatform" in client._api_client._http_options.base_url
556556
assert isinstance(client.models._api_client, api_client.BaseApiClient)
557557

558558

@@ -568,7 +568,7 @@ def test_vertexai_global_endpoint(monkeypatch):
568568
assert client.models._api_client.vertexai
569569
assert client.models._api_client.project == project_id
570570
assert client.models._api_client.location == location
571-
assert client.models._api_client._http_options["base_url"] == (
571+
assert client.models._api_client._http_options.base_url == (
572572
"https://aiplatform.googleapis.com/"
573573
)
574574
assert isinstance(client.models._api_client, api_client.BaseApiClient)

google/genai/tests/client/test_client_requests.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_build_request_appends_to_user_agent_headers(monkeypatch):
9797
'test/path',
9898
{'key': 'value'},
9999
api_client.HttpOptionsDict(
100-
url='test/url',
100+
base_url='test/url',
101101
api_version='1',
102102
headers={'user-agent': 'test-user-agent'},
103103
),
@@ -115,7 +115,7 @@ def test_build_request_appends_to_goog_api_client_headers(monkeypatch):
115115
'test/path',
116116
{'key': 'value'},
117117
api_client.HttpOptionsDict(
118-
url='test/url',
118+
base_url='test/url',
119119
api_version='1',
120120
headers={'x-goog-api-client': 'test-goog-api-client'},
121121
),
@@ -136,7 +136,7 @@ def test_build_request_keeps_sdk_version_headers(monkeypatch):
136136
'test/path',
137137
{'key': 'value'},
138138
api_client.HttpOptionsDict(
139-
url='test/url',
139+
base_url='test/url',
140140
api_version='1',
141141
headers=headers_to_inject,
142142
),

0 commit comments

Comments
 (0)