Skip to content

Improve llava support & add llava_qwen2 #1324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ function getNormalizedConfig(config) {
case 'phi':
case 'phi3':
case 'phi3_v':
case 'llava_qwen2':
mapping['num_heads'] = 'num_key_value_heads';
mapping['num_layers'] = 'num_hidden_layers';
mapping['hidden_size'] = 'hidden_size';
Expand Down
131 changes: 44 additions & 87 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -887,8 +887,26 @@ function createPositionIds(model_inputs, past_key_values = null, start_index = 0
}

function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
const past_length = model_inputs.past_key_values
? Object.values(model_inputs.past_key_values)[0].dims.at(-2)
: 0;

if (!model_inputs.attention_mask) {
// If the attention mask is not provided, we attempt to infer based on provided inputs
let dims;
for (const key of ['input_ids', 'inputs_embeds', 'position_ids']) {
if (model_inputs[key]) {
dims = model_inputs[key].dims;
break;
}
}
if (!dims) {
throw new Error("attention_mask is not provided, and unable to infer its shape from model inputs.");
}
model_inputs.attention_mask = ones([dims[0], past_length + dims[1]]);
}

if (model_inputs.past_key_values) {
const past_length = Object.values(model_inputs.past_key_values)[0].dims.at(-2);
const { input_ids, attention_mask } = model_inputs;

// Keep only the unprocessed tokens:
Expand All @@ -909,24 +927,7 @@ function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, ge
}
// 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
else {
if (
// NOTE: Only used by VLMs (!= so that null matches undefined)
self.config.image_token_index != null &&
// Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint)
input_ids.data.some(x => x == self.config.image_token_index)
) {
// TODO: Support multiple image tokens
const num_image_tokens = self.config.num_image_tokens;
if (!num_image_tokens) {
throw new Error('`num_image_tokens` is missing in the model configuration.');
}

const num_new_tokens = input_ids.dims[1] - (past_length - num_image_tokens);
model_inputs.input_ids = input_ids.slice(null, [-num_new_tokens, null]);

// TODO: The attention mask should be formed from the attention mask passed in model_inputs
model_inputs.attention_mask = ones([1, past_length + num_new_tokens]);
}
}
}

Expand Down Expand Up @@ -2016,17 +2017,7 @@ export class PreTrainedModel extends Callable {

async encode_image({ pixel_values }) {
// image_inputs === { pixel_values }
const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values })).image_features;
// @ts-expect-error TS2339
if (!this.config.num_image_tokens) {
console.warn(
'The number of image tokens was not set in the model configuration. ' +
`Setting it to the number of features detected by the vision encoder (${features.dims[1]}).`
)
// @ts-expect-error TS2339
this.config.num_image_tokens = features.dims[1];
}
return features;
return (await sessionRun(this.sessions['vision_encoder'], { pixel_values })).image_features;
}

async encode_text({ input_ids }) {
Expand Down Expand Up @@ -3640,65 +3631,16 @@ export class LlavaPreTrainedModel extends PreTrainedModel {
* The LLAVA model which consists of a vision backbone and a language model.
*/
export class LlavaForConditionalGeneration extends LlavaPreTrainedModel {
_merge_input_ids_with_image_features(kwargs) {
const vision_hidden_size = kwargs.image_features.dims.at(-1);
const reshaped_image_hidden_states = kwargs.image_features.view(-1, vision_hidden_size);

_merge_input_ids_with_image_features({
inputs_embeds,
image_features,
input_ids,
attention_mask,
}) {

// @ts-expect-error TS2339
const image_token_index = this.config.image_token_index;

const idsList = input_ids.tolist();

// NOTE: we use .findIndex instead of .indexOf to perform weak comparison (==) between BigInt and Number
const indexOfImage = idsList.map(x => x.findIndex(x => x == image_token_index));

const noImages = indexOfImage.every(x => x === -1);
const allImages = indexOfImage.every(x => x !== -1);
if (!noImages && !allImages) {
// Check for padding reasons
throw new Error('Every input should contain either 0 or 1 image token.');
}

if (noImages) {
return {
inputs_embeds,
attention_mask,
}
}

const stacked = [];
const stacked_attention_mask = [];
for (let i = 0; i < indexOfImage.length; ++i) {
const index = indexOfImage[i];

const e = inputs_embeds[i];
const im = image_features[i];
const am = attention_mask[i];
stacked.push(
cat([
e.slice([0, index]),
im,
e.slice([index + 1, e.dims[0]]),
], 0)
);

stacked_attention_mask.push(
cat([
am.slice([0, index]),
ones([im.dims[0]]),
am.slice([index + 1, am.dims[0]])
], 0)
)
}

return {
inputs_embeds: stack(stacked, 0),
attention_mask: stack(stacked_attention_mask, 0),
}
return default_merge_input_ids_with_image_features({
// @ts-ignore
image_token_id: this.config.image_token_index,
...kwargs,
image_features: reshaped_image_hidden_states,
})
}
}
//////////////////////////////////////////////////
Expand Down Expand Up @@ -3839,6 +3781,20 @@ export class PaliGemmaForConditionalGeneration extends PaliGemmaPreTrainedModel
}
}

export class LlavaQwen2ForCausalLM extends LlavaPreTrainedModel {
_merge_input_ids_with_image_features(kwargs) {
const vision_hidden_size = kwargs.image_features.dims.at(-1);
const reshaped_image_hidden_states = kwargs.image_features.view(-1, vision_hidden_size);

return default_merge_input_ids_with_image_features({
// @ts-ignore
image_token_id: this.config.image_token_index,
...kwargs,
image_features: reshaped_image_hidden_states,
})
}
}

//////////////////////////////////////////////////
// Idefics3 Models
export class Idefics3PreTrainedModel extends PreTrainedModel {
Expand Down Expand Up @@ -7842,6 +7798,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
['smolvlm', ['SmolVLMForConditionalGeneration', SmolVLMForConditionalGeneration]],
['paligemma', ['PaliGemmaForConditionalGeneration', PaliGemmaForConditionalGeneration]],
['llava_qwen2', ['LlavaQwen2ForCausalLM', LlavaQwen2ForCausalLM]],
]);

const MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
Expand Down
2 changes: 1 addition & 1 deletion src/models/florence2/processing_florence2.js
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ export class Florence2Processor extends Processor {
}

const image_inputs = await this.image_processor(images, kwargs);
const text_inputs = text ? this.tokenizer(text, kwargs) : {};
const text_inputs = text ? this.tokenizer(this.construct_prompts(text), kwargs) : {};

return {
...image_inputs,
Expand Down
44 changes: 44 additions & 0 deletions src/models/llava/processing_llava.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

import { Processor } from "../../base/processing_utils.js";
import { AutoImageProcessor } from "../auto/image_processing_auto.js";
import { AutoTokenizer } from "../../tokenizers.js";

export class LlavaProcessor extends Processor {
static tokenizer_class = AutoTokenizer
static image_processor_class = AutoImageProcessor
static uses_processor_config = true;

/**
* @typedef {import('../../utils/image.js').RawImage} RawImage
*/

// `images` is required, `text` is optional
async _call(/** @type {RawImage|RawImage[]} */ images, text = null, kwargs = {}) {

const image_inputs = await this.image_processor(images, kwargs);

if (text) {
const [height, width] = image_inputs.pixel_values.dims.slice(-2);

const {image_token, patch_size, num_additional_image_tokens} = this.config;
const num_image_tokens = Math.floor(
height / patch_size
) * Math.floor(width / patch_size) + num_additional_image_tokens;

text = structuredClone(text); // Avoid modifying the original text input
if (!Array.isArray(text)) {
text = [text];
}
for (let i = 0; i < text.length; ++i) {
text[i] = text[i].replace(image_token, image_token.repeat(num_image_tokens));
}
}

const text_inputs = text ? this.tokenizer(text, kwargs) : {};

return {
...image_inputs,
...text_inputs,
}
}
}
1 change: 1 addition & 0 deletions src/models/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export * from './grounding_dino/processing_grounding_dino.js';
export * from './idefics3/processing_idefics3.js';
export * from './janus/processing_janus.js';
export * from './jina_clip/processing_jina_clip.js';
export * from './llava/processing_llava.js';
export * from './mgp_str/processing_mgp_str.js';
export * from './moonshine/processing_moonshine.js';
export * from './owlvit/processing_owlvit.js';
Expand Down
6 changes: 3 additions & 3 deletions tests/models/florence2/test_modeling_florence2.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export default () => {
{
const inputs = await processor(image, texts[0]);
const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 });
expect(generate_ids.tolist()).toEqual([[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n]]);
expect(generate_ids.tolist()).toEqual([[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n]]);
}
},
MAX_TEST_EXECUTION_TIME,
Expand All @@ -68,8 +68,8 @@ export default () => {

const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 });
expect(generate_ids.tolist()).toEqual([
[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n],
[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n],
[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n],
[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n],
]);
}
},
Expand Down
37 changes: 36 additions & 1 deletion tests/models/florence2/test_processor_florence2.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { AutoProcessor, Florence2Processor } from "../../../src/transformers.js"
import { MAX_TEST_EXECUTION_TIME, MAX_PROCESSOR_LOAD_TIME } from "../../init.js";
import { load_cached_image } from "../../asset_cache.js";
export default () => {
describe("FlorenceProcessor", () => {
describe("Florence2Processor", () => {
const model_id = "Xenova/tiny-random-Florence2ForConditionalGeneration";

/** @type {Florence2Processor} */
Expand All @@ -14,9 +14,44 @@ export default () => {
images = {
beetle: await load_cached_image("beetle"),
book_cover: await load_cached_image("book_cover"),
white_image: await load_cached_image("white_image"),
};
}, MAX_PROCESSOR_LOAD_TIME);

describe("Processing", () => {
it(
"Process image and text (no task)",
async () => {
const inputs = await processor(images.white_image, "describe");
expect(inputs.input_ids.dims).toEqual([1, 4]);
expect(inputs.input_ids.tolist()).toEqual([[0n, 45091n, 21700n, 2n]]);

expect(inputs.attention_mask.dims).toEqual([1, 4]);
expect(inputs.attention_mask.tolist()).toEqual([[1n, 1n, 1n, 1n]]);

expect(inputs.pixel_values.dims).toEqual([1, 3, 768, 768]);
expect(inputs.pixel_values.mean().item()).toBeCloseTo(2.439159870147705, 1);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"Process image and text (with task)",
async () => {
const inputs = await processor(images.white_image, "<OPEN_VOCABULARY_DETECTION>cat");
expect(inputs.input_ids.dims).toEqual([1, 9]);
expect(inputs.input_ids.tolist()).toEqual([[0n, 574n, 22486n, 4758n, 11n, 5n, 2274n, 4n, 2n]]);

expect(inputs.attention_mask.dims).toEqual([1, 9]);
expect(inputs.attention_mask.tolist()).toEqual([[1n, 1n, 1n, 1n, 1n, 1n, 1n, 1n, 1n]]);

expect(inputs.pixel_values.dims).toEqual([1, 3, 768, 768]);
expect(inputs.pixel_values.mean().item()).toBeCloseTo(2.439159870147705, 1);
},
MAX_TEST_EXECUTION_TIME,
);
});

describe("Prompt construction", () => {
it(
"Construct prompt",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export default () => {
expect(pred_boxes.dims).toEqual([1, num_queries, 4]);
expect(logits.max().item()).toBeCloseTo(56.237613677978516, 2);
expect(logits.min().item()).toEqual(-Infinity);
expect(pred_boxes.mean().item()).toEqual(0.2500016987323761);
expect(pred_boxes.mean().item()).toBeCloseTo(0.2500016987323761, 6);
},
MAX_TEST_EXECUTION_TIME,
);
Expand Down
Loading
Loading