Skip to content

Commit

Permalink
Improve and unify Gemini content blocking handling
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed May 3, 2024
1 parent fdaaeec commit 0c8cae1
Showing 1 changed file with 53 additions and 50 deletions.
103 changes: 53 additions & 50 deletions src/helm/clients/vertexai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Dict, Optional, List, Union

from helm.common.cache import CacheConfig
from helm.common.hierarchical_logger import hlog
from helm.common.media_object import TEXT_TYPE
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, ErrorFlags
Expand Down Expand Up @@ -131,12 +130,6 @@ def do_it() -> Dict[str, Any]:
class VertexAIChatClient(VertexAIClient):
"""Client for Vertex AI chat models (e.g., Gemini). Supports multimodal prompts."""

# Set the finish reason to this if the prompt violates the content policy
CONTENT_POLICY_VIOLATED_FINISH_REASON: str = "The prompt violates Google's content policy."

# Gemini returns this error for certain valid requests
CONTENT_HAS_NO_PARTS_ERROR: str = "Content has no parts."

# Enum taken from:
# https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#google.cloud.aiplatform.v1beta1.Candidate.FinishReason
# We don't directly import this enum because it can differ between different Vertex AI library versions.
Expand All @@ -149,7 +142,7 @@ class VertexAIChatClient(VertexAIClient):
]

@staticmethod
def get_model(model_name: str) -> Any:
def get_model(model_name: str) -> GenerativeModel:
global _models_lock
global _models

Expand Down Expand Up @@ -202,21 +195,22 @@ def do_it() -> Dict[str, Any]:
)
candidates: List[Candidate] = response.candidates

# Depending on the version of the Vertex AI library and the type of content blocking,
# content blocking can show up in many ways, so this defensively handles most of these ways
# Depending on the version of the Vertex AI library and the type of prompt blocking,
# prompt blocking can show up in many ways, so this defensively handles most of these ways
if response.prompt_feedback.block_reason:
raise VertexAIContentBlockedError(
f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
)
if not candidates:
raise VertexAIContentBlockedError("No candidates in response due to content blocking")
raise VertexAIContentBlockedError(f"No candidates in response: {response}")
predictions: List[Dict[str, Any]] = []
for candidate in candidates:
if (
candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS
or not candidate.content.parts
):
# The prediction was either blocked due to safety settings or the model stopped and returned
# nothing (which also happens when the model is blocked).
# For now, we don't cache blocked requests, because we are trying to get the
# content blocking removed.
raise VertexAIContentBlockedError("Content has no parts due to content blocking")
# Depending on the version of the Vertex AI library and the type of prompt blocking,
# content blocking can show up in many ways, so this defensively handles most of these ways
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
if not candidate.content.parts:
raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
predictions.append({"text": candidate.content.text})
# TODO: Extract more information from the response
return {"predictions": predictions}
Expand All @@ -234,11 +228,11 @@ def do_it() -> Dict[str, Any]:
)

response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
except VertexAIContentBlockedError:
except VertexAIContentBlockedError as e:
return RequestResult(
success=False,
cached=False,
error="Response was empty due to content moderation filter",
error=f"Content blocked: {str(e)}",
completions=[],
embedding=[],
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
Expand All @@ -252,7 +246,7 @@ def do_it() -> Dict[str, Any]:
return RequestResult(
success=False,
cached=False,
error="Response was empty due to content moderation filter",
error=f"Content blocked error in cached response: {str(response)}",
completions=[],
embedding=[],
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
Expand All @@ -266,7 +260,7 @@ def do_it() -> Dict[str, Any]:
return RequestResult(
success=False,
cached=False,
error="Response was empty due to content moderation filter",
error=f"Content blocked error in cached prediction: {str(prediction)}",
completions=[],
embedding=[],
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
Expand All @@ -291,21 +285,6 @@ def do_it() -> Dict[str, Any]:
)

def _make_multimodal_request(self, request: Request) -> RequestResult:
def complete_for_valid_error(error_message: str) -> RequestResult:
empty_completion = GeneratedOutput(
text="",
logprob=0,
tokens=[],
finish_reason={"reason": error_message},
)
return RequestResult(
success=True,
cached=False,
request_time=0,
completions=[empty_completion] * request.num_completions,
embedding=[],
)

# Contents can either be text or a list of multimodal content made up of text, images or other content
contents: Union[str, List[Union[str, Any]]] = request.prompt
# Used to generate a unique cache key for this specific request
Expand Down Expand Up @@ -349,10 +328,22 @@ def do_it() -> Dict[str, Any]:
raw_response = model.generate_content(
contents, generation_config=parameters, safety_settings=self.safety_settings
)
if raw_response._raw_response.prompt_feedback.block_reason != 0:
hlog(f"Content blocked for prompt: {request.multimodal_prompt}")
return {"error": self.CONTENT_POLICY_VIOLATED_FINISH_REASON}

# Depending on the version of the Vertex AI library and the type of prompt blocking,
# prompt blocking can show up in many ways, so this defensively handles most of these ways
if raw_response.prompt_feedback.block_reason:
raise VertexAIContentBlockedError(
f"Prompt blocked with reason: {raw_response.prompt_feedback.block_reason}"
)
if not raw_response.candidates:
raise VertexAIContentBlockedError(f"No candidates in response: {response}")
for candidate in raw_response.candidates:
# Depending on the version of the Vertex AI library and the type of prompt blocking,
# content blocking can show up in many ways, so this defensively handles most of these ways
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
if not candidate.content.parts:
raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
assert len(raw_response.candidates) == 1
return {"predictions": [{"text": raw_response.candidates[0].text}]}

raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters}
Expand All @@ -361,15 +352,27 @@ def do_it() -> Dict[str, Any]:

cache_key = CachingClient.make_cache_key(raw_cache_key, request)
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
except (requests.exceptions.RequestException, ValueError) as e:
if str(e) == self.CONTENT_HAS_NO_PARTS_ERROR:
return complete_for_valid_error(self.CONTENT_HAS_NO_PARTS_ERROR)

error: str = f"Gemini Vision error: {e}"
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
except VertexAIContentBlockedError as e:
return RequestResult(
success=False,
cached=False,
error=f"Content blocked: {str(e)}",
completions=[],
embedding=[],
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
)

if "error" in response:
return complete_for_valid_error(response["error"])
return RequestResult(
success=False,
cached=False,
error=f"Content blocked error in cached response: {str(response)}",
completions=[],
embedding=[],
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
request_time=response["request_time"],
request_datetime=response["request_datetime"],
)

response_text = response["predictions"][0]["text"]
completion = GeneratedOutput(text=response_text, logprob=0, tokens=[])
Expand Down

0 comments on commit 0c8cae1

Please sign in to comment.