Skip to content

Commit d789295

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Allow passing dataset ID in addition to full resource name in dataset methods.
PiperOrigin-RevId: 899573600
1 parent f5dc932 commit d789295

File tree

3 files changed

+88
-5
lines changed

3 files changed

+88
-5
lines changed

tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ def test_get_dataset_from_public_method(client):
4141
assert dataset.display_name == "test-display-name"
4242

4343

44+
def test_get_dataset_by_id(client):
45+
dataset = client.datasets.get_multimodal_dataset(
46+
name="8810841321427173376",
47+
)
48+
assert isinstance(dataset, types.MultimodalDataset)
49+
assert dataset.name == DATASET
50+
assert dataset.display_name == "test-display-name"
51+
52+
4453
pytestmark = pytest_helper.setup(
4554
file=__file__,
4655
globals_for_file=globals(),
@@ -67,3 +76,13 @@ async def test_get_dataset_from_public_method_async(client):
6776
assert isinstance(dataset, types.MultimodalDataset)
6877
assert dataset.name == DATASET
6978
assert dataset.display_name == "test-display-name"
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_get_dataset_by_id_async(client):
83+
dataset = await client.aio.datasets.get_multimodal_dataset(
84+
name="8810841321427173376",
85+
)
86+
assert isinstance(dataset, types.MultimodalDataset)
87+
assert dataset.name == DATASET
88+
assert dataset.display_name == "test-display-name"

vertexai/_genai/_datasets_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,11 @@ async def save_dataframe_to_bigquery_async(
262262
)
263263
await asyncio.to_thread(copy_job.result)
264264
await asyncio.to_thread(bq_client.delete_table, temp_table_id)
265+
266+
267+
def resolve_dataset_name(resource_name_or_id: str, project: str, location: str) -> str:
268+
"""Resolves a dataset name or ID to a full resource name."""
269+
resource_prefix = f"projects/{project}/locations/{location}/datasets/"
270+
if not resource_name_or_id.startswith(resource_prefix):
271+
return resource_prefix + resource_name_or_id
272+
return resource_name_or_id

vertexai/_genai/datasets.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,8 +1130,8 @@ def get_multimodal_dataset(
11301130
11311131
Args:
11321132
name:
1133-
Required. name of a multimodal dataset. The name should be in
1134-
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
1133+
Required. A fully-qualified resource name or ID of the dataset.
1134+
Example: "projects/.../locations/.../datasets/123" or "123".
11351135
config:
11361136
Optional. A configuration for getting the multimodal dataset. If not
11371137
provided, the default configuration will be used.
@@ -1145,6 +1145,10 @@ def get_multimodal_dataset(
11451145
elif not config:
11461146
config = types.VertexBaseConfig()
11471147

1148+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
1149+
if not name.startswith(resource_prefix):
1150+
name = resource_prefix + name
1151+
11481152
return self._get_multimodal_dataset(config=config, name=name)
11491153

11501154
def delete_multimodal_dataset(
@@ -1172,6 +1176,10 @@ def delete_multimodal_dataset(
11721176
elif not config:
11731177
config = types.VertexBaseConfig()
11741178

1179+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
1180+
if not name.startswith(resource_prefix):
1181+
name = resource_prefix + name
1182+
11751183
return self._delete_multimodal_dataset(config=config, name=name)
11761184

11771185
def assemble(
@@ -1207,6 +1215,10 @@ def assemble(
12071215
elif not config:
12081216
config = types.AssembleDatasetConfig()
12091217

1218+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
1219+
if not name.startswith(resource_prefix):
1220+
name = resource_prefix + name
1221+
12101222
operation = self._assemble_multimodal_dataset(
12111223
name=name,
12121224
gemini_request_read_config=gemini_request_read_config,
@@ -1255,6 +1267,10 @@ def assess_tuning_resources(
12551267
elif not config:
12561268
config = types.AssessDatasetConfig()
12571269

1270+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
1271+
if not dataset_name.startswith(resource_prefix):
1272+
dataset_name = resource_prefix + dataset_name
1273+
12581274
operation = self._assess_multimodal_dataset(
12591275
name=dataset_name,
12601276
tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig(
@@ -1316,6 +1332,10 @@ def assess_tuning_validity(
13161332
elif not config:
13171333
config = types.AssessDatasetConfig()
13181334

1335+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
1336+
if not dataset_name.startswith(resource_prefix):
1337+
dataset_name = resource_prefix + dataset_name
1338+
13191339
operation = self._assess_multimodal_dataset(
13201340
name=dataset_name,
13211341
tuning_validation_assessment_config=types.TuningValidationAssessmentConfig(
@@ -1376,6 +1396,10 @@ def assess_batch_prediction_resources(
13761396
elif not config:
13771397
config = types.AssessDatasetConfig()
13781398

1399+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
1400+
if not dataset_name.startswith(resource_prefix):
1401+
dataset_name = resource_prefix + dataset_name
1402+
13791403
operation = self._assess_multimodal_dataset(
13801404
name=dataset_name,
13811405
batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig(
@@ -1435,6 +1459,10 @@ def assess_batch_prediction_validity(
14351459
elif not config:
14361460
config = types.AssessDatasetConfig()
14371461

1462+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
1463+
if not dataset_name.startswith(resource_prefix):
1464+
dataset_name = resource_prefix + dataset_name
1465+
14381466
operation = self._assess_multimodal_dataset(
14391467
name=dataset_name,
14401468
batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig(
@@ -2352,21 +2380,25 @@ async def get_multimodal_dataset(
23522380
23532381
Args:
23542382
name:
2355-
Required. name of a multimodal dataset. The name should be in
2356-
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
2383+
Required. A fully-qualified resource name or ID of the dataset.
2384+
Example: "projects/.../locations/.../datasets/123" or "123".
23572385
config:
23582386
Optional. A configuration for getting the multimodal dataset. If not
23592387
provided, the default configuration will be used.
23602388
23612389
Returns:
2362-
A types.MultimodalDataset object representing the updated multimodal
2390+
A types.MultimodalDataset object representing the retrieved multimodal
23632391
dataset.
23642392
"""
23652393
if isinstance(config, dict):
23662394
config = types.VertexBaseConfig(**config)
23672395
elif not config:
23682396
config = types.VertexBaseConfig()
23692397

2398+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
2399+
if not name.startswith(resource_prefix):
2400+
name = resource_prefix + name
2401+
23702402
return await self._get_multimodal_dataset(config=config, name=name)
23712403

23722404
async def delete_multimodal_dataset(
@@ -2394,6 +2426,10 @@ async def delete_multimodal_dataset(
23942426
elif not config:
23952427
config = types.VertexBaseConfig()
23962428

2429+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
2430+
if not name.startswith(resource_prefix):
2431+
name = resource_prefix + name
2432+
23972433
return await self._delete_multimodal_dataset(config=config, name=name)
23982434

23992435
async def assemble(
@@ -2429,6 +2465,10 @@ async def assemble(
24292465
elif not config:
24302466
config = types.AssembleDatasetConfig()
24312467

2468+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
2469+
if not name.startswith(resource_prefix):
2470+
name = resource_prefix + name
2471+
24322472
operation = await self._assemble_multimodal_dataset(
24332473
name=name,
24342474
gemini_request_read_config=gemini_request_read_config,
@@ -2477,6 +2517,10 @@ async def assess_tuning_resources(
24772517
elif not config:
24782518
config = types.AssessDatasetConfig()
24792519

2520+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
2521+
if not dataset_name.startswith(resource_prefix):
2522+
dataset_name = resource_prefix + dataset_name
2523+
24802524
operation = await self._assess_multimodal_dataset(
24812525
name=dataset_name,
24822526
tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig(
@@ -2538,6 +2582,10 @@ async def assess_tuning_validity(
25382582
elif not config:
25392583
config = types.AssessDatasetConfig()
25402584

2585+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
2586+
if not dataset_name.startswith(resource_prefix):
2587+
dataset_name = resource_prefix + dataset_name
2588+
25412589
operation = await self._assess_multimodal_dataset(
25422590
name=dataset_name,
25432591
tuning_validation_assessment_config=types.TuningValidationAssessmentConfig(
@@ -2598,6 +2646,10 @@ async def assess_batch_prediction_resources(
25982646
elif not config:
25992647
config = types.AssessDatasetConfig()
26002648

2649+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
2650+
if not dataset_name.startswith(resource_prefix):
2651+
dataset_name = resource_prefix + dataset_name
2652+
26012653
operation = await self._assess_multimodal_dataset(
26022654
name=dataset_name,
26032655
batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig(
@@ -2657,6 +2709,10 @@ async def assess_batch_prediction_validity(
26572709
elif not config:
26582710
config = types.AssessDatasetConfig()
26592711

2712+
resource_prefix = f"projects/{self._api_client.project}/locations/{self._api_client.location}/datasets/"
2713+
if not dataset_name.startswith(resource_prefix):
2714+
dataset_name = resource_prefix + dataset_name
2715+
26602716
operation = await self._assess_multimodal_dataset(
26612717
name=dataset_name,
26622718
batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig(

0 commit comments

Comments
 (0)