-
Notifications
You must be signed in to change notification settings - Fork 203
[Model builder] Add option to exclude cache in inputs and outputs #1162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -111,6 +111,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): | |||||||
elif self.include_hidden_states: | ||||||||
self.output_names = ["hidden_states"] + self.output_names | ||||||||
|
||||||||
self.exclude_cache = "exclude_cache" in extra_options | ||||||||
|
||||||||
# Store names of nodes already created | ||||||||
self.node_names = set() | ||||||||
|
||||||||
|
@@ -557,19 +559,20 @@ def make_inputs_and_outputs(self): | |||||||
shape = self.output_shapes[name] | ||||||||
outputs.append(helper.make_tensor_value_info(name, dtype, shape=shape)) | ||||||||
|
||||||||
# Add KV cache to inputs and outputs | ||||||||
for i in range(self.num_layers): | ||||||||
# Add KV cache to inputs | ||||||||
key_name = f"past_key_values.{i}.key" | ||||||||
inputs.append(helper.make_tensor_value_info(key_name, self.input_types["past_key_values.key"], shape=self.input_shapes["past_key_values.key"])) | ||||||||
value_name = f"past_key_values.{i}.value" | ||||||||
inputs.append(helper.make_tensor_value_info(value_name, self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"])) | ||||||||
|
||||||||
# Add KV cache to outputs | ||||||||
key_name = f"present.{i}.key" | ||||||||
outputs.append(helper.make_tensor_value_info(key_name, self.output_types["present.key"], shape=self.output_shapes["present.key"])) | ||||||||
value_name = f"present.{i}.value" | ||||||||
outputs.append(helper.make_tensor_value_info(value_name, self.output_types["present.value"], shape=self.output_shapes["present.value"])) | ||||||||
if not self.exclude_cache: | ||||||||
# Add KV cache to inputs and outputs | ||||||||
for i in range(self.num_layers): | ||||||||
# Add KV cache to inputs | ||||||||
key_name = f"past_key_values.{i}.key" | ||||||||
inputs.append(helper.make_tensor_value_info(key_name, self.input_types["past_key_values.key"], shape=self.input_shapes["past_key_values.key"])) | ||||||||
value_name = f"past_key_values.{i}.value" | ||||||||
inputs.append(helper.make_tensor_value_info(value_name, self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"])) | ||||||||
|
||||||||
# Add KV cache to outputs | ||||||||
key_name = f"present.{i}.key" | ||||||||
outputs.append(helper.make_tensor_value_info(key_name, self.output_types["present.key"], shape=self.output_shapes["present.key"])) | ||||||||
value_name = f"present.{i}.value" | ||||||||
outputs.append(helper.make_tensor_value_info(value_name, self.output_types["present.value"], shape=self.output_shapes["present.value"])) | ||||||||
|
||||||||
self.inputs = inputs | ||||||||
self.outputs = outputs | ||||||||
|
@@ -1536,15 +1539,20 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): | |||||||
self.make_rotary_embedding(attention.rotary_emb, k_rotary_name, root_input=self.attention_attrs["k_path"], position_ids=kwargs.get("position_ids", "position_ids")) | ||||||||
self.attention_attrs["k_path"] = f"{k_rotary_name}/output_0" | ||||||||
|
||||||||
# Make repeat KV nodes (Note: `repeat_kv` needs to be kept since GroupQueryAttention isn't supported for FP32 CUDA) | ||||||||
past_k = f"past_key_values.{layer_id}.key" | ||||||||
past_v = f"past_key_values.{layer_id}.value" | ||||||||
present_k = f"present.{layer_id}.key" | ||||||||
present_v = f"present.{layer_id}.value" | ||||||||
if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] == "MultiHeadAttention": | ||||||||
self.attention_attrs["k_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["k_path"], past_kv=past_k, present_kv=present_k) | ||||||||
self.attention_attrs["v_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["v_path"], past_kv=past_v, present_kv=present_v) | ||||||||
past_k, past_v, present_k, present_v = "", "", "", "" | ||||||||
if not self.exclude_cache: | ||||||||
# Make repeat KV nodes (Note: `repeat_kv` needs to be kept since GroupQueryAttention isn't supported for FP32 CUDA) | ||||||||
past_k = f"past_key_values.{layer_id}.key" | ||||||||
past_v = f"past_key_values.{layer_id}.value" | ||||||||
present_k = f"present.{layer_id}.key" | ||||||||
present_v = f"present.{layer_id}.value" | ||||||||
if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] == "MultiHeadAttention": | ||||||||
self.attention_attrs["k_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["k_path"], past_kv=past_k, present_kv=present_k) | ||||||||
self.attention_attrs["v_path"] = self.make_repeat_kv(layer_id, root_input=self.attention_attrs["v_path"], past_kv=past_v, present_kv=present_v) | ||||||||
past_k, past_v, present_k, present_v = "", "", "", "" | ||||||||
else: | ||||||||
past_k, past_v = "", "" | ||||||||
present_k = f"present.{layer_id}.key" | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||||||||
present_v = f"present.{layer_id}.value" | ||||||||
|
||||||||
# Make attention node (e.g. MultiHeadAttention, GroupQueryAttention, etc.) | ||||||||
attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" | ||||||||
|
@@ -3267,6 +3275,8 @@ def get_args(): | |||||||
exclude_lm_head = Remove language modeling head from your ONNX model. | ||||||||
Use this option when you want to remove the language modeling head from within your ONNX model. | ||||||||
Instead of `logits`, you will have `hidden_states` as the output to your ONNX model. | ||||||||
exclude_cache = Remove cache inputs and outputs from your ONNX model. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
Use this option when you want to remove the `past_key_values` inputs and `present` outputs from within your ONNX model. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
include_hidden_states = Include hidden states as output from your ONNX model. | ||||||||
Use this option when you want to have the hidden states as an output from your ONNX model. | ||||||||
In addition to `logits`, you will have `hidden_states` as an output to your ONNX model. | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.