Skip to content

Commit 3918f7c

Browse files
committed
[Enhance] Set max queue size for async saving hf checkpoint
`xtuner.v1.model.BaseModel.save_hf` using `ProcessPoolExecutor` to submit multiple saving tasks to accelerate the save speed. However, `ProcessPoolExecutor.submit` is nonblocking, it will accumulated lots of `cpu` tensor and will cause cpu oom.
1 parent 3f96a84 commit 3918f7c

File tree

5 files changed

+94
-1
lines changed

5 files changed

+94
-1
lines changed

ci/scripts/save_hf_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import argparse
2+
import time
3+
import torch
4+
import torch.distributed as dist
5+
from xtuner.v1.model import get_model_config_from_hf
6+
from xtuner.v1.config import FSDPConfig
7+
8+
from memory_profiler import profile
9+
10+
MB = 1024 ** 2
11+
12+
def get_args():
13+
p = argparse.ArgumentParser("Profile build/shard/save with @profile (RSS) and simple GPU stats")
14+
p.add_argument("hf_path", type=str, help="HF model path")
15+
p.add_argument("out", type=str, help="Output HF path")
16+
p.add_argument("--ep", type=int, default=1, help="expert parallel size")
17+
return p.parse_args()
18+
19+
def set_device_for_rank():
20+
if torch.cuda.is_available():
21+
rank = dist.get_rank() if dist.is_initialized() else 0
22+
torch.cuda.set_device(rank % torch.cuda.device_count())
23+
24+
def gpu_mem(label):
25+
if not torch.cuda.is_available():
26+
print(f"[GPU] {label}: no CUDA")
27+
return
28+
torch.cuda.synchronize()
29+
alloc = torch.cuda.memory_allocated() / MB
30+
reserved = torch.cuda.memory_reserved() / MB
31+
peak = torch.cuda.max_memory_allocated() / MB
32+
print(f"[GPU] {label}: alloc={alloc:.2f}MB reserved={reserved:.2f}MB peak={peak:.2f}MB")
33+
34+
def build_model(hf_path: str):
35+
cfg = get_model_config_from_hf(hf_path)
36+
model = cfg.build()
37+
return model
38+
39+
def shard_model(model, ep: int):
40+
fsdp_cfg = FSDPConfig(ep_size=ep)
41+
model.fully_shard(fsdp_config=fsdp_cfg)
42+
return model
43+
44+
@profile
45+
def save_model(model, out: str):
46+
model.save_hf(out)
47+
48+
def main():
49+
args = get_args()
50+
51+
dist.init_process_group(backend="nccl")
52+
set_device_for_rank()
53+
54+
t0 = time.perf_counter()
55+
gpu_mem("init")
56+
57+
torch.cuda.reset_peak_memory_stats()
58+
model = build_model(args.hf_path)
59+
gpu_mem("after_build")
60+
61+
torch.cuda.reset_peak_memory_stats()
62+
shard_model(model, args.ep)
63+
gpu_mem("after_shard")
64+
65+
torch.cuda.reset_peak_memory_stats()
66+
save_model(model, args.out)
67+
gpu_mem("after_save")
68+
69+
print(f"[TIME] total={time.perf_counter()-t0:.3f}s")
70+
71+
dist.destroy_process_group()
72+
73+
if __name__ == "__main__":
74+
main()

xtuner/v1/model/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import json
22
import math
3-
from concurrent.futures import ProcessPoolExecutor, wait
3+
from concurrent.futures import ProcessPoolExecutor, wait, Future
44
from functools import reduce
55
from itertools import chain
66
from pathlib import Path
7+
from queue import Queue
78
from shutil import copy, copytree
89
from typing import Annotated, Generator, Literal, cast
910

@@ -73,6 +74,7 @@ class TransformerConfig(PydanticBaseModel):
7374
use_sliding_window: Annotated[bool, Parameter(group="model")] = False
7475
max_window_layers: Annotated[int | None, Parameter(group="model")] = None
7576
rope_scaling_cfg: RopeScalingConfig | None = None
77+
hf_save_worker: Annotated[int, Parameter(group="model")] = 16
7678

7779
@computed_field
7880
def num_attention_heads(self) -> int:
@@ -719,6 +721,7 @@ def _save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16)
719721
hf_dir / safetensor_name,
720722
)
721723
save_futures.append(future)
724+
self._wait_save_task(save_futures)
722725

723726
safetensor_index = 0
724727
for name_list, hf_tensor_list in chain(same_gen, shard_gen):
@@ -742,6 +745,7 @@ def _save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16)
742745
hf_dir / safetensor_name,
743746
)
744747
save_futures.append(future)
748+
self._wait_save_task(save_futures)
745749

746750
if save_executor is not None:
747751
wait(save_futures)
@@ -1103,3 +1107,15 @@ def _to_empty_meta(self):
11031107
module.to_empty(device=self.device, recurse=False)
11041108
DEVICE_MODULE.synchronize()
11051109
return
1110+
1111+
def _wait_save_task(self, tasks: list[Future]):
1112+
"Limit the number of concurrent save tasks to avoid OOM. "
1113+
if len(tasks) >= self.config.hf_save_worker:
1114+
done, pending = wait(tasks)
1115+
for future in done:
1116+
if (exception := future.exception()) is not None:
1117+
raise exception
1118+
tasks.clear()
1119+
tasks.extend(pending)
1120+
else:
1121+
return

xtuner/v1/model/compose/intern_s1/intern_s1_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class InternS1BaseConfig(BaseModel):
9393
freeze_vision: bool = False
9494
freeze_projector: bool = False
9595
freeze_language: bool = False
96+
hf_save_worker: int = 16
9697

9798
def build(self) -> "InternS1ForConditionalGeneration":
9899
from .modeling_intern_s1 import InternS1ForConditionalGeneration

xtuner/v1/model/compose/internvl/internvl_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class InternVLBaseConfig(BaseModel):
9191
freeze_vision: bool = False
9292
freeze_projector: bool = False
9393
freeze_language: bool = False
94+
hf_save_worker: int = 16
9495

9596
def build(self) -> "InternVLForConditionalGeneration":
9697
from .modeling_internvl import InternVLForConditionalGeneration

xtuner/v1/model/compose/qwen3_vl/qwen3_vl_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class Qwen3VLBaseConfig(BaseModel):
7878
freeze_vision: bool = False
7979
freeze_projector: bool = False
8080
freeze_language: bool = False
81+
hf_save_worker: int = 16
8182

8283
def build(self):
8384
from .modeling_qwen3_vl import Qwen3VLForConditionalGeneration

0 commit comments

Comments
 (0)