Skip to content

Commit daf0a4d

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Allow fetching dataset with dataset_id in get_multimodal_dataset.
FUTURE_COPYBARA_INTEGRATE_REVIEW=#6572 from googleapis:release-please--branches--main 1448130 PiperOrigin-RevId: 899573600
1 parent 9722998 commit daf0a4d

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ def test_get_dataset_from_public_method(client):
4141
assert dataset.display_name == "test-display-name"
4242

4343

44+
def test_get_dataset_by_id(client):
45+
dataset = client.datasets.get_multimodal_dataset(
46+
dataset_id="8810841321427173376",
47+
)
48+
assert isinstance(dataset, types.MultimodalDataset)
49+
assert dataset.name == DATASET
50+
assert dataset.display_name == "test-display-name"
51+
52+
4453
pytestmark = pytest_helper.setup(
4554
file=__file__,
4655
globals_for_file=globals(),
@@ -67,3 +76,13 @@ async def test_get_dataset_from_public_method_async(client):
6776
assert isinstance(dataset, types.MultimodalDataset)
6877
assert dataset.name == DATASET
6978
assert dataset.display_name == "test-display-name"
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_get_dataset_by_id_async(client):
83+
dataset = await client.aio.datasets.get_multimodal_dataset(
84+
dataset_id="8810841321427173376",
85+
)
86+
assert isinstance(dataset, types.MultimodalDataset)
87+
assert dataset.name == DATASET
88+
assert dataset.display_name == "test-display-name"

vertexai/_genai/datasets.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,15 +1123,19 @@ def update_multimodal_dataset(
11231123
def get_multimodal_dataset(
11241124
self,
11251125
*,
1126-
name: str,
1126+
name: Optional[str] = None,
1127+
dataset_id: Optional[str] = None,
11271128
config: Optional[types.VertexBaseConfigOrDict] = None,
11281129
) -> types.MultimodalDataset:
11291130
"""Gets a multimodal dataset.
11301131
11311132
Args:
11321133
name:
1133-
Required. name of a multimodal dataset. The name should be in
1134+
Optional. name of a multimodal dataset. The name should be in
11341135
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
1136+
dataset_id:
1137+
Optional. The ID of the dataset to get. If provided, the full name
1138+
will be constructed using the client's project and location.
11351139
config:
11361140
Optional. A configuration for getting the multimodal dataset. If not
11371141
provided, the default configuration will be used.
@@ -1145,6 +1149,14 @@ def get_multimodal_dataset(
11451149
elif not config:
11461150
config = types.VertexBaseConfig()
11471151

1152+
if not name and not dataset_id:
1153+
raise ValueError("Either 'name' or 'dataset_id' must be provided.")
1154+
if name and dataset_id:
1155+
raise ValueError("Only one of 'name' or 'dataset_id' can be provided.")
1156+
1157+
if dataset_id:
1158+
name = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/{dataset_id}"
1159+
11481160
return self._get_multimodal_dataset(config=config, name=name)
11491161

11501162
def delete_multimodal_dataset(
@@ -2345,15 +2357,19 @@ async def update_multimodal_dataset(
23452357
async def get_multimodal_dataset(
23462358
self,
23472359
*,
2348-
name: str,
2360+
name: Optional[str] = None,
2361+
dataset_id: Optional[str] = None,
23492362
config: Optional[types.VertexBaseConfigOrDict] = None,
23502363
) -> types.MultimodalDataset:
23512364
"""Gets a multimodal dataset.
23522365
23532366
Args:
23542367
name:
2355-
Required. name of a multimodal dataset. The name should be in
2368+
Optional. name of a multimodal dataset. The name should be in
23562369
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
2370+
dataset_id:
2371+
Optional. The ID of the dataset to get. If provided, the full name
2372+
will be constructed using the client's project and location.
23572373
config:
23582374
Optional. A configuration for getting the multimodal dataset. If not
23592375
provided, the default configuration will be used.
@@ -2367,6 +2383,14 @@ async def get_multimodal_dataset(
23672383
elif not config:
23682384
config = types.VertexBaseConfig()
23692385

2386+
if not name and not dataset_id:
2387+
raise ValueError("Either 'name' or 'dataset_id' must be provided.")
2388+
if name and dataset_id:
2389+
raise ValueError("Only one of 'name' or 'dataset_id' can be provided.")
2390+
2391+
if dataset_id:
2392+
name = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/{dataset_id}"
2393+
23702394
return await self._get_multimodal_dataset(config=config, name=name)
23712395

23722396
async def delete_multimodal_dataset(

0 commit comments

Comments
 (0)