|
100 | 100 | }
|
101 | 101 |
|
102 | 102 |
|
| 103 | +def _expand_input_ids_with_image_tokens( |
| 104 | + text_input_ids, |
| 105 | + prompt_attention_mask, |
| 106 | + max_sequence_length, |
| 107 | + image_token_index, |
| 108 | + image_emb_len, |
| 109 | + image_emb_start, |
| 110 | + image_emb_end, |
| 111 | + pad_token_id, |
| 112 | +): |
| 113 | + special_image_token_mask = text_input_ids == image_token_index |
| 114 | + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
| 115 | + batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index) |
| 116 | + |
| 117 | + max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1)) |
| 118 | + new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1 |
| 119 | + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
| 120 | + |
| 121 | + expanded_input_ids = torch.full( |
| 122 | + (text_input_ids.shape[0], max_expanded_length), |
| 123 | + pad_token_id, |
| 124 | + dtype=text_input_ids.dtype, |
| 125 | + device=text_input_ids.device, |
| 126 | + ) |
| 127 | + expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices] |
| 128 | + expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index |
| 129 | + |
| 130 | + expanded_attention_mask = torch.zeros( |
| 131 | + (text_input_ids.shape[0], max_expanded_length), |
| 132 | + dtype=prompt_attention_mask.dtype, |
| 133 | + device=prompt_attention_mask.device, |
| 134 | + ) |
| 135 | + attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id) |
| 136 | + expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0 |
| 137 | + expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype) |
| 138 | + position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1) |
| 139 | + |
| 140 | + return { |
| 141 | + "input_ids": expanded_input_ids, |
| 142 | + "attention_mask": expanded_attention_mask, |
| 143 | + "position_ids": position_ids, |
| 144 | + } |
| 145 | + |
| 146 | + |
103 | 147 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
104 | 148 | def retrieve_timesteps(
|
105 | 149 | scheduler,
|
@@ -251,6 +295,12 @@ def _get_llama_prompt_embeds(
|
251 | 295 | prompt = [prompt_template["template"].format(p) for p in prompt]
|
252 | 296 |
|
253 | 297 | crop_start = prompt_template.get("crop_start", None)
|
| 298 | + |
| 299 | + image_emb_len = prompt_template.get("image_emb_len", 576) |
| 300 | + image_emb_start = prompt_template.get("image_emb_start", 5) |
| 301 | + image_emb_end = prompt_template.get("image_emb_end", 581) |
| 302 | + double_return_token_id = prompt_template.get("double_return_token_id", 271) |
| 303 | + |
254 | 304 | if crop_start is None:
|
255 | 305 | prompt_template_input = self.tokenizer(
|
256 | 306 | prompt_template["template"],
|
@@ -280,19 +330,25 @@ def _get_llama_prompt_embeds(
|
280 | 330 |
|
281 | 331 | image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
|
282 | 332 |
|
| 333 | + image_token_index = self.text_encoder.config.image_token_index |
| 334 | + pad_token_id = self.text_encoder.config.pad_token_id |
| 335 | + expanded_inputs = _expand_input_ids_with_image_tokens( |
| 336 | + text_input_ids, |
| 337 | + prompt_attention_mask, |
| 338 | + max_sequence_length, |
| 339 | + image_token_index, |
| 340 | + image_emb_len, |
| 341 | + image_emb_start, |
| 342 | + image_emb_end, |
| 343 | + pad_token_id, |
| 344 | + ) |
283 | 345 | prompt_embeds = self.text_encoder(
|
284 |
| - input_ids=text_input_ids, |
285 |
| - attention_mask=prompt_attention_mask, |
286 |
| - pixel_values=image_embeds, |
| 346 | + **expanded_inputs, |
| 347 | + pixel_value=image_embeds, |
287 | 348 | output_hidden_states=True,
|
288 | 349 | ).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
289 | 350 | prompt_embeds = prompt_embeds.to(dtype=dtype)
|
290 | 351 |
|
291 |
| - image_emb_len = prompt_template.get("image_emb_len", 576) |
292 |
| - image_emb_start = prompt_template.get("image_emb_start", 5) |
293 |
| - image_emb_end = prompt_template.get("image_emb_end", 581) |
294 |
| - double_return_token_id = prompt_template.get("double_return_token_id", 271) |
295 |
| - |
296 | 352 | if crop_start is not None and crop_start > 0:
|
297 | 353 | text_crop_start = crop_start - 1 + image_emb_len
|
298 | 354 | batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
|
|
0 commit comments