@@ -68,26 +68,30 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
68
68
69
69
70
70
def _patch_http_options (
71
- options : HttpOptionsDict , patch_options : dict [str , Any ]
72
- ) -> HttpOptionsDict :
73
- # use shallow copy so we don't override the original objects.
74
- copy_option = HttpOptionsDict ()
75
- copy_option .update (options )
76
- for patch_key , patch_value in patch_options .items ():
77
- # if both are dicts, update the copy.
78
- # This is to handle cases like merging headers.
79
- if isinstance (patch_value , dict ) and isinstance (
80
- copy_option .get (patch_key , None ), dict
81
- ):
82
- copy_option [patch_key ] = {}
83
- copy_option [patch_key ].update (
84
- options [patch_key ]
85
- ) # shallow copy from original options.
86
- copy_option [patch_key ].update (patch_value )
87
- elif patch_value is not None : # Accept empty values.
88
- copy_option [patch_key ] = patch_value
89
- if copy_option ['headers' ]:
90
- _append_library_version_headers (copy_option ['headers' ])
71
+ options : HttpOptions , patch_options : HttpOptions
72
+ ) -> HttpOptions :
73
+ copy_option = options .model_copy ()
74
+
75
+ options_headers = copy_option .headers or {}
76
+ patch_options_headers = patch_options .headers or {}
77
+ copy_option .headers = {
78
+ ** options_headers ,
79
+ ** patch_options_headers ,
80
+ }
81
+
82
+ http_options_keys = HttpOptions .model_fields .keys ()
83
+
84
+ for key in http_options_keys :
85
+ if key == 'headers' :
86
+ continue
87
+ patch_value = getattr (patch_options , key , None )
88
+ if patch_value is not None :
89
+ setattr (copy_option , key , patch_value )
90
+ else :
91
+ setattr (copy_option , key , getattr (options , key ))
92
+
93
+ if copy_option .headers is not None :
94
+ _append_library_version_headers (copy_option .headers )
91
95
return copy_option
92
96
93
97
@@ -200,7 +204,7 @@ async def async_segments(self) -> AsyncIterator[Any]:
200
204
for chunk in self .response_stream :
201
205
yield json .loads (chunk ) if chunk else {}
202
206
elif self .response_stream is None :
203
- async for c in []:
207
+ async for c in []: # type: ignore[attr-defined]
204
208
yield c
205
209
else :
206
210
# Iterator of objects retrieved from the API.
@@ -306,16 +310,14 @@ def __init__(
306
310
)
307
311
308
312
# Validate http_options if it is provided.
309
- validated_http_options : dict [ str , Any ]
313
+ validated_http_options = HttpOptions ()
310
314
if isinstance (http_options , dict ):
311
315
try :
312
- validated_http_options = HttpOptions .model_validate (
313
- http_options
314
- ).model_dump ()
316
+ validated_http_options = HttpOptions .model_validate (http_options )
315
317
except ValidationError as e :
316
318
raise ValueError (f'Invalid http_options: { e } ' )
317
319
elif isinstance (http_options , HttpOptions ):
318
- validated_http_options = http_options . model_dump ()
320
+ validated_http_options = http_options
319
321
320
322
# Retrieve implicitly set values from the environment.
321
323
env_project = os .environ .get ('GOOGLE_CLOUD_PROJECT' , None )
@@ -326,7 +328,7 @@ def __init__(
326
328
self .api_key = api_key or env_api_key
327
329
328
330
self ._credentials = credentials
329
- self ._http_options = HttpOptionsDict ()
331
+ self ._http_options = HttpOptions ()
330
332
# Initialize the lock. This lock will be used to protect access to the
331
333
# credentials. This is crucial for thread safety when multiple coroutines
332
334
# might be accessing the credentials at the same time.
@@ -374,40 +376,40 @@ def __init__(
374
376
'AI API.'
375
377
)
376
378
if self .api_key or self .location == 'global' :
377
- self ._http_options [ ' base_url' ] = f'https://aiplatform.googleapis.com/'
379
+ self ._http_options . base_url = f'https://aiplatform.googleapis.com/'
378
380
else :
379
- self ._http_options [ ' base_url' ] = (
381
+ self ._http_options . base_url = (
380
382
f'https://{ self .location } -aiplatform.googleapis.com/'
381
383
)
382
- self ._http_options [ ' api_version' ] = 'v1beta1'
384
+ self ._http_options . api_version = 'v1beta1'
383
385
else : # Implicit initialization or missing arguments.
384
386
if not self .api_key :
385
387
raise ValueError (
386
388
'Missing key inputs argument! To use the Google AI API,'
387
389
'provide (`api_key`) arguments. To use the Google Cloud API,'
388
390
' provide (`vertexai`, `project` & `location`) arguments.'
389
391
)
390
- self ._http_options ['base_url' ] = (
391
- 'https://generativelanguage.googleapis.com/'
392
- )
393
- self ._http_options ['api_version' ] = 'v1beta'
392
+ self ._http_options .base_url = 'https://generativelanguage.googleapis.com/'
393
+ self ._http_options .api_version = 'v1beta'
394
394
# Default options for both clients.
395
- self ._http_options [ ' headers' ] = {'Content-Type' : 'application/json' }
395
+ self ._http_options . headers = {'Content-Type' : 'application/json' }
396
396
if self .api_key :
397
- self ._http_options ['headers' ]['x-goog-api-key' ] = self .api_key
397
+ if self ._http_options .headers is not None :
398
+ self ._http_options .headers ['x-goog-api-key' ] = self .api_key
398
399
# Update the http options with the user provided http options.
399
400
if http_options :
400
401
self ._http_options = _patch_http_options (
401
402
self ._http_options , validated_http_options
402
403
)
403
404
else :
404
- _append_library_version_headers (self ._http_options ['headers' ])
405
+ if self ._http_options .headers is not None :
406
+ _append_library_version_headers (self ._http_options .headers )
405
407
# Initialize the httpx client.
406
408
self ._httpx_client = SyncHttpxClient ()
407
409
self ._async_httpx_client = AsyncHttpxClient ()
408
410
409
411
def _websocket_base_url (self ):
410
- url_parts = urlparse (self ._http_options [ ' base_url' ] )
412
+ url_parts = urlparse (self ._http_options . base_url )
411
413
return url_parts ._replace (scheme = 'wss' ).geturl ()
412
414
413
415
def _access_token (self ) -> str :
@@ -418,9 +420,7 @@ def _access_token(self) -> str:
418
420
self .project = project
419
421
420
422
if self ._credentials :
421
- if (
422
- self ._credentials .expired or not self ._credentials .token
423
- ):
423
+ if self ._credentials .expired or not self ._credentials .token :
424
424
# Only refresh when it needs to. Default expiration is 3600 seconds.
425
425
_refresh_auth (self ._credentials )
426
426
if not self ._credentials .token :
@@ -473,11 +473,12 @@ def _build_request(
473
473
if http_options :
474
474
if isinstance (http_options , HttpOptions ):
475
475
patched_http_options = _patch_http_options (
476
- self ._http_options , http_options .model_dump ()
476
+ self ._http_options ,
477
+ http_options ,
477
478
)
478
479
else :
479
480
patched_http_options = _patch_http_options (
480
- self ._http_options , http_options
481
+ self ._http_options , HttpOptions . model_validate ( http_options )
481
482
)
482
483
else :
483
484
patched_http_options = self ._http_options
@@ -496,13 +497,27 @@ def _build_request(
496
497
and not self .api_key
497
498
):
498
499
path = f'projects/{ self .project } /locations/{ self .location } /' + path
500
+
501
+ if patched_http_options .api_version is None :
502
+ versioned_path = f'/{ path } '
503
+ else :
504
+ versioned_path = f'{ patched_http_options .api_version } /{ path } '
505
+
506
+ if (
507
+ patched_http_options .base_url is None
508
+ or not patched_http_options .base_url
509
+ ):
510
+ raise ValueError ('Base URL must be set.' )
511
+ else :
512
+ base_url = patched_http_options .base_url
513
+
499
514
url = _join_url_path (
500
- patched_http_options . get ( ' base_url' , '' ) ,
501
- patched_http_options . get ( 'api_version' , '' ) + '/' + path ,
515
+ base_url ,
516
+ versioned_path ,
502
517
)
503
518
504
- timeout_in_seconds : Optional [Union [float , int ]] = patched_http_options . get (
505
- ' timeout' , None
519
+ timeout_in_seconds : Optional [Union [float , int ]] = (
520
+ patched_http_options . timeout
506
521
)
507
522
if timeout_in_seconds :
508
523
# HttpOptions.timeout is in milliseconds. But httpx.Client.request()
@@ -511,10 +526,12 @@ def _build_request(
511
526
else :
512
527
timeout_in_seconds = None
513
528
529
+ if patched_http_options .headers is None :
530
+ raise ValueError ('Request headers must be set.' )
514
531
return HttpRequest (
515
532
method = http_method ,
516
533
url = url ,
517
- headers = patched_http_options [ ' headers' ] ,
534
+ headers = patched_http_options . headers ,
518
535
data = request_dict ,
519
536
timeout = timeout_in_seconds ,
520
537
)
@@ -526,9 +543,7 @@ def _request(
526
543
) -> HttpResponse :
527
544
data : Optional [Union [str , bytes ]] = None
528
545
if self .vertexai and not self .api_key :
529
- http_request .headers ['Authorization' ] = (
530
- f'Bearer { self ._access_token ()} '
531
- )
546
+ http_request .headers ['Authorization' ] = f'Bearer { self ._access_token ()} '
532
547
if self ._credentials and self ._credentials .quota_project_id :
533
548
http_request .headers ['x-goog-user-project' ] = (
534
549
self ._credentials .quota_project_id
@@ -616,11 +631,11 @@ async def _async_request(
616
631
response .headers , response if stream else [response .text ]
617
632
)
618
633
619
- def get_read_only_http_options (self ) -> HttpOptionsDict :
620
- copied = HttpOptionsDict ()
634
+ def get_read_only_http_options (self ) -> dict [str , Any ]:
621
635
if isinstance (self ._http_options , BaseModel ):
622
- self ._http_options = self ._http_options .model_dump ()
623
- copied .update (self ._http_options )
636
+ copied = self ._http_options .model_dump ()
637
+ else :
638
+ copied = self ._http_options
624
639
return copied
625
640
626
641
def request (
0 commit comments