Skip to content

feat: add OOM pre-check for vision models and fix InternVL image dime…#1253

Open
sufubao wants to merge 1 commit intomainfrom
fix_mm_check
Open

feat: add OOM pre-check for vision models and fix InternVL image dime…#1253
sufubao wants to merge 1 commit intomainfrom
fix_mm_check

Conversation

@sufubao
Copy link
Copy Markdown
Collaborator

@sufubao sufubao commented Apr 2, 2026

…nsion handling

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an OOM (Out of Memory) pre-check mechanism for Qwen-family vision models by performing a dummy forward pass with worst-case image dimensions during initialization. It updates several model implementations to support this check and refactors the ViT model to derive inference parameters from the configuration instead of environment variables. Feedback focuses on improving memory management within the pre-check function by explicitly deleting tensors and clearing the CUDA cache, as well as refining exception handling and logging practices.

grid_thw = grid_thw.to("cuda", non_blocking=True)

result = model.forward(pixel_values, grid_thw=grid_thw)
del result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

After the forward pass, the pixel_values and grid_thw tensors remain in GPU memory until the function returns. In an OOM pre-check context, it is best practice to explicitly delete these large tensors and call torch.cuda.empty_cache() to ensure that the memory is immediately available for the subsequent model initialization and KV cache allocation.

Suggested change
del result
del result, pixel_values, grid_thw
torch.cuda.empty_cache()

del result
logger.info(f"vit check max_len {max_batch_size} infer ok")
except (RuntimeError, torch.OutOfMemoryError, ValueError) as e:
logger.exception(str(e))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using logger.exception(str(e)) is redundant because logger.exception automatically captures the exception object and its traceback. It is better to provide a descriptive message about what failed.

Suggested change
logger.exception(str(e))
logger.exception("Qwen VL check max len infer failed")

"1.Set the --visual_infer_batch_size to a smaller value."
)
logger.error(exception_str)
raise Exception(exception_str)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Raising a base Exception is generally discouraged as it makes it harder for calling code to distinguish between different types of failures. Using RuntimeError is more appropriate for an execution failure during a model check.

Suggested change
raise Exception(exception_str)
raise RuntimeError(exception_str)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant