Skip to content

[Model builder] Add option to exclude cache in inputs and outputs #1162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
elif self.include_hidden_states:
self.output_names = ["hidden_states"] + self.output_names

self.exclude_cache = "exclude_cache" in extra_options
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.exclude_cache = "exclude_cache" in extra_options
self.exclude_cache = extra_options.get("exclude_cache", False)


# Store names of nodes already created
self.node_names = set()

Expand Down Expand Up @@ -557,19 +559,20 @@ def make_inputs_and_outputs(self):
shape = self.output_shapes[name]
outputs.append(helper.make_tensor_value_info(name, dtype, shape=shape))

# Add KV cache to inputs and outputs
for i in range(self.num_layers):
# Add KV cache to inputs
key_name = f"past_key_values.{i}.key"
inputs.append(helper.make_tensor_value_info(key_name, self.input_types["past_key_values.key"], shape=self.input_shapes["past_key_values.key"]))
value_name = f"past_key_values.{i}.value"
inputs.append(helper.make_tensor_value_info(value_name, self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"]))

# Add KV cache to outputs
key_name = f"present.{i}.key"
outputs.append(helper.make_tensor_value_info(key_name, self.output_types["present.key"], shape=self.output_shapes["present.key"]))
value_name = f"present.{i}.value"
outputs.append(helper.make_tensor_value_info(value_name, self.output_types["present.value"], shape=self.output_shapes["present.value"]))
if not self.exclude_cache:
# Add KV cache to inputs and outputs
for i in range(self.num_layers):
# Add KV cache to inputs
key_name = f"past_key_values.{i}.key"
inputs.append(helper.make_tensor_value_info(key_name, self.input_types["past_key_values.key"], shape=self.input_shapes["past_key_values.key"]))
value_name = f"past_key_values.{i}.value"
inputs.append(helper.make_tensor_value_info(value_name, self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"]))

# Add KV cache to outputs
key_name = f"present.{i}.key"
outputs.append(helper.make_tensor_value_info(key_name, self.output_types["present.key"], shape=self.output_shapes["present.key"]))
value_name = f"present.{i}.value"
outputs.append(helper.make_tensor_value_info(value_name, self.output_types["present.value"], shape=self.output_shapes["present.value"]))

self.inputs = inputs
self.outputs = outputs
Expand Down Expand Up @@ -1536,15 +1539,20 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
self.make_rotary_embedding(attention.rotary_emb, k_rotary_name, root_input=self.attention_attrs["k_path"], position_ids=kwargs.get("position_ids", "position_ids"))
self.attention_attrs["k_path"] = f"{k_rotary_name}/output_0"

# Make repeat KV nodes (Note: `repeat_kv` needs to be kept since GroupQueryAttention isn't supported for FP32 CUDA)
past_k = f"past_key_values.{layer_id}.key"
past_v = f"past_key_values.{layer_id}.value"
present_k = f"present.{layer_id}.key"
present_v = f"present.{layer_id}.value"
if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] == "MultiHeadAttention":
self.attention_attrs["k_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["k_path"], past_kv=past_k, present_kv=present_k)
self.attention_attrs["v_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["v_path"], past_kv=past_v, present_kv=present_v)
past_k, past_v, present_k, present_v = "", "", "", ""
if not self.exclude_cache:
# Make repeat KV nodes (Note: `repeat_kv` needs to be kept since GroupQueryAttention isn't supported for FP32 CUDA)
past_k = f"past_key_values.{layer_id}.key"
past_v = f"past_key_values.{layer_id}.value"
present_k = f"present.{layer_id}.key"
present_v = f"present.{layer_id}.value"
if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] == "MultiHeadAttention":
self.attention_attrs["k_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["k_path"], past_kv=past_k, present_kv=present_k)
self.attention_attrs["v_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["v_path"], past_kv=past_v, present_kv=present_v)
past_k, past_v, present_k, present_v = "", "", "", ""
else:
past_k, past_v = "", ""
present_k = f"present.{layer_id}.key"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think present_k and present_v should be empty strings since the KV cache inputs and outputs are not in the model.

present_v = f"present.{layer_id}.value"

# Make attention node (e.g. MultiHeadAttention, GroupQueryAttention, etc.)
attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}"
Expand Down Expand Up @@ -3267,6 +3275,8 @@ def get_args():
exclude_lm_head = Remove language modeling head from your ONNX model.
Use this option when you want to remove the language modeling head from within your ONNX model.
Instead of `logits`, you will have `hidden_states` as the output to your ONNX model.
exclude_cache = Remove cache inputs and outputs from your ONNX model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
exclude_cache = Remove cache inputs and outputs from your ONNX model.
exclude_kv_cache = Remove KV cache inputs and outputs from your ONNX model.

Use this option when you want to remove the `past_key_values` inputs and `present` outputs from within your ONNX model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Use this option when you want to remove the `past_key_values` inputs and `present` outputs from within your ONNX model.
Use this option when you want to remove the `past_key_values` inputs and `present` outputs from within your ONNX model.
Note that this should be used when you want to run ONNX models with ONNX Runtime only. ONNX Runtime GenAI requires the KV cache inputs and outputs for inference.

include_hidden_states = Include hidden states as output from your ONNX model.
Use this option when you want to have the hidden states as an output from your ONNX model.
In addition to `logits`, you will have `hidden_states` as an output to your ONNX model.
Expand Down
Loading