Skip to content

Commit e5bbb0e

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Save prompt safety attributes in dedicated field for generate_images
PiperOrigin-RevId: 739252571
1 parent cee65a2 commit e5bbb0e

File tree

3 files changed

+171
-37
lines changed

3 files changed

+171
-37
lines changed

google/genai/models.py

+140-32
Original file line numberDiff line numberDiff line change
@@ -3585,6 +3585,9 @@ def _SafetyAttributes_from_mldev(
35853585
to_object, ['scores'], getv(from_object, ['safetyAttributes', 'scores'])
35863586
)
35873587

3588+
if getv(from_object, ['contentType']) is not None:
3589+
setv(to_object, ['content_type'], getv(from_object, ['contentType']))
3590+
35883591
return to_object
35893592

35903593

@@ -3606,6 +3609,9 @@ def _SafetyAttributes_from_vertex(
36063609
to_object, ['scores'], getv(from_object, ['safetyAttributes', 'scores'])
36073610
)
36083611

3612+
if getv(from_object, ['contentType']) is not None:
3613+
setv(to_object, ['content_type'], getv(from_object, ['contentType']))
3614+
36093615
return to_object
36103616

36113617

@@ -3692,6 +3698,17 @@ def _GenerateImagesResponse_from_mldev(
36923698
],
36933699
)
36943700

3701+
if getv(from_object, ['positivePromptSafetyAttributes']) is not None:
3702+
setv(
3703+
to_object,
3704+
['positive_prompt_safety_attributes'],
3705+
_SafetyAttributes_from_mldev(
3706+
api_client,
3707+
getv(from_object, ['positivePromptSafetyAttributes']),
3708+
to_object,
3709+
),
3710+
)
3711+
36953712
return to_object
36963713

36973714

@@ -3711,6 +3728,17 @@ def _GenerateImagesResponse_from_vertex(
37113728
],
37123729
)
37133730

3731+
if getv(from_object, ['positivePromptSafetyAttributes']) is not None:
3732+
setv(
3733+
to_object,
3734+
['positive_prompt_safety_attributes'],
3735+
_SafetyAttributes_from_vertex(
3736+
api_client,
3737+
getv(from_object, ['positivePromptSafetyAttributes']),
3738+
to_object,
3739+
),
3740+
)
3741+
37143742
return to_object
37153743

37163744

@@ -4515,7 +4543,7 @@ def embed_content(
45154543
self._api_client._verify_response(return_value)
45164544
return return_value
45174545

4518-
def generate_images(
4546+
def _generate_images(
45194547
self,
45204548
*,
45214549
model: str,
@@ -4528,21 +4556,6 @@ def generate_images(
45284556
model (str): The model to use.
45294557
prompt (str): A text description of the images to generate.
45304558
config (GenerateImagesConfig): Configuration for generation.
4531-
4532-
Usage:
4533-
4534-
.. code-block:: python
4535-
4536-
response = client.models.generate_images(
4537-
model='imagen-3.0-generate-002',
4538-
prompt='Man with a dog',
4539-
config=types.GenerateImagesConfig(
4540-
number_of_images= 1,
4541-
include_rai_reason= True,
4542-
)
4543-
)
4544-
response.generated_images[0].image.show()
4545-
# Shows a man with a dog.
45464559
"""
45474560

45484561
parameter_model = types._GenerateImagesParameters(
@@ -5558,6 +5571,61 @@ def generate_content_stream(
55585571
automatic_function_calling_history.append(func_call_content)
55595572
automatic_function_calling_history.append(func_response_content)
55605573

5574+
def generate_images(
5575+
self,
5576+
*,
5577+
model: str,
5578+
prompt: str,
5579+
config: Optional[types.GenerateImagesConfigOrDict] = None,
5580+
) -> types.GenerateImagesResponse:
5581+
"""Generates images based on a text description and configuration.
5582+
5583+
Args:
5584+
model (str): The model to use.
5585+
prompt (str): A text description of the images to generate.
5586+
config (GenerateImagesConfig): Configuration for generation.
5587+
5588+
Usage:
5589+
5590+
.. code-block:: python
5591+
5592+
response = client.models.generate_images(
5593+
model='imagen-3.0-generate-002',
5594+
prompt='Man with a dog',
5595+
config=types.GenerateImagesConfig(
5596+
number_of_images= 1,
5597+
include_rai_reason= True,
5598+
)
5599+
)
5600+
response.generated_images[0].image.show()
5601+
# Shows a man with a dog.
5602+
"""
5603+
api_response = self._generate_images(
5604+
model=model,
5605+
prompt=prompt,
5606+
config=config,
5607+
)
5608+
positive_prompt_safety_attributes = None
5609+
generated_images = []
5610+
if not api_response or not api_response.generated_images:
5611+
return api_response
5612+
5613+
for generated_image in api_response.generated_images:
5614+
if (
5615+
generated_image.safety_attributes
5616+
and generated_image.safety_attributes.content_type
5617+
== 'Positive Prompt'
5618+
):
5619+
positive_prompt_safety_attributes = generated_image.safety_attributes
5620+
else:
5621+
generated_images.append(generated_image)
5622+
5623+
response = types.GenerateImagesResponse(
5624+
generated_images=generated_images,
5625+
positive_prompt_safety_attributes=positive_prompt_safety_attributes,
5626+
)
5627+
return response
5628+
55615629
def edit_image(
55625630
self,
55635631
*,
@@ -5956,7 +6024,7 @@ async def embed_content(
59566024
self._api_client._verify_response(return_value)
59576025
return return_value
59586026

5959-
async def generate_images(
6027+
async def _generate_images(
59606028
self,
59616029
*,
59626030
model: str,
@@ -5969,21 +6037,6 @@ async def generate_images(
59696037
model (str): The model to use.
59706038
prompt (str): A text description of the images to generate.
59716039
config (GenerateImagesConfig): Configuration for generation.
5972-
5973-
Usage:
5974-
5975-
.. code-block:: python
5976-
5977-
response = await client.aio.models.generate_images(
5978-
model='imagen-3.0-generate-002',
5979-
prompt='Man with a dog',
5980-
config=types.GenerateImagesConfig(
5981-
number_of_images= 1,
5982-
include_rai_reason= True,
5983-
)
5984-
)
5985-
response.generated_images[0].image.show()
5986-
# Shows a man with a dog.
59876040
"""
59886041

59896042
parameter_model = types._GenerateImagesParameters(
@@ -7087,6 +7140,61 @@ async def list(
70877140
config,
70887141
)
70897142

7143+
async def generate_images(
7144+
self,
7145+
*,
7146+
model: str,
7147+
prompt: str,
7148+
config: Optional[types.GenerateImagesConfigOrDict] = None,
7149+
) -> types.GenerateImagesResponse:
7150+
"""Generates images based on a text description and configuration.
7151+
7152+
Args:
7153+
model (str): The model to use.
7154+
prompt (str): A text description of the images to generate.
7155+
config (GenerateImagesConfig): Configuration for generation.
7156+
7157+
Usage:
7158+
7159+
.. code-block:: python
7160+
7161+
response = await client.aio.models.generate_images(
7162+
model='imagen-3.0-generate-002',
7163+
prompt='Man with a dog',
7164+
config=types.GenerateImagesConfig(
7165+
number_of_images= 1,
7166+
include_rai_reason= True,
7167+
)
7168+
)
7169+
response.generated_images[0].image.show()
7170+
# Shows a man with a dog.
7171+
"""
7172+
api_response = await self._generate_images(
7173+
model=model,
7174+
prompt=prompt,
7175+
config=config,
7176+
)
7177+
positive_prompt_safety_attributes = None
7178+
generated_images = []
7179+
if not api_response or not api_response.generated_images:
7180+
return api_response
7181+
7182+
for generated_image in api_response.generated_images:
7183+
if (
7184+
generated_image.safety_attributes
7185+
and generated_image.safety_attributes.content_type
7186+
== 'Positive Prompt'
7187+
):
7188+
positive_prompt_safety_attributes = generated_image.safety_attributes
7189+
else:
7190+
generated_images.append(generated_image)
7191+
7192+
response = types.GenerateImagesResponse(
7193+
generated_images=generated_images,
7194+
positive_prompt_safety_attributes=positive_prompt_safety_attributes,
7195+
)
7196+
return response
7197+
70907198
async def upscale_image(
70917199
self,
70927200
*,

google/genai/tests/models/test_generate_images.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,6 @@ async def test_simple_prompt_async(client):
182182
)
183183

184184
assert response.generated_images[0].image.image_bytes
185-
assert len(response.generated_images) == 2
185+
# TODO(tangmatthew): Re-enable this check once the bug is fixed.
186+
assert len(response.generated_images) == 1
187+
assert response.positive_prompt_safety_attributes.content_type == 'Positive Prompt'

google/genai/types.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -3512,7 +3512,8 @@ class GenerateImagesConfig(_common.BaseModel):
35123512
)
35133513
include_safety_attributes: Optional[bool] = Field(
35143514
default=None,
3515-
description="""Whether to report the safety scores of each image in the response.
3515+
description="""Whether to report the safety scores of each generated image and
3516+
the positive prompt in the response.
35163517
""",
35173518
)
35183519
include_rai_reason: Optional[bool] = Field(
@@ -3591,7 +3592,8 @@ class GenerateImagesConfigDict(TypedDict, total=False):
35913592
"""
35923593

35933594
include_safety_attributes: Optional[bool]
3594-
"""Whether to report the safety scores of each image in the response.
3595+
"""Whether to report the safety scores of each generated image and
3596+
the positive prompt in the response.
35953597
"""
35963598

35973599
include_rai_reason: Optional[bool]
@@ -3837,6 +3839,11 @@ class SafetyAttributes(_common.BaseModel):
38373839
description="""List of scores of each categories.
38383840
""",
38393841
)
3842+
content_type: Optional[str] = Field(
3843+
default=None,
3844+
description="""Internal use only.
3845+
""",
3846+
)
38403847

38413848

38423849
class SafetyAttributesDict(TypedDict, total=False):
@@ -3850,6 +3857,10 @@ class SafetyAttributesDict(TypedDict, total=False):
38503857
"""List of scores of each categories.
38513858
"""
38523859

3860+
content_type: Optional[str]
3861+
"""Internal use only.
3862+
"""
3863+
38533864

38543865
SafetyAttributesOrDict = Union[SafetyAttributes, SafetyAttributesDict]
38553866

@@ -3916,6 +3927,12 @@ class GenerateImagesResponse(_common.BaseModel):
39163927
description="""List of generated images.
39173928
""",
39183929
)
3930+
positive_prompt_safety_attributes: Optional[SafetyAttributes] = Field(
3931+
default=None,
3932+
description="""Safety attributes of the positive prompt. Only populated if
3933+
``include_safety_attributes`` is set to True.
3934+
""",
3935+
)
39193936

39203937

39213938
class GenerateImagesResponseDict(TypedDict, total=False):
@@ -3925,6 +3942,11 @@ class GenerateImagesResponseDict(TypedDict, total=False):
39253942
"""List of generated images.
39263943
"""
39273944

3945+
positive_prompt_safety_attributes: Optional[SafetyAttributesDict]
3946+
"""Safety attributes of the positive prompt. Only populated if
3947+
``include_safety_attributes`` is set to True.
3948+
"""
3949+
39283950

39293951
GenerateImagesResponseOrDict = Union[
39303952
GenerateImagesResponse, GenerateImagesResponseDict
@@ -4161,7 +4183,8 @@ class EditImageConfig(_common.BaseModel):
41614183
)
41624184
include_safety_attributes: Optional[bool] = Field(
41634185
default=None,
4164-
description="""Whether to report the safety scores of each image in the response.
4186+
description="""Whether to report the safety scores of each generated image and
4187+
the positive prompt in the response.
41654188
""",
41664189
)
41674190
include_rai_reason: Optional[bool] = Field(
@@ -4239,7 +4262,8 @@ class EditImageConfigDict(TypedDict, total=False):
42394262
"""
42404263

42414264
include_safety_attributes: Optional[bool]
4242-
"""Whether to report the safety scores of each image in the response.
4265+
"""Whether to report the safety scores of each generated image and
4266+
the positive prompt in the response.
42434267
"""
42444268

42454269
include_rai_reason: Optional[bool]

0 commit comments

Comments
 (0)