Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Add MLA with pure RoPE support for deepseek-vl2 models #12729

Merged
merged 2 commits into from
Feb 5, 2025

Conversation

Isotr0py
Copy link
Collaborator

@Isotr0py Isotr0py commented Feb 4, 2025

FIX #11578 (comment)

  • Deepseek-VL2 use pure rotary embedding, and current MLA implementation only support yarn rope

Copy link

github-actions bot commented Feb 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@Isotr0py
Copy link
Collaborator Author

Isotr0py commented Feb 4, 2025

deepseek-vl2-small with deepseek-v2 backbone should work with MLA backend now:

$ python examples/offline_inference/vision_language.py -m deepseek_vl_v2 --num-prompts 2
INFO 02-04 15:55:15 __init__.py:186] Automatically detected platform cuda.
INFO 02-04 15:55:17 config.py:306] Overriding HF config with {'architectures': ['DeepseekVLV2ForCausalLM']}
INFO 02-04 15:55:29 config.py:542] This model supports multiple tasks: {'embed', 'score', 'reward', 'classify', 'generate'}. Defaulting to 'generate'.
INFO 02-04 15:55:29 config.py:3276] MLA is enabled; forcing chunked prefill and prefix caching to be disabled.
INFO 02-04 15:55:29 llm_engine.py:234] Initializing a V0 LLM engine (v0.6.6.post2.dev502+g9e387cf2.d20250204) with config: model='deepseek-ai/deepseek-vl2-small', speculative_config=None, tokenizer='deepseek-ai/deepseek-vl2-small', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/deepseek-vl2-small, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[2,1],"max_capture_size":2}, use_cached_outputs=False, 
INFO 02-04 15:55:33 cuda.py:161] Using Triton MLA backend.
WARNING 02-04 15:55:34 triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
INFO 02-04 15:55:39 model_runner.py:1113] Starting to load model deepseek-ai/deepseek-vl2-small...
INFO 02-04 15:55:43 config.py:2993] cudagraph sizes specified by model runner [1, 2] is overridden by config [1, 2]
INFO 02-04 15:55:43 config.py:3276] MLA is enabled; forcing chunked prefill and prefix caching to be disabled.
INFO 02-04 15:55:44 cuda.py:161] Using Triton MLA backend.
INFO 02-04 15:55:48 weight_utils.py:252] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:04<00:14,  4.86s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:08<00:08,  4.24s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:13<00:04,  4.55s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:18<00:00,  4.71s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:18<00:00,  4.64s/it]

INFO 02-04 15:56:11 model_runner.py:1118] Loading model weights took 31.9199 GB
Some kwargs in processor config are unused and will not have any effect: image_mean, image_token, normalize, patch_size, candidate_resolutions, image_std, pad_token, downsample_ratio, sft_format, mask_prompt, ignore_id, add_special_token. 
WARNING 02-04 15:56:24 fused_moe.py:806] Using default MoE config. Performance might be sub-optimal! Config file not found at /home/zifeng/develop-projects/vllm/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=NVIDIA_A100-SXM4-80GB.json
INFO 02-04 15:56:28 worker.py:267] Memory profiling takes 16.52 seconds
INFO 02-04 15:56:28 worker.py:267] the current vLLM instance can use total_gpu_memory (79.15GiB) x gpu_memory_utilization (0.50) = 39.58GiB
INFO 02-04 15:56:28 worker.py:267] model weights take 31.92GiB; non_torch_memory takes 0.09GiB; PyTorch activation peak memory takes 0.47GiB; the rest of the memory reserved for KV Cache is 7.09GiB.
INFO 02-04 15:56:28 executor_base.py:110] # CUDA blocks: 15299, # CPU blocks: 8630
INFO 02-04 15:56:28 executor_base.py:115] Maximum concurrency for 4096 tokens per request: 59.76x
INFO 02-04 15:56:43 model_runner.py:1437] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:11<00:00,  5.81s/it]
INFO 02-04 15:56:55 model_runner.py:1565] Graph capturing finished in 12 secs, took 0.70 GiB
INFO 02-04 15:56:55 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 44.61 seconds
Some kwargs in processor config are unused and will not have any effect: image_mean, image_token, normalize, patch_size, candidate_resolutions, image_std, pad_token, downsample_ratio, sft_format, mask_prompt, ignore_id, add_special_token. 
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.45s/it, est. speed input: 986.19 toks/s, output: 27.20 toks/s]
The image features a view of a tall tower partially obscured by blooming cherry blossom trees. The sky is clear and blue, providing a vibrant backdrop for the pink flowers.
The image features a view of the Tokyo Skytree, a prominent tower in Tokyo, Japan, surrounded by blooming cherry blossoms. The cherry blossoms are in full bloom, creating a picturesque scene with the tower in the background.

@Isotr0py Isotr0py marked this pull request as ready for review February 4, 2025 15:09
Comment on lines +426 to +442
def apply_pure_rope(
self,
input_positions: torch.Tensor,
q_pe: torch.Tensor,
k_pe: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = input_positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape

q_pe, k_pe = self.rotary_emb(
input_positions,
q_pe.reshape(seq_len, -1),
k_pe.reshape(seq_len, -1),
)
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

return q_pe, k_pe
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you say a bit about why you needed to wrap rotary_embedding when using pure_rope? Wondering if we could clean things up by always doing this reshape so that we could always call self.rotary_embedding without the special cases for pure rope vs yarn

Copy link
Collaborator Author

@Isotr0py Isotr0py Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrapped the rotary_embedding to reshape with pure_rope because if q_pe and k_pe have shape of [seq_len, num_heads, head_dim] and passed to pure_rope directly, it will cause an illegal memory allocation on q_pe when applying flash_attention_varlen_func:

[rank0]:   File "/home/zifeng/develop-projects/vllm/vllm/attention/backends/mla/utils.py", line 531, in _forward_prefill_flash
[rank0]:     attn_output = flash_attn_varlen_func(
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/zifeng/develop-projects/vllm/vllm/vllm_flash_attn/flash_attn_interface.py", line 172, in flash_attn_varlen_func
[rank0]:     out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/zifeng/miniconda3/envs/vllm/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

BTW, if we use forward_native for pure_rope without reshape, the error won't be encountered and it can also work with shape of [seq_len, num_heads, head_dim], so the issue is forward_cuda specific. Perhaps we should add a shape check in RotaryEmbedding's forward_cuda?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, seems that it's because the calculation for num_heads in rotary_embedding cuda ops is unsuitable for tensor with shape [seq_len, num_heads, head_dim]:

void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;

Let's fix it in a separate PR to avoid blocking v0.7.2 release, especially it's on the kernel side and I need some time to build with compilation. :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a bug in the kernel -- I'll look into it tomorrow. In the meantime I like adding a shape check in forward_cuda if you have a good idea of what shapes are problematic

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's fix it in a separate PR to avoid blocking v0.7.2 release, especially it's on the kernel side and I need some time to build with compilation. :)

Nice find, sounds good to me!

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

I had one comment from earlier today asking if the implementation could be cleaned up a bit - I still wonder about that. For now, I'm accepting because functionality looks good and I think we should get it into 0.7.2

@tlrmchlsmth tlrmchlsmth mentioned this pull request Feb 5, 2025
4 tasks
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 5, 2025
@simon-mo simon-mo merged commit 98fd089 into vllm-project:main Feb 5, 2025
49 of 53 checks passed
@Isotr0py Isotr0py deleted the deepseek-vl2-mla branch February 5, 2025 04:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants