diff --git a/src/helm/clients/vertexai_client.py b/src/helm/clients/vertexai_client.py index a7992722490..32a4764d8e7 100644 --- a/src/helm/clients/vertexai_client.py +++ b/src/helm/clients/vertexai_client.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, List, Union from helm.common.cache import CacheConfig -from helm.common.hierarchical_logger import hlog from helm.common.media_object import TEXT_TYPE from helm.common.optional_dependencies import handle_module_not_found_error from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, ErrorFlags @@ -131,12 +130,6 @@ def do_it() -> Dict[str, Any]: class VertexAIChatClient(VertexAIClient): """Client for Vertex AI chat models (e.g., Gemini). Supports multimodal prompts.""" - # Set the finish reason to this if the prompt violates the content policy - CONTENT_POLICY_VIOLATED_FINISH_REASON: str = "The prompt violates Google's content policy." - - # Gemini returns this error for certain valid requests - CONTENT_HAS_NO_PARTS_ERROR: str = "Content has no parts." - # Enum taken from: # https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#google.cloud.aiplatform.v1beta1.Candidate.FinishReason # We don't directly import this enum because it can differ between different Vertex AI library versions. @@ -149,7 +142,7 @@ class VertexAIChatClient(VertexAIClient): ] @staticmethod - def get_model(model_name: str) -> Any: + def get_model(model_name: str) -> GenerativeModel: global _models_lock global _models @@ -202,21 +195,22 @@ def do_it() -> Dict[str, Any]: ) candidates: List[Candidate] = response.candidates - # Depending on the version of the Vertex AI library and the type of content blocking, - # content blocking can show up in many ways, so this defensively handles most of these ways + # Depending on the version of the Vertex AI library and the type of prompt blocking, + # prompt blocking can show up in many ways, so this defensively handles most of these ways + if response.prompt_feedback.block_reason: + raise VertexAIContentBlockedError( + f"Prompt blocked with reason: {response.prompt_feedback.block_reason}" + ) if not candidates: - raise VertexAIContentBlockedError("No candidates in response due to content blocking") + raise VertexAIContentBlockedError(f"No candidates in response: {response}") predictions: List[Dict[str, Any]] = [] for candidate in candidates: - if ( - candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS - or not candidate.content.parts - ): - # The prediction was either blocked due to safety settings or the model stopped and returned - # nothing (which also happens when the model is blocked). - # For now, we don't cache blocked requests, because we are trying to get the - # content blocking removed. - raise VertexAIContentBlockedError("Content has no parts due to content blocking") + # Depending on the version of the Vertex AI library and the type of prompt blocking, + # content blocking can show up in many ways, so this defensively handles most of these ways + if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS: + raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}") + if not candidate.content.parts: + raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}") predictions.append({"text": candidate.content.text}) # TODO: Extract more information from the response return {"predictions": predictions} @@ -234,11 +228,11 @@ def do_it() -> Dict[str, Any]: ) response, cached = self.cache.get(cache_key, wrap_request_time(do_it)) - except VertexAIContentBlockedError: + except VertexAIContentBlockedError as e: return RequestResult( success=False, cached=False, - error="Response was empty due to content moderation filter", + error=f"Content blocked: {str(e)}", completions=[], embedding=[], error_flags=ErrorFlags(is_retriable=False, is_fatal=False), @@ -252,7 +246,7 @@ def do_it() -> Dict[str, Any]: return RequestResult( success=False, cached=False, - error="Response was empty due to content moderation filter", + error=f"Content blocked error in cached response: {str(response)}", completions=[], embedding=[], error_flags=ErrorFlags(is_retriable=False, is_fatal=False), @@ -266,7 +260,7 @@ def do_it() -> Dict[str, Any]: return RequestResult( success=False, cached=False, - error="Response was empty due to content moderation filter", + error=f"Content blocked error in cached prediction: {str(prediction)}", completions=[], embedding=[], error_flags=ErrorFlags(is_retriable=False, is_fatal=False), @@ -291,21 +285,6 @@ def do_it() -> Dict[str, Any]: ) def _make_multimodal_request(self, request: Request) -> RequestResult: - def complete_for_valid_error(error_message: str) -> RequestResult: - empty_completion = GeneratedOutput( - text="", - logprob=0, - tokens=[], - finish_reason={"reason": error_message}, - ) - return RequestResult( - success=True, - cached=False, - request_time=0, - completions=[empty_completion] * request.num_completions, - embedding=[], - ) - # Contents can either be text or a list of multimodal content made up of text, images or other content contents: Union[str, List[Union[str, Any]]] = request.prompt # Used to generate a unique cache key for this specific request @@ -349,10 +328,22 @@ def do_it() -> Dict[str, Any]: raw_response = model.generate_content( contents, generation_config=parameters, safety_settings=self.safety_settings ) - if raw_response._raw_response.prompt_feedback.block_reason != 0: - hlog(f"Content blocked for prompt: {request.multimodal_prompt}") - return {"error": self.CONTENT_POLICY_VIOLATED_FINISH_REASON} - + # Depending on the version of the Vertex AI library and the type of prompt blocking, + # prompt blocking can show up in many ways, so this defensively handles most of these ways + if raw_response.prompt_feedback.block_reason: + raise VertexAIContentBlockedError( + f"Prompt blocked with reason: {raw_response.prompt_feedback.block_reason}" + ) + if not raw_response.candidates: + raise VertexAIContentBlockedError(f"No candidates in response: {response}") + for candidate in raw_response.candidates: + # Depending on the version of the Vertex AI library and the type of prompt blocking, + # content blocking can show up in many ways, so this defensively handles most of these ways + if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS: + raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}") + if not candidate.content.parts: + raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}") + assert len(raw_response.candidates) == 1 return {"predictions": [{"text": raw_response.candidates[0].text}]} raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters} @@ -361,15 +352,27 @@ def do_it() -> Dict[str, Any]: cache_key = CachingClient.make_cache_key(raw_cache_key, request) response, cached = self.cache.get(cache_key, wrap_request_time(do_it)) - except (requests.exceptions.RequestException, ValueError) as e: - if str(e) == self.CONTENT_HAS_NO_PARTS_ERROR: - return complete_for_valid_error(self.CONTENT_HAS_NO_PARTS_ERROR) - - error: str = f"Gemini Vision error: {e}" - return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[]) + except VertexAIContentBlockedError as e: + return RequestResult( + success=False, + cached=False, + error=f"Content blocked: {str(e)}", + completions=[], + embedding=[], + error_flags=ErrorFlags(is_retriable=False, is_fatal=False), + ) if "error" in response: - return complete_for_valid_error(response["error"]) + return RequestResult( + success=False, + cached=False, + error=f"Content blocked error in cached response: {str(response)}", + completions=[], + embedding=[], + error_flags=ErrorFlags(is_retriable=False, is_fatal=False), + request_time=response["request_time"], + request_datetime=response["request_datetime"], + ) response_text = response["predictions"][0]["text"] completion = GeneratedOutput(text=response_text, logprob=0, tokens=[])