Skip to content

[Model] EXAONE 4.0 model support #21060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 19, 2025
Merged

Conversation

Deepfocused
Copy link
Contributor

@Deepfocused Deepfocused commented Jul 16, 2025

Purpose

EXAONE 4 has finally been released. I promptly integrated it with vLLM. I hope it works smoothly on vllm v0.9.3.

This PR includes the modifications required to support the EXAONE 4 model, its configuration files, and other necessary code changes.

Usage

👀 For tool usage, use hermes too parser

# tool
vllm serve ... --enable-auto-tool-choice --tool-call-parser hermes --chat-template chat_template.jinja

🦾 Attention, please for reasoning 🦾

  1. "Modify the first three lines at the top of the chat_template.jinja file located at https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B/tree/main for reasoning"
from
{%- if not skip_think is defined %}
   {%- set skip_think = true %}
{%- endif %}
to
{%- set skip_think = false %}
or
To skip the above action, pass the skip_think argument: {\"chat_template_kwargs\": { \"skip_think \": false }.

💥 As you can see from reading chat_template.jinja, when the input contains a </think> token, the reasoning_content value is stored, and this value becomes <think>reasoning_content</think>.(**previous think result** )
The variable skip_think determines whether or not to include the **previous think result** in the next input when it contains a </think> token.
To summarize,
the <think> token does not appear in the model’s output!!!
Exaone 4.0 has been trained to perform reasoning only when the <think> token is included in the **input.**

💥To enable reasoning, you need to include the enable_thinking argument as follows.(added on 2025-07-28)

curl http://localhost:9021/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d \
'{
    "model": "EXAONE-4.0-32B",
    "messages": [
      {
        "role": "user",
        "content": "Which one is bigger, 3.12 vs 3.9?"
      }
    ],
    "chat_template_kwargs": {
      "enable_thinking": true,
      "skip_think": false
    },
    "stream": false,
    "top_p": 0.95,
    "temperature": 0.6
 }'

2. "About Exaone4.0 reasoning mode"
```bash
To enable reasoning mode in the chat_template.jinja file, 
you need to include <think> in the prompt.
However, this only allows us to identify the end point of the reasoning output, 
making it impossible to parse out just the reasoning content separately.
(This is not possible with the reasoning parsers currently supported by vllm.)
This is why the --reasoning-parser {parser} option wasn't used.

Therefore, the reasoning response is included in the content field meaning 
all content is in 'response.choices[0].message.content'.

💥 "The reasoning parser for Exaone4 is currently under development. It will be added once completed." 😸

In "invoke" mode, you can easily put "reasoning" content into "reasoning_contents" 
(and the code is simple). However, in "streaming" mode, it is impossible to write the 
corresponding code because there is no "reasoning" start token. 
(There is no distinguisher to distinguish between "reasoning" answers and general answers.)
It would be possible if the "extract_reasoning_content_streaming" function could be changed 
to receive the "enable_thinking" argument, but it seems uncommon, 
so I will stop developing this part.😥 
 
If you have any other opinions, please let me know.😽

Thanks

Signed-off-by: Deepfocused <[email protected]>
… configuration externally accessible.

Signed-off-by: Deepfocused <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the new-model Requests to new models label Jul 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the EXAONE 4.0 model, including its architecture implementation, configuration, and necessary registry updates. The changes are well-structured and mostly follow the existing patterns in vLLM.

My review has identified two significant issues in the Exaone4Attention implementation that would lead to incorrect model behavior. One is a critical bug where rotary embeddings are not applied correctly, and the other is a high-severity bug in the sliding window initialization logic which is the root cause of the first bug. I have provided detailed feedback and code suggestions to address these issues.

Once these issues are resolved, the implementation should be solid. Thank you for your contribution!

Comment on lines 207 to 209
if self.sliding_window:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The rotary embeddings are only applied if self.sliding_window is not None. This is incorrect as rotary embeddings should be applied to queries and keys unconditionally for all layers. The current implementation will lead to incorrect model outputs for layers that do not use sliding window attention, as their queries and keys will not be rotated.

Suggested change
if self.sliding_window:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)

@Deepfocused Deepfocused reopened this Jul 16, 2025
@chriswritescode-dev
Copy link

This looks like its ready to go .

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add this model to the Supported Models page

@mergify mergify bot added the documentation Improvements or additions to documentation label Jul 17, 2025
@Deepfocused
Copy link
Contributor Author

Please also add this model to the Supported Models page

Completed.

@DarkLight1337 DarkLight1337 added this to the v0.10.0 milestone Jul 17, 2025
@lkm2835
Copy link

lkm2835 commented Jul 17, 2025

Hi, thanks for supporting EXAONE-4.0 on vLLM.
I've found a few parts that seem to need some fixes.

@DarkLight1337
Copy link
Member

Hi, thanks for supporting EXAONE-4.0 on vLLM. I've found a few parts that seem to need some fixes.

Could you elaborate?

@lkm2835
Copy link

lkm2835 commented Jul 17, 2025

Hi, thanks for supporting EXAONE-4.0 on vLLM. I've found a few parts that seem to need some fixes.

Could you elaborate?

The first part (self.norm) differs from the EXAONE4 modeling implementation.

For the second part (rotary_emb), since there are two main types of the EXAONE-4.0 model (32B, 1.2B), the conditional statements need to be adjusted to support both.
Specifically, the 32B model uses rotary_emb with sliding_attention, while the 1.2B model uses rotary_emb with full_attention.

@Deepfocused
Copy link
Contributor Author

I’ll compare it once more with the original EXAONE4 code, make the revisions, and update.

@Deepfocused
Copy link
Contributor Author

@lkm2835
“RMS seems to be the same.
What exactly is different?
Refer here: https://github.com/lgai-exaone/transformers/blob/add-exaone4/src/transformers/models/exaone4/modeling_exaone4.py#L59

I plan to add exception handling for the 1.2B model.”

"residual": residual
})

hidden_states, _ = self.norm(hidden_states, residual)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the EXAONE4 Transformers code, it seems that residual should be removed.

transformers/modeling_exaone4.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i got it

Comment on lines 207 to 208
if self.sliding_window:
q, k = self.rotary_emb(positions, q, k)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The EXAONE-4.0-1.2B config should be considered in this conditional statement.
This model uses rotary embeddings, even though sliding_window is set to None.

You can refer to this part.
transformers/exaone4_modeling.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i got it

@lkm2835
Copy link

lkm2835 commented Jul 17, 2025

@lkm2835 “RMS seems to be the same. What exactly is different? Refer here: https://github.com/lgai-exaone/transformers/blob/add-exaone4/src/transformers/models/exaone4/modeling_exaone4.py#L59

I plan to add exception handling for the 1.2B model.”

@Deepfocused
Apologies, I wrote the review but forgot to open it.

…nsidering EXAONE-4.0-1.2B.

Signed-off-by: Deepfocused <[email protected]>
@Deepfocused
Copy link
Contributor Author

@DarkLight1337 @chriswritescode-dev @lkm2835 @andrew 

👨‍🌾I've completed the code modifications, verified they're working correctly, and recommitted the changes.

The last changes are as follows:

(1) remove residual from self.norm
(2) Code change for applying RoPE considering EXAONE-4.0-1.2B

😸Thanks a lot!😸

@DarkLight1337
Copy link
Member

@lkm2835 can you confirm whether it is correct now?

Copy link

mergify bot commented Jul 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Deepfocused.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 18, 2025
@DarkLight1337
Copy link
Member

You should use Git to resolve the merge conflict

@mergify mergify bot removed the needs-rebase label Jul 18, 2025
@DarkLight1337
Copy link
Member

The row is missing again

@Deepfocused
Copy link
Contributor Author

The row is missing again

@DarkLight1337 Adding the following to docs/models/supported_models.md causes a merge conflict:
| 'Exaone4ForCausalLM' | Exaone-4 | LGAI-EXAONE/EXAONE-4.0-32B, etc. | ✅︎ | ✅︎ | ✅︎ |

  • I'll try to fix it.

@@ -87,6 +88,7 @@ def _get_hf_token() -> Optional[str]:
"medusa": MedusaConfig,
"eagle": EAGLEConfig,
"exaone": ExaoneConfig,
"exaone4": Exaone4Config,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have another question.
Is Exaone4Config necessary? or can we just use the default Exaone4 HF config?

@DarkLight1337 DarkLight1337 merged commit 3e04107 into vllm-project:main Jul 19, 2025
69 checks passed
hj-mistral pushed a commit to hj-mistral/vllm that referenced this pull request Jul 19, 2025
Signed-off-by: Deepfocused <[email protected]>
Signed-off-by: woongsik <[email protected]>
Signed-off-by: Himanshu Jaju <[email protected]>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
@blakkd
Copy link

blakkd commented Jul 25, 2025

@Deepfocused
I'm sorry to bother, I didn't know really know whether I should have posted this as an issue or here, but I thought all the context was above, so I finally decided to put it here.

I can't replicate and make it engage in reasoning.
I set skip_think = false at line 2 of their official chat_template.jinja as you suggested, but the <think> token isn't prepended to the model output. I only get regular answers.

Here is my server launch command:
vllm serve "LGAI-EXAONE/EXAONE-4.0-32B-AWQ" --gpu-memory-utilization 0.94 --enforce-eager --max_num_seq 1 --max-model-len 30000 --chat-template /home/user/vllm/templates/exaone4_vllm.jinja

Do you have any idea of what I'm doing wrong?


I even tried the chat template I use for llama.cpp (and which works for it) where I simply forced the <think> token line 141 like this:

{%- if add_generation_prompt %}
    {{- role_indicators['assistant'] }}
    {{- "<think>\n" }}
{%- endif %}

But still no success.

@lkm2835
Copy link

lkm2835 commented Jul 25, 2025

Hi @blakkd
Could you try without --enforce-eager?

@blakkd
Copy link

blakkd commented Jul 25, 2025

Removing --enforce-eager doesn't change the behavior. But I have to revise what I said, here is the current behaviors I'm facing:

  • with only {%- set skip_think = false %} (line 2) --> no success, no thinking response.
  • with harcoded
{%- if add_generation_prompt %}
    {{- role_indicators['assistant'] }}
    {{- "<think>\n" }}
{%- endif %}

--> Actually, I do get the reasoning triggering, but the opening <think> tag doesn't show up in the response.

However, this only allows us to identify the end point of the reasoning output, making it impossible to parse out just the reasoning content separately.

Why would it make it impossible to identify the starting point when we force the token ourselves to mark the starting point?

I'm so confused :/ I'm so noob

But if ever you feel this is not the place to post, please tell me, I would totally understand and open a discussion instead!

@lkm2835
Copy link

lkm2835 commented Jul 26, 2025

@blakkd
Can you open a new issue?

@blakkd
Copy link

blakkd commented Jul 26, 2025

Of course :D But I'd rather open a discussion instead, I don't feel I've dove enough to post an issue

@hongseok-oh
Copy link

Removing --enforce-eager doesn't change the behavior. But I have to revise what I said, here is the current behaviors I'm facing:

  • with only {%- set skip_think = false %} (line 2) --> no success, no thinking response.
  • with harcoded
{%- if add_generation_prompt %}
    {{- role_indicators['assistant'] }}
    {{- "<think>\n" }}
{%- endif %}

--> Actually, I do get the reasoning triggering, but the opening tag doesn't show up in the response.

However, this only allows us to identify the end point of the reasoning output, making it impossible to parse out just the reasoning content separately.

Why would it make it impossible to identify the starting point when we force the token ourselves to mark the starting point?

I'm so confused :/ I'm so noob

But if ever you feel this is not the place to post, please tell me, I would totally understand and open a discussion instead!

I created an issue #21718

@Deepfocused
Copy link
Contributor Author

Deepfocused commented Jul 28, 2025

@blakkd @hongseok-oh

oh, I saw this too late, Sorry about that.

It seems I didn't explain skip_think in enough detail — sorry for the confusion.

Let me explain skip_think in a bit more detail:
As you can see from reading chat_template.jinja, when the input contains a  ‘</think>‘ token, the reasoning_content value is stored, and this value becomes ‘<think>reasoning_content</think>‘.(**previous think result** )

The variable skip_think determines whether or not to include the **previous think result** in the next input when it contains a ‘</think>‘ token.

To summarize,
the ‘<think>‘ token does not appear in the model’s output!!!
Exaone 4.0 has been trained to perform reasoning only when the ‘<think>‘ token is included in the **input.** ```


curl example for reasoning
```bash
curl http://localhost:9021/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d \
'{
    "model": "EXAONE-4.0-32B",
    "messages": [
      {
        "role": "user",
        "content": "Which one is bigger, 3.12 vs 3.9?"
      }
    ],
    "chat_template_kwargs": {
      "enable_thinking": true,
      "skip_think": false
    },
    "stream": false,
    "top_p": 0.95,
    "temperature": 0.6
 }'

@blakkd
Copy link

blakkd commented Jul 28, 2025

@Deepfocused Thanks for taking time to bring the reformulation!

So, let's say in the general case where we want to have a multiturn discussion, we should set "enable_thinking": true AND "skip_think": true in order to keep only the previous final answer as context for the next turn, excluding the previous reasoning trace, is that it?

Nevertheless @hongseok-oh thanks for your initiative taking the reins and posting the issue

@Deepfocused
Copy link
Contributor Author

@blakkd
Yes, that's right.
It's exactly what's in chat_template.jinja!

@blakkd
Copy link

blakkd commented Jul 29, 2025

Thanks! Sorry my functioning artificial brain was stuck because of the stop tokens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants