Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lightllm/models/qwen2_5_vl/qwen2_5_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from lightllm.server.visualserver import get_vit_attn_backend
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class Qwen2RMSNorm(nn.Module):
Expand Down Expand Up @@ -157,6 +160,7 @@ def __init__(
super().__init__()
self.weight_dir = kvargs["weight_dir"]
self.data_type = kvargs.get("data_type", "bfloat16")
self.max_batch_size = kvargs.get("max_batch_size", 1)

self.depth = depth
self.hidden_size = hidden_size
Expand Down Expand Up @@ -224,6 +228,12 @@ def _init_datatype(self):
raise ValueError(f"Unsupport datatype {self.data_type}!")
return

@torch.no_grad()
def _check_max_len_infer(self):
from lightllm.models.qwen2_vl.vision_process import qwen_vl_check_max_len_infer

qwen_vl_check_max_len_infer(self, self.max_batch_size)

def rot_pos_emb(self, grid_thw):
pos_ids = []
s = self.spatial_merge_size
Expand Down
7 changes: 7 additions & 0 deletions lightllm/models/qwen2_vl/qwen2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __init__(
):
super().__init__()
self.data_type = kvargs.get("data_type", "bfloat16")
self.max_batch_size = kvargs.get("max_batch_size", 1)

self.depth = depth
self.embed_dim = embed_dim
Expand Down Expand Up @@ -238,6 +239,12 @@ def _init_datatype(self):
raise ValueError(f"Unsupport datatype {self.data_type}!")
return

@torch.no_grad()
def _check_max_len_infer(self):
from lightllm.models.qwen2_vl.vision_process import qwen_vl_check_max_len_infer

qwen_vl_check_max_len_infer(self, self.max_batch_size)

def load_model(self, weight_dir):

processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
Expand Down
54 changes: 54 additions & 0 deletions lightllm/models/qwen2_vl/vision_process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import math
import os
import torch
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -27,6 +28,59 @@
logger = init_logger(__name__)


def closest_factor_pair(n):
"""Find the factor pair of n closest to sqrt(n). Returns (smaller, larger)."""
sqrt_n = int(math.sqrt(n))
for i in range(sqrt_n, 0, -1):
if n % i == 0:
return i, n // i
return 1, n


@torch.no_grad()
def qwen_vl_check_max_len_infer(model, max_batch_size):
"""OOM pre-check for Qwen-family vision models.

Constructs worst-case dummy images at max_pixels resolution,
replicates for max_batch_size, and runs a forward pass to validate
GPU memory is sufficient.
"""
disable_check = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None
if disable_check:
return

unit = model.patch_size * model.spatial_merge_size
max_pixels = model.processor.max_pixels
max_patches = max_pixels // (unit * unit)
if max_patches < 1:
max_patches = 1
h_factor, w_factor = closest_factor_pair(max_patches)
worst_h = unit * h_factor
worst_w = unit * w_factor

try:
dummy_image = Image.new("RGB", (worst_w, worst_h), color=(128, 128, 128))
pixel_values, grid_thw = model.processor.preprocess(dummy_image)

pixel_values = pixel_values.repeat(max_batch_size, 1, 1)
grid_thw = grid_thw.repeat(max_batch_size, 1)

pixel_values = pixel_values.to("cuda", dtype=model.data_type, non_blocking=True)
grid_thw = grid_thw.to("cuda", non_blocking=True)

result = model.forward(pixel_values, grid_thw=grid_thw)
del result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

After the forward pass, the pixel_values and grid_thw tensors remain in GPU memory until the function returns. In an OOM pre-check context, it is best practice to explicitly delete these large tensors and call torch.cuda.empty_cache() to ensure that the memory is immediately available for the subsequent model initialization and KV cache allocation.

Suggested change
del result
del result, pixel_values, grid_thw
torch.cuda.empty_cache()

logger.info(f"vit check max_len {max_batch_size} infer ok")
except (RuntimeError, torch.OutOfMemoryError, ValueError) as e:
logger.exception(str(e))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using logger.exception(str(e)) is redundant because logger.exception automatically captures the exception object and its traceback. It is better to provide a descriptive message about what failed.

Suggested change
logger.exception(str(e))
logger.exception("Qwen VL check max len infer failed")

exception_str = (
"Vit check max len infer fail, you can try: "
"1.Set the --visual_infer_batch_size to a smaller value."
)
logger.error(exception_str)
raise Exception(exception_str)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Raising a base Exception is generally discouraged as it makes it harder for calling code to distinguish between different types of failures. Using RuntimeError is more appropriate for an execution failure during a model check.

Suggested change
raise Exception(exception_str)
raise RuntimeError(exception_str)



IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
Expand Down
10 changes: 10 additions & 0 deletions lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class Qwen3OmniMoeVisionMLP(nn.Module):
Expand Down Expand Up @@ -140,6 +143,7 @@ def __init__(
):
super().__init__()
self.data_type = kvargs.get("data_type", "bfloat16")
self.max_batch_size = kvargs.get("max_batch_size", 1)

self.depth = depth
self.out_hidden_size = out_hidden_size
Expand Down Expand Up @@ -207,6 +211,12 @@ def _init_datatype(self):
raise ValueError(f"Unsupport datatype {self.data_type}!")
return

@torch.no_grad()
def _check_max_len_infer(self):
from lightllm.models.qwen2_vl.vision_process import qwen_vl_check_max_len_infer

qwen_vl_check_max_len_infer(self, self.max_batch_size)

def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids):
all_chunks = []

Expand Down
7 changes: 7 additions & 0 deletions lightllm/models/qwen3_vl/qwen3_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
):
super().__init__()
self.data_type = kvargs.get("data_type", "bfloat16")
self.max_batch_size = kvargs.get("max_batch_size", 1)

self.depth = depth
self.out_hidden_size = out_hidden_size
Expand Down Expand Up @@ -202,6 +203,12 @@ def _init_datatype(self):
raise ValueError(f"Unsupport datatype {self.data_type}!")
return

@torch.no_grad()
def _check_max_len_infer(self):
from lightllm.models.qwen2_vl.vision_process import qwen_vl_check_max_len_infer

qwen_vl_check_max_len_infer(self, self.max_batch_size)

def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids):
all_chunks = []

Expand Down
23 changes: 17 additions & 6 deletions lightllm/models/vit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(self, kvargs):
self._init_quant()
self._init_weights()
self._init_infer_layer()
self._check_max_len_infer()
return

@final
Expand All @@ -73,7 +72,8 @@ def _check_max_len_infer(self):
except (RuntimeError, torch.OutOfMemoryError) as e:
logger.exception(str(e))
exception_str = (
"Vit check max len infer fail, you can try:" "1.Set the --visual_infer_batch_size to a smaller value."
"Vit check max len infer fail, you can try: "
"1.Set the --visual_infer_batch_size to a smaller value."
)
logger.error(exception_str)
raise Exception(exception_str)
Expand All @@ -85,16 +85,27 @@ def _init_config(self):
self.select_layer = self.config["select_layer"]
self.config["vision_config"]["llm_hidden_size"] = self.config["llm_config"]["hidden_size"]
self.config["vision_config"]["downsample_ratio"] = self.config["downsample_ratio"]

# Derive worst-case image dimensions from model config
image_size = self.config.get("force_image_size", self.config["vision_config"]["image_size"])
max_dynamic_patch = self.config.get("max_dynamic_patch", 12)
use_thumbnail = self.config.get("use_thumbnail", True)
dynamic_image_size = self.config.get("dynamic_image_size", True)

self.config = self.config["vision_config"]

repair_config(self.config, same_names=["num_attention_heads", "n_head"])
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
self.layers_num = self.config["num_hidden_layers"]

# infer info
self.IMAGE_H = int(os.getenv("IMAGE_H", 448))
self.IMAGE_W = int(os.getenv("IMAGE_W", 448))
self.MAX_PATH_NUM = os.getenv("MAX_PATH_NUM", 13)
# infer info — computed from config, not env vars
self.IMAGE_H = image_size
self.IMAGE_W = image_size
max_num = max_dynamic_patch if dynamic_image_size else 1
if use_thumbnail and max_num != 1:
max_num += 1
self.MAX_PATH_NUM = max_num
return

def _padding_hidden_size(self):
Expand Down
2 changes: 2 additions & 0 deletions lightllm/server/visualserver/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def exposed_init_model(self, kvargs):

self.model.load_model(weight_dir)
self.model = self.model.cuda()
if hasattr(self.model, "_check_max_len_infer"):
self.model._check_max_len_infer()
if not self.is_visual_only_mode:
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
Expand Down
Loading