Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to export GLM models to ONNX #35021

Open
1 of 4 tasks
xenova opened this issue Nov 29, 2024 · 6 comments
Open
1 of 4 tasks

Unable to export GLM models to ONNX #35021

xenova opened this issue Nov 29, 2024 · 6 comments

Comments

@xenova
Copy link
Contributor

xenova commented Nov 29, 2024

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-6.5.0-1025-azure-x86_64-with-glibc2.31
  • Python version: 3.12.1
  • Huggingface_hub version: 0.26.1
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0+cu124 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run the following script:

import os
import torch
from transformers import (
    AutoProcessor,
    GlmForCausalLM,
    DynamicCache,
)

class PatchedGlmForCausalLM(GlmForCausalLM):
    def forward(self, *args):
        input_ids, attention_mask, position_ids, *past_key_values_args = args

        # Convert past_key_values list to DynamicCache
        if len(past_key_values_args) == 0:
            past_key_values = None
        else:
            past_key_values = DynamicCache(self.config.num_hidden_layers)
            for i in range(self.config.num_hidden_layers):
                key = past_key_values_args.pop(0)
                value = past_key_values_args.pop(0)
                past_key_values.update(key_states=key, value_states=value, layer_idx=i)

        o = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
        )

        flattened_past_key_values_outputs = {
            "logits": o.logits,
        }
        output_past_key_values: DynamicCache = o.past_key_values
        for i, (key, value) in enumerate(
            zip(output_past_key_values.key_cache, output_past_key_values.value_cache)
        ):
            flattened_past_key_values_outputs[f"present.{i}.key"] = key
            flattened_past_key_values_outputs[f"present.{i}.value"] = value

        return flattened_past_key_values_outputs


# Constants
OUTPUT_FOLDER = "output"
TEXT_MODEL_NAME = "model.onnx"
TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp")
FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx")


# Load model and processor
model_id = "hf-internal-testing/tiny-random-GlmForCausalLM"
model = PatchedGlmForCausalLM.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)


# Save model configs and processor
model.config.save_pretrained(OUTPUT_FOLDER)
model.generation_config.save_pretrained(OUTPUT_FOLDER)
processor.save_pretrained(OUTPUT_FOLDER)
os.makedirs(TEMP_MODEL_OUTPUT_FOLDER, exist_ok=True)


# Configuration values
## Text model
text_config = model.config
num_heads = text_config.num_attention_heads
num_key_value_heads = text_config.num_key_value_heads
head_dim = text_config.head_dim
num_layers = text_config.num_hidden_layers
hidden_size = text_config.hidden_size


# Dummy input sizes
batch_size = 2
sequence_length = 16
past_sequence_length = 0

## Text inputs
dummy_past_key_values_kwargs = {
    f"past_key_values.{i}.{key}": torch.zeros(
        batch_size,
        num_key_value_heads,
        past_sequence_length,
        head_dim,
        dtype=torch.float32,
    )
    for i in range(num_layers)
    for key in ["key", "value"]
}
input_ids = torch.randint(
    0, text_config.vocab_size,
    (batch_size, sequence_length),
)
attention_mask = torch.ones(batch_size, sequence_length + past_sequence_length, dtype=torch.int64)
position_ids = torch.ones(batch_size, sequence_length, dtype=torch.int64)

text_inputs = dict(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    **dummy_past_key_values_kwargs,
)
text_inputs_positional = tuple(text_inputs.values())
text_outputs = model.forward(*text_inputs_positional)  # Test forward pass


# ONNX Exports
## Text model
TEXT_MODEL_OUTPUT_PATH=os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME)
torch.onnx.export(
    model,
    args=text_inputs_positional,
    f=TEXT_MODEL_OUTPUT_PATH,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=list(text_inputs.keys()),
    output_names=["logits"]
    + [f"present.{i}.{key}" for i in range(num_layers) for key in ["key", "value"]],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "total_sequence_length"},
        "position_ids": {0: "batch_size", 1: "sequence_length"},
        **{
            f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"}
            for i in range(num_layers)
            for key in ["key", "value"]
        },
        "logits": {0: "batch_size", 1: "sequence_length"},
        **{
            f"present.{i}.{key}": {0: "batch_size", 2: "total_sequence_length"}
            for i in range(num_layers)
            for key in ["key", "value"]
        },
    },
)

It produces this error:

Traceback (most recent call last):
  File "/workspaces/glm.py", line 110, in <module>
    torch.onnx.export(
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
    graph = _optimize_graph(
            ^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 639, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_helper.py", line 369, in wrapper
    return fn(g, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_opset11.py", line 519, in cat
    return opset9.cat(g, tensor_list, dim)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_helper.py", line 281, in wrapper
    return fn(g, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_opset9.py", line 575, in cat
    assert all(a)
AssertionError

Expected behavior

The model should export correctly. This may in fact be an ONNX bug, but not 100% sure. Models like Gemma can export correctly, so it seems to be GLM-specific.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jan 7, 2025
@ArthurZucker ArthurZucker reopened this Feb 11, 2025
@ArthurZucker
Copy link
Collaborator

cc @Cyrilvallez

@xenova
Copy link
Contributor Author

xenova commented Feb 11, 2025

I did a long deep-dive on this recently, and in fact, it appears to have been a bug with pytorch! pytorch/pytorch#145100 👀

I was able to apply a fix in Optimum by overriding the repeat_interleave op (huggingface/optimum#2162), so it may not necessarily be something we can (or should) fix in transformers.

@Cyrilvallez
Copy link
Member

Hey @xenova! Thanks for looking deep into this. To be fair, repeat_interleave should be used directly in the RotaryEmbedding class, instead of the apply_rotary_pos_emb function. I initially implemented it the way it currently is because I wanted to use modular, and this was a very simple and afficient workaround. I did not know at the time that we already had existing model using interleave instead of cat in the RotaryEmbedding class, and manual interleave in rotate_half - but turns out Cohere does!
I have been thinking about simplifying the implementation using inheritance from Cohere instead. Would that solve the issue if the repeat_interleave is located directly in the RotaryEmbedding class?

@xenova
Copy link
Contributor Author

xenova commented Feb 17, 2025

Hey @xenova! Thanks for looking deep into this. To be fair, repeat_interleave should be used directly in the RotaryEmbedding class, instead of the apply_rotary_pos_emb function. I initially implemented it the way it currently is because I wanted to use modular, and this was a very simple and afficient workaround. I did not know at the time that we already had existing model using interleave instead of cat in the RotaryEmbedding class, and manual interleave in rotate_half - but turns out Cohere does! I have been thinking about simplifying the implementation using inheritance from Cohere instead. Would that solve the issue if the repeat_interleave is located directly in the RotaryEmbedding class?

Possibly... 👀 but to be sure, if you could post code snippets, I can test them out with the torch ONNX exporter.

@Cyrilvallez
Copy link
Member

Sure, here are the class and functions with the modification:

class GlmRotaryEmbedding(nn.Module):
    def __init__(self, config: CohereConfig, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
            # This .to() is needed if the model has been moved to a device after being initialized (because
            # the buffer is automatically moved, but not the original copy)
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.repeat_interleave(freqs, 2, dim=-1)  # diff from Llama: we interleave() instead of cat()
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).flatten(-2)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # Keep half or full tensor for later concatenation
    rotary_dim = cos.shape[-1]
    q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
    k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

    # Apply rotary embeddings on the first half or full tensor
    q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
    k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)

    # Concatenate back to full shape
    q_embed = torch.cat([q_embed, q_pass], dim=-1)
    k_embed = torch.cat([k_embed, k_pass], dim=-1)
    return q_embed, k_embed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants