|
56 | 56 | _TEST_CLAUDE_MODEL_RESOURCE_NAME = (
|
57 | 57 | f"publishers/anthropic/models/{_TEST_CLAUDE_MODEL_NAME}"
|
58 | 58 | )
|
| 59 | +_TEST_GPT_MODEL_NAME = "gpt-oss-120b-maas" |
| 60 | +_TEST_GPT_MODEL_RESOURCE_NAME = f"publishers/openai/models/{_TEST_GPT_MODEL_NAME}" |
| 61 | +_TEST_QWEN_MODEL_NAME = "qwen3-235b-a22b-instruct-2507-maas" |
| 62 | +_TEST_QWEN_MODEL_RESOURCE_NAME = f"publishers/qwen/models/{_TEST_QWEN_MODEL_NAME}" |
| 63 | +_TEST_DEEPSEEK_MODEL_NAME = "deepseek-r1-0528-maas" |
| 64 | +_TEST_DEEPSEEK_MODEL_RESOURCE_NAME = ( |
| 65 | + f"publishers/deepseek-ai/models/{_TEST_DEEPSEEK_MODEL_NAME}" |
| 66 | +) |
| 67 | +_TEST_E5_MODEL_NAME = "multilingual-e5-small-maas" |
| 68 | +_TEST_E5_MODEL_RESOURCE_NAME = f"publishers/intfloat/models/{_TEST_E5_MODEL_NAME}" |
59 | 69 | _TEST_SELF_HOSTED_GEMMA_MODEL_RESOURCE_NAME = (
|
60 | 70 | "publishers/google/models/gemma@gemma-2b-it"
|
61 | 71 | )
|
@@ -170,6 +180,74 @@ def get_batch_prediction_job_with_claude_model_mock():
|
170 | 180 | yield get_job_mock
|
171 | 181 |
|
172 | 182 |
|
| 183 | +@pytest.fixture |
| 184 | +def get_batch_prediction_job_with_gpt_model_mock(): |
| 185 | + with mock.patch.object( |
| 186 | + job_service_client.JobServiceClient, "get_batch_prediction_job" |
| 187 | + ) as get_job_mock: |
| 188 | + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 189 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, |
| 190 | + display_name=_TEST_DISPLAY_NAME, |
| 191 | + model=_TEST_GPT_MODEL_RESOURCE_NAME, |
| 192 | + state=_TEST_JOB_STATE_SUCCESS, |
| 193 | + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( |
| 194 | + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX |
| 195 | + ), |
| 196 | + ) |
| 197 | + yield get_job_mock |
| 198 | + |
| 199 | + |
| 200 | +@pytest.fixture |
| 201 | +def get_batch_prediction_job_with_qwen_model_mock(): |
| 202 | + with mock.patch.object( |
| 203 | + job_service_client.JobServiceClient, "get_batch_prediction_job" |
| 204 | + ) as get_job_mock: |
| 205 | + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 206 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, |
| 207 | + display_name=_TEST_DISPLAY_NAME, |
| 208 | + model=_TEST_QWEN_MODEL_RESOURCE_NAME, |
| 209 | + state=_TEST_JOB_STATE_SUCCESS, |
| 210 | + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( |
| 211 | + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX |
| 212 | + ), |
| 213 | + ) |
| 214 | + yield get_job_mock |
| 215 | + |
| 216 | + |
| 217 | +@pytest.fixture |
| 218 | +def get_batch_prediction_job_with_deepseek_model_mock(): |
| 219 | + with mock.patch.object( |
| 220 | + job_service_client.JobServiceClient, "get_batch_prediction_job" |
| 221 | + ) as get_job_mock: |
| 222 | + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 223 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, |
| 224 | + display_name=_TEST_DISPLAY_NAME, |
| 225 | + model=_TEST_DEEPSEEK_MODEL_RESOURCE_NAME, |
| 226 | + state=_TEST_JOB_STATE_SUCCESS, |
| 227 | + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( |
| 228 | + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX |
| 229 | + ), |
| 230 | + ) |
| 231 | + yield get_job_mock |
| 232 | + |
| 233 | + |
| 234 | +@pytest.fixture |
| 235 | +def get_batch_prediction_job_with_e5_model_mock(): |
| 236 | + with mock.patch.object( |
| 237 | + job_service_client.JobServiceClient, "get_batch_prediction_job" |
| 238 | + ) as get_job_mock: |
| 239 | + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 240 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, |
| 241 | + display_name=_TEST_DISPLAY_NAME, |
| 242 | + model=_TEST_E5_MODEL_RESOURCE_NAME, |
| 243 | + state=_TEST_JOB_STATE_SUCCESS, |
| 244 | + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( |
| 245 | + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX |
| 246 | + ), |
| 247 | + ) |
| 248 | + yield get_job_mock |
| 249 | + |
| 250 | + |
173 | 251 | @pytest.fixture
|
174 | 252 | def get_batch_prediction_job_with_tuned_gemini_model_mock():
|
175 | 253 | with mock.patch.object(
|
@@ -315,6 +393,46 @@ def test_init_batch_prediction_job_with_claude_model(
|
315 | 393 | name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
|
316 | 394 | )
|
317 | 395 |
|
| 396 | + def test_init_batch_prediction_job_with_gpt_model( |
| 397 | + self, |
| 398 | + get_batch_prediction_job_with_gpt_model_mock, |
| 399 | + ): |
| 400 | + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) |
| 401 | + |
| 402 | + get_batch_prediction_job_with_gpt_model_mock.assert_called_once_with( |
| 403 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY |
| 404 | + ) |
| 405 | + |
| 406 | + def test_init_batch_prediction_job_with_qwen_model( |
| 407 | + self, |
| 408 | + get_batch_prediction_job_with_qwen_model_mock, |
| 409 | + ): |
| 410 | + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) |
| 411 | + |
| 412 | + get_batch_prediction_job_with_qwen_model_mock.assert_called_once_with( |
| 413 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY |
| 414 | + ) |
| 415 | + |
| 416 | + def test_init_batch_prediction_job_with_deepseek_model( |
| 417 | + self, |
| 418 | + get_batch_prediction_job_with_deepseek_model_mock, |
| 419 | + ): |
| 420 | + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) |
| 421 | + |
| 422 | + get_batch_prediction_job_with_deepseek_model_mock.assert_called_once_with( |
| 423 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY |
| 424 | + ) |
| 425 | + |
| 426 | + def test_init_batch_prediction_job_with_e5_model( |
| 427 | + self, |
| 428 | + get_batch_prediction_job_with_e5_model_mock, |
| 429 | + ): |
| 430 | + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) |
| 431 | + |
| 432 | + get_batch_prediction_job_with_e5_model_mock.assert_called_once_with( |
| 433 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY |
| 434 | + ) |
| 435 | + |
318 | 436 | def test_init_batch_prediction_job_with_tuned_gemini_model(
|
319 | 437 | self,
|
320 | 438 | get_batch_prediction_job_with_tuned_gemini_model_mock,
|
@@ -576,6 +694,138 @@ def test_submit_batch_prediction_job_with_claude_model(
|
576 | 694 | timeout=None,
|
577 | 695 | )
|
578 | 696 |
|
| 697 | + def test_submit_batch_prediction_job_with_gpt_model( |
| 698 | + self, |
| 699 | + create_batch_prediction_job_mock, |
| 700 | + ): |
| 701 | + job = batch_prediction.BatchPredictionJob.submit( |
| 702 | + source_model=_TEST_GPT_MODEL_RESOURCE_NAME, |
| 703 | + input_dataset=_TEST_BQ_INPUT_URI, |
| 704 | + ) |
| 705 | + |
| 706 | + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB |
| 707 | + |
| 708 | + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 709 | + display_name=_TEST_DISPLAY_NAME, |
| 710 | + model=_TEST_GPT_MODEL_RESOURCE_NAME, |
| 711 | + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( |
| 712 | + instances_format="bigquery", |
| 713 | + bigquery_source=gca_io_compat.BigQuerySource( |
| 714 | + input_uri=_TEST_BQ_INPUT_URI |
| 715 | + ), |
| 716 | + ), |
| 717 | + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( |
| 718 | + bigquery_destination=gca_io_compat.BigQueryDestination( |
| 719 | + output_uri=_TEST_BQ_OUTPUT_PREFIX |
| 720 | + ), |
| 721 | + predictions_format="bigquery", |
| 722 | + ), |
| 723 | + ) |
| 724 | + create_batch_prediction_job_mock.assert_called_once_with( |
| 725 | + parent=_TEST_PARENT, |
| 726 | + batch_prediction_job=expected_gapic_batch_prediction_job, |
| 727 | + timeout=None, |
| 728 | + ) |
| 729 | + |
| 730 | + def test_submit_batch_prediction_job_with_qwen_model( |
| 731 | + self, |
| 732 | + create_batch_prediction_job_mock, |
| 733 | + ): |
| 734 | + job = batch_prediction.BatchPredictionJob.submit( |
| 735 | + source_model=_TEST_QWEN_MODEL_RESOURCE_NAME, |
| 736 | + input_dataset=_TEST_BQ_INPUT_URI, |
| 737 | + ) |
| 738 | + |
| 739 | + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB |
| 740 | + |
| 741 | + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 742 | + display_name=_TEST_DISPLAY_NAME, |
| 743 | + model=_TEST_QWEN_MODEL_RESOURCE_NAME, |
| 744 | + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( |
| 745 | + instances_format="bigquery", |
| 746 | + bigquery_source=gca_io_compat.BigQuerySource( |
| 747 | + input_uri=_TEST_BQ_INPUT_URI |
| 748 | + ), |
| 749 | + ), |
| 750 | + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( |
| 751 | + bigquery_destination=gca_io_compat.BigQueryDestination( |
| 752 | + output_uri=_TEST_BQ_OUTPUT_PREFIX |
| 753 | + ), |
| 754 | + predictions_format="bigquery", |
| 755 | + ), |
| 756 | + ) |
| 757 | + create_batch_prediction_job_mock.assert_called_once_with( |
| 758 | + parent=_TEST_PARENT, |
| 759 | + batch_prediction_job=expected_gapic_batch_prediction_job, |
| 760 | + timeout=None, |
| 761 | + ) |
| 762 | + |
| 763 | + def test_submit_batch_prediction_job_with_deepseek_model( |
| 764 | + self, |
| 765 | + create_batch_prediction_job_mock, |
| 766 | + ): |
| 767 | + job = batch_prediction.BatchPredictionJob.submit( |
| 768 | + source_model=_TEST_DEEPSEEK_MODEL_RESOURCE_NAME, |
| 769 | + input_dataset=_TEST_BQ_INPUT_URI, |
| 770 | + ) |
| 771 | + |
| 772 | + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB |
| 773 | + |
| 774 | + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 775 | + display_name=_TEST_DISPLAY_NAME, |
| 776 | + model=_TEST_DEEPSEEK_MODEL_RESOURCE_NAME, |
| 777 | + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( |
| 778 | + instances_format="bigquery", |
| 779 | + bigquery_source=gca_io_compat.BigQuerySource( |
| 780 | + input_uri=_TEST_BQ_INPUT_URI |
| 781 | + ), |
| 782 | + ), |
| 783 | + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( |
| 784 | + bigquery_destination=gca_io_compat.BigQueryDestination( |
| 785 | + output_uri=_TEST_BQ_OUTPUT_PREFIX |
| 786 | + ), |
| 787 | + predictions_format="bigquery", |
| 788 | + ), |
| 789 | + ) |
| 790 | + create_batch_prediction_job_mock.assert_called_once_with( |
| 791 | + parent=_TEST_PARENT, |
| 792 | + batch_prediction_job=expected_gapic_batch_prediction_job, |
| 793 | + timeout=None, |
| 794 | + ) |
| 795 | + |
| 796 | + def test_submit_batch_prediction_job_with_e5_model( |
| 797 | + self, |
| 798 | + create_batch_prediction_job_mock, |
| 799 | + ): |
| 800 | + job = batch_prediction.BatchPredictionJob.submit( |
| 801 | + source_model=_TEST_E5_MODEL_RESOURCE_NAME, |
| 802 | + input_dataset=_TEST_BQ_INPUT_URI, |
| 803 | + ) |
| 804 | + |
| 805 | + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB |
| 806 | + |
| 807 | + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 808 | + display_name=_TEST_DISPLAY_NAME, |
| 809 | + model=_TEST_E5_MODEL_RESOURCE_NAME, |
| 810 | + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( |
| 811 | + instances_format="bigquery", |
| 812 | + bigquery_source=gca_io_compat.BigQuerySource( |
| 813 | + input_uri=_TEST_BQ_INPUT_URI |
| 814 | + ), |
| 815 | + ), |
| 816 | + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( |
| 817 | + bigquery_destination=gca_io_compat.BigQueryDestination( |
| 818 | + output_uri=_TEST_BQ_OUTPUT_PREFIX |
| 819 | + ), |
| 820 | + predictions_format="bigquery", |
| 821 | + ), |
| 822 | + ) |
| 823 | + create_batch_prediction_job_mock.assert_called_once_with( |
| 824 | + parent=_TEST_PARENT, |
| 825 | + batch_prediction_job=expected_gapic_batch_prediction_job, |
| 826 | + timeout=None, |
| 827 | + ) |
| 828 | + |
579 | 829 | @pytest.mark.usefixtures("create_batch_prediction_job_mock")
|
580 | 830 | def test_submit_batch_prediction_job_with_tuned_model(
|
581 | 831 | self,
|
|
0 commit comments