feat: add OOM pre-check for vision models and fix InternVL image dime…#1253
feat: add OOM pre-check for vision models and fix InternVL image dime…#1253
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an OOM (Out of Memory) pre-check mechanism for Qwen-family vision models by performing a dummy forward pass with worst-case image dimensions during initialization. It updates several model implementations to support this check and refactors the ViT model to derive inference parameters from the configuration instead of environment variables. Feedback focuses on improving memory management within the pre-check function by explicitly deleting tensors and clearing the CUDA cache, as well as refining exception handling and logging practices.
| grid_thw = grid_thw.to("cuda", non_blocking=True) | ||
|
|
||
| result = model.forward(pixel_values, grid_thw=grid_thw) | ||
| del result |
There was a problem hiding this comment.
After the forward pass, the pixel_values and grid_thw tensors remain in GPU memory until the function returns. In an OOM pre-check context, it is best practice to explicitly delete these large tensors and call torch.cuda.empty_cache() to ensure that the memory is immediately available for the subsequent model initialization and KV cache allocation.
| del result | |
| del result, pixel_values, grid_thw | |
| torch.cuda.empty_cache() |
| del result | ||
| logger.info(f"vit check max_len {max_batch_size} infer ok") | ||
| except (RuntimeError, torch.OutOfMemoryError, ValueError) as e: | ||
| logger.exception(str(e)) |
There was a problem hiding this comment.
| "1.Set the --visual_infer_batch_size to a smaller value." | ||
| ) | ||
| logger.error(exception_str) | ||
| raise Exception(exception_str) |
There was a problem hiding this comment.
Raising a base Exception is generally discouraged as it makes it harder for calling code to distinguish between different types of failures. Using RuntimeError is more appropriate for an execution failure during a model check.
| raise Exception(exception_str) | |
| raise RuntimeError(exception_str) |
…nsion handling