-
Notifications
You must be signed in to change notification settings - Fork 315
feat: add OOM pre-check for vision models and fix InternVL image dime… #1253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||
| from __future__ import annotations | ||||||
| import math | ||||||
| import os | ||||||
| import torch | ||||||
| import numpy as np | ||||||
| from PIL import Image | ||||||
|
|
@@ -27,6 +28,59 @@ | |||||
| logger = init_logger(__name__) | ||||||
|
|
||||||
|
|
||||||
| def closest_factor_pair(n): | ||||||
| """Find the factor pair of n closest to sqrt(n). Returns (smaller, larger).""" | ||||||
| sqrt_n = int(math.sqrt(n)) | ||||||
| for i in range(sqrt_n, 0, -1): | ||||||
| if n % i == 0: | ||||||
| return i, n // i | ||||||
| return 1, n | ||||||
|
|
||||||
|
|
||||||
| @torch.no_grad() | ||||||
| def qwen_vl_check_max_len_infer(model, max_batch_size): | ||||||
| """OOM pre-check for Qwen-family vision models. | ||||||
|
|
||||||
| Constructs worst-case dummy images at max_pixels resolution, | ||||||
| replicates for max_batch_size, and runs a forward pass to validate | ||||||
| GPU memory is sufficient. | ||||||
| """ | ||||||
| disable_check = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None | ||||||
| if disable_check: | ||||||
| return | ||||||
|
|
||||||
| unit = model.patch_size * model.spatial_merge_size | ||||||
| max_pixels = model.processor.max_pixels | ||||||
| max_patches = max_pixels // (unit * unit) | ||||||
| if max_patches < 1: | ||||||
| max_patches = 1 | ||||||
| h_factor, w_factor = closest_factor_pair(max_patches) | ||||||
| worst_h = unit * h_factor | ||||||
| worst_w = unit * w_factor | ||||||
|
|
||||||
| try: | ||||||
| dummy_image = Image.new("RGB", (worst_w, worst_h), color=(128, 128, 128)) | ||||||
| pixel_values, grid_thw = model.processor.preprocess(dummy_image) | ||||||
|
|
||||||
| pixel_values = pixel_values.repeat(max_batch_size, 1, 1) | ||||||
| grid_thw = grid_thw.repeat(max_batch_size, 1) | ||||||
|
|
||||||
| pixel_values = pixel_values.to("cuda", dtype=model.data_type, non_blocking=True) | ||||||
| grid_thw = grid_thw.to("cuda", non_blocking=True) | ||||||
|
|
||||||
| result = model.forward(pixel_values, grid_thw=grid_thw) | ||||||
| del result | ||||||
| logger.info(f"vit check max_len {max_batch_size} infer ok") | ||||||
| except (RuntimeError, torch.OutOfMemoryError, ValueError) as e: | ||||||
| logger.exception(str(e)) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| exception_str = ( | ||||||
| "Vit check max len infer fail, you can try: " | ||||||
| "1.Set the --visual_infer_batch_size to a smaller value." | ||||||
| ) | ||||||
| logger.error(exception_str) | ||||||
| raise Exception(exception_str) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Raising a base
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| IMAGE_FACTOR = 28 | ||||||
| MIN_PIXELS = 4 * 28 * 28 | ||||||
| MAX_PIXELS = 16384 * 28 * 28 | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After the forward pass, the
pixel_valuesandgrid_thwtensors remain in GPU memory until the function returns. In an OOM pre-check context, it is best practice to explicitly delete these large tensors and calltorch.cuda.empty_cache()to ensure that the memory is immediately available for the subsequent model initialization and KV cache allocation.