Skip to content

Commit 01b6e51

Browse files
authored
Properly close VertexAI HTTP client at end of live test (#2768)
1 parent 66fa21a commit 01b6e51

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

tests/test_live.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,30 @@ def gemini(_: httpx.AsyncClient, _tmp_path: Path) -> Model:
3434
return GoogleModel('gemini-1.5-pro')
3535

3636

37-
def vertexai(_: httpx.AsyncClient, tmp_path: Path) -> Model:
37+
def vertexai(http_client: httpx.AsyncClient, tmp_path: Path) -> Model:
3838
from google.oauth2 import service_account
3939

4040
from pydantic_ai.models.google import GoogleModel
4141
from pydantic_ai.providers.google import GoogleProvider
4242

43-
service_account_content = os.environ['GOOGLE_SERVICE_ACCOUNT_CONTENT']
44-
project_id = json.loads(service_account_content)['project_id']
45-
service_account_path = tmp_path / 'service_account.json'
46-
service_account_path.write_text(service_account_content)
43+
if service_account_path := os.environ.get('GOOGLE_APPLICATION_CREDENTIALS'):
44+
project_id = json.loads(Path(service_account_path).read_text())['project_id']
45+
elif service_account_content := os.environ.get('GOOGLE_SERVICE_ACCOUNT_CONTENT'):
46+
project_id = json.loads(service_account_content)['project_id']
47+
service_account_path = tmp_path / 'service_account.json'
48+
service_account_path.write_text(service_account_content)
49+
else:
50+
pytest.skip(
51+
'VertexAI live test requires GOOGLE_APPLICATION_CREDENTIALS or GOOGLE_SERVICE_ACCOUNT_CONTENT to be set'
52+
)
4753

4854
credentials = service_account.Credentials.from_service_account_file( # type: ignore[reportUnknownReturnType]
4955
service_account_path,
5056
scopes=['https://www.googleapis.com/auth/cloud-platform'],
5157
)
5258
provider = GoogleProvider(credentials=credentials, project=project_id)
53-
return GoogleModel('gemini-1.5-flash', provider=provider)
59+
provider.client.aio._api_client._async_httpx_client = http_client # type: ignore
60+
return GoogleModel('gemini-2.0-flash', provider=provider)
5461

5562

5663
def groq(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:

0 commit comments

Comments
 (0)