Skip to content

Commit 0977ae4

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI - Added GPT, Qwen, and DeepSeek models support in GenAI batch prediction
PiperOrigin-RevId: 803950323
1 parent 36013de commit 0977ae4

File tree

2 files changed

+274
-0
lines changed

2 files changed

+274
-0
lines changed

tests/unit/vertexai/test_batch_prediction.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@
5656
_TEST_CLAUDE_MODEL_RESOURCE_NAME = (
5757
f"publishers/anthropic/models/{_TEST_CLAUDE_MODEL_NAME}"
5858
)
59+
_TEST_GPT_MODEL_NAME = "gpt-oss-120b-maas"
60+
_TEST_GPT_MODEL_RESOURCE_NAME = f"publishers/openai/models/{_TEST_GPT_MODEL_NAME}"
61+
_TEST_QWEN_MODEL_NAME = "qwen3-235b-a22b-instruct-2507-maas"
62+
_TEST_QWEN_MODEL_RESOURCE_NAME = f"publishers/qwen/models/{_TEST_QWEN_MODEL_NAME}"
63+
_TEST_DEEPSEEK_MODEL_NAME = "deepseek-r1-0528-maas"
64+
_TEST_DEEPSEEK_MODEL_RESOURCE_NAME = (
65+
f"publishers/deepseek-ai/models/{_TEST_DEEPSEEK_MODEL_NAME}"
66+
)
67+
_TEST_E5_MODEL_NAME = "multilingual-e5-small-maas"
68+
_TEST_E5_MODEL_RESOURCE_NAME = f"publishers/intfloat/models/{_TEST_E5_MODEL_NAME}"
5969
_TEST_SELF_HOSTED_GEMMA_MODEL_RESOURCE_NAME = (
6070
"publishers/google/models/gemma@gemma-2b-it"
6171
)
@@ -170,6 +180,74 @@ def get_batch_prediction_job_with_claude_model_mock():
170180
yield get_job_mock
171181

172182

183+
@pytest.fixture
184+
def get_batch_prediction_job_with_gpt_model_mock():
185+
with mock.patch.object(
186+
job_service_client.JobServiceClient, "get_batch_prediction_job"
187+
) as get_job_mock:
188+
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
189+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
190+
display_name=_TEST_DISPLAY_NAME,
191+
model=_TEST_GPT_MODEL_RESOURCE_NAME,
192+
state=_TEST_JOB_STATE_SUCCESS,
193+
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
194+
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
195+
),
196+
)
197+
yield get_job_mock
198+
199+
200+
@pytest.fixture
201+
def get_batch_prediction_job_with_qwen_model_mock():
202+
with mock.patch.object(
203+
job_service_client.JobServiceClient, "get_batch_prediction_job"
204+
) as get_job_mock:
205+
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
206+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
207+
display_name=_TEST_DISPLAY_NAME,
208+
model=_TEST_QWEN_MODEL_RESOURCE_NAME,
209+
state=_TEST_JOB_STATE_SUCCESS,
210+
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
211+
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
212+
),
213+
)
214+
yield get_job_mock
215+
216+
217+
@pytest.fixture
218+
def get_batch_prediction_job_with_deepseek_model_mock():
219+
with mock.patch.object(
220+
job_service_client.JobServiceClient, "get_batch_prediction_job"
221+
) as get_job_mock:
222+
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
223+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
224+
display_name=_TEST_DISPLAY_NAME,
225+
model=_TEST_DEEPSEEK_MODEL_RESOURCE_NAME,
226+
state=_TEST_JOB_STATE_SUCCESS,
227+
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
228+
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
229+
),
230+
)
231+
yield get_job_mock
232+
233+
234+
@pytest.fixture
235+
def get_batch_prediction_job_with_e5_model_mock():
236+
with mock.patch.object(
237+
job_service_client.JobServiceClient, "get_batch_prediction_job"
238+
) as get_job_mock:
239+
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
240+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
241+
display_name=_TEST_DISPLAY_NAME,
242+
model=_TEST_E5_MODEL_RESOURCE_NAME,
243+
state=_TEST_JOB_STATE_SUCCESS,
244+
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
245+
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
246+
),
247+
)
248+
yield get_job_mock
249+
250+
173251
@pytest.fixture
174252
def get_batch_prediction_job_with_tuned_gemini_model_mock():
175253
with mock.patch.object(
@@ -315,6 +393,46 @@ def test_init_batch_prediction_job_with_claude_model(
315393
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
316394
)
317395

396+
def test_init_batch_prediction_job_with_gpt_model(
397+
self,
398+
get_batch_prediction_job_with_gpt_model_mock,
399+
):
400+
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
401+
402+
get_batch_prediction_job_with_gpt_model_mock.assert_called_once_with(
403+
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
404+
)
405+
406+
def test_init_batch_prediction_job_with_qwen_model(
407+
self,
408+
get_batch_prediction_job_with_qwen_model_mock,
409+
):
410+
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
411+
412+
get_batch_prediction_job_with_qwen_model_mock.assert_called_once_with(
413+
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
414+
)
415+
416+
def test_init_batch_prediction_job_with_deepseek_model(
417+
self,
418+
get_batch_prediction_job_with_deepseek_model_mock,
419+
):
420+
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
421+
422+
get_batch_prediction_job_with_deepseek_model_mock.assert_called_once_with(
423+
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
424+
)
425+
426+
def test_init_batch_prediction_job_with_e5_model(
427+
self,
428+
get_batch_prediction_job_with_e5_model_mock,
429+
):
430+
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
431+
432+
get_batch_prediction_job_with_e5_model_mock.assert_called_once_with(
433+
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
434+
)
435+
318436
def test_init_batch_prediction_job_with_tuned_gemini_model(
319437
self,
320438
get_batch_prediction_job_with_tuned_gemini_model_mock,
@@ -576,6 +694,138 @@ def test_submit_batch_prediction_job_with_claude_model(
576694
timeout=None,
577695
)
578696

697+
def test_submit_batch_prediction_job_with_gpt_model(
698+
self,
699+
create_batch_prediction_job_mock,
700+
):
701+
job = batch_prediction.BatchPredictionJob.submit(
702+
source_model=_TEST_GPT_MODEL_RESOURCE_NAME,
703+
input_dataset=_TEST_BQ_INPUT_URI,
704+
)
705+
706+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
707+
708+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
709+
display_name=_TEST_DISPLAY_NAME,
710+
model=_TEST_GPT_MODEL_RESOURCE_NAME,
711+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
712+
instances_format="bigquery",
713+
bigquery_source=gca_io_compat.BigQuerySource(
714+
input_uri=_TEST_BQ_INPUT_URI
715+
),
716+
),
717+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
718+
bigquery_destination=gca_io_compat.BigQueryDestination(
719+
output_uri=_TEST_BQ_OUTPUT_PREFIX
720+
),
721+
predictions_format="bigquery",
722+
),
723+
)
724+
create_batch_prediction_job_mock.assert_called_once_with(
725+
parent=_TEST_PARENT,
726+
batch_prediction_job=expected_gapic_batch_prediction_job,
727+
timeout=None,
728+
)
729+
730+
def test_submit_batch_prediction_job_with_qwen_model(
731+
self,
732+
create_batch_prediction_job_mock,
733+
):
734+
job = batch_prediction.BatchPredictionJob.submit(
735+
source_model=_TEST_QWEN_MODEL_RESOURCE_NAME,
736+
input_dataset=_TEST_BQ_INPUT_URI,
737+
)
738+
739+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
740+
741+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
742+
display_name=_TEST_DISPLAY_NAME,
743+
model=_TEST_QWEN_MODEL_RESOURCE_NAME,
744+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
745+
instances_format="bigquery",
746+
bigquery_source=gca_io_compat.BigQuerySource(
747+
input_uri=_TEST_BQ_INPUT_URI
748+
),
749+
),
750+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
751+
bigquery_destination=gca_io_compat.BigQueryDestination(
752+
output_uri=_TEST_BQ_OUTPUT_PREFIX
753+
),
754+
predictions_format="bigquery",
755+
),
756+
)
757+
create_batch_prediction_job_mock.assert_called_once_with(
758+
parent=_TEST_PARENT,
759+
batch_prediction_job=expected_gapic_batch_prediction_job,
760+
timeout=None,
761+
)
762+
763+
def test_submit_batch_prediction_job_with_deepseek_model(
764+
self,
765+
create_batch_prediction_job_mock,
766+
):
767+
job = batch_prediction.BatchPredictionJob.submit(
768+
source_model=_TEST_DEEPSEEK_MODEL_RESOURCE_NAME,
769+
input_dataset=_TEST_BQ_INPUT_URI,
770+
)
771+
772+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
773+
774+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
775+
display_name=_TEST_DISPLAY_NAME,
776+
model=_TEST_DEEPSEEK_MODEL_RESOURCE_NAME,
777+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
778+
instances_format="bigquery",
779+
bigquery_source=gca_io_compat.BigQuerySource(
780+
input_uri=_TEST_BQ_INPUT_URI
781+
),
782+
),
783+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
784+
bigquery_destination=gca_io_compat.BigQueryDestination(
785+
output_uri=_TEST_BQ_OUTPUT_PREFIX
786+
),
787+
predictions_format="bigquery",
788+
),
789+
)
790+
create_batch_prediction_job_mock.assert_called_once_with(
791+
parent=_TEST_PARENT,
792+
batch_prediction_job=expected_gapic_batch_prediction_job,
793+
timeout=None,
794+
)
795+
796+
def test_submit_batch_prediction_job_with_e5_model(
797+
self,
798+
create_batch_prediction_job_mock,
799+
):
800+
job = batch_prediction.BatchPredictionJob.submit(
801+
source_model=_TEST_E5_MODEL_RESOURCE_NAME,
802+
input_dataset=_TEST_BQ_INPUT_URI,
803+
)
804+
805+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
806+
807+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
808+
display_name=_TEST_DISPLAY_NAME,
809+
model=_TEST_E5_MODEL_RESOURCE_NAME,
810+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
811+
instances_format="bigquery",
812+
bigquery_source=gca_io_compat.BigQuerySource(
813+
input_uri=_TEST_BQ_INPUT_URI
814+
),
815+
),
816+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
817+
bigquery_destination=gca_io_compat.BigQueryDestination(
818+
output_uri=_TEST_BQ_OUTPUT_PREFIX
819+
),
820+
predictions_format="bigquery",
821+
),
822+
)
823+
create_batch_prediction_job_mock.assert_called_once_with(
824+
parent=_TEST_PARENT,
825+
batch_prediction_job=expected_gapic_batch_prediction_job,
826+
timeout=None,
827+
)
828+
579829
@pytest.mark.usefixtures("create_batch_prediction_job_mock")
580830
def test_submit_batch_prediction_job_with_tuned_model(
581831
self,

vertexai/batch_prediction/_batch_prediction.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
_GEMINI_MODEL_PATTERN = r"publishers/google/models/gemini"
3636
_LLAMA_MODEL_PATTERN = r"publishers/meta/models/llama"
3737
_CLAUDE_MODEL_PATTERN = r"publishers/anthropic/models/claude"
38+
_GPT_MODEL_PATTERN = r"publishers/openai/models/gpt"
39+
_QWEN_MODEL_PATTERN = r"publishers/qwen/models/qwen"
40+
_DEEPSEEK_MODEL_PATTERN = r"publishers/deepseek-ai/models/deepseek"
41+
_E5_MODEL_PATTERN = r"publishers/intfloat/models/multilingual"
3842
_GEMINI_TUNED_MODEL_PATTERN = r"^projects/[0-9]+?/locations/[0-9a-z-]+?/models/[0-9]+?$"
3943

4044

@@ -318,6 +322,10 @@ def _reconcile_model_name(cls, model_name: str) -> str:
318322
or model_name.startswith("publishers/google/models/")
319323
or model_name.startswith("publishers/meta/models/")
320324
or model_name.startswith("publishers/anthropic/models/")
325+
or model_name.startswith("publishers/openai/models/")
326+
or model_name.startswith("publishers/qwen/models/")
327+
or model_name.startswith("publishers/deepseek-ai/models/")
328+
or model_name.startswith("publishers/intfloat/models/")
321329
or re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name)
322330
):
323331
return model_name
@@ -348,6 +356,22 @@ def _is_genai_model(cls, model_name: str) -> bool:
348356
# Model is a claude model.
349357
return True
350358

359+
if re.search(_GPT_MODEL_PATTERN, model_name):
360+
# Model is a GPT model.
361+
return True
362+
363+
if re.search(_QWEN_MODEL_PATTERN, model_name):
364+
# Model is a Qwen model.
365+
return True
366+
367+
if re.search(_DEEPSEEK_MODEL_PATTERN, model_name):
368+
# Model is a DeepSeek model.
369+
return True
370+
371+
if re.search(_E5_MODEL_PATTERN, model_name):
372+
# Model is an E5 model.
373+
return True
374+
351375
if re.match(
352376
r"^publishers/(?P<publisher>[^/]+)/models/(?P<model>[^@]+)@(?P<version>[^@]+)$",
353377
model_name,

0 commit comments

Comments
 (0)