Skip to content

Commit 57dd84b

Browse files
committed
[Enhance] Set max queue size for async saving hf checkpoint
`xtuner.v1.model.BaseModel.save_hf` using `ProcessPoolExecutor` to submit multiple saving tasks to accelerate the save speed. However, `ProcessPoolExecutor.submit` is nonblocking, it will accumulated lots of `cpu` tensor and will cause cpu oom.
1 parent 3f96a84 commit 57dd84b

File tree

5 files changed

+93
-1
lines changed

5 files changed

+93
-1
lines changed

ci/scripts/save_hf_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import argparse
2+
import time
3+
import torch
4+
import torch.distributed as dist
5+
from xtuner.v1.model import get_model_config_from_hf
6+
from xtuner.v1.config import FSDPConfig
7+
8+
from memory_profiler import profile
9+
10+
MB = 1024 ** 2
11+
12+
def get_args():
13+
p = argparse.ArgumentParser("Profile build/shard/save with @profile (RSS) and simple GPU stats")
14+
p.add_argument("hf_path", type=str, help="HF model path")
15+
p.add_argument("out", type=str, help="Output HF path")
16+
p.add_argument("--ep", type=int, default=1, help="expert parallel size")
17+
return p.parse_args()
18+
19+
def set_device_for_rank():
20+
if torch.cuda.is_available():
21+
rank = dist.get_rank() if dist.is_initialized() else 0
22+
torch.cuda.set_device(rank % torch.cuda.device_count())
23+
24+
def gpu_mem(label):
25+
if not torch.cuda.is_available():
26+
print(f"[GPU] {label}: no CUDA")
27+
return
28+
torch.cuda.synchronize()
29+
alloc = torch.cuda.memory_allocated() / MB
30+
reserved = torch.cuda.memory_reserved() / MB
31+
peak = torch.cuda.max_memory_allocated() / MB
32+
print(f"[GPU] {label}: alloc={alloc:.2f}MB reserved={reserved:.2f}MB peak={peak:.2f}MB")
33+
34+
def build_model(hf_path: str):
35+
cfg = get_model_config_from_hf(hf_path)
36+
model = cfg.build()
37+
return model
38+
39+
def shard_model(model, ep: int):
40+
fsdp_cfg = FSDPConfig(ep_size=ep)
41+
model.fully_shard(fsdp_config=fsdp_cfg)
42+
return model
43+
44+
@profile
45+
def save_model(model, out: str):
46+
model.save_hf(out)
47+
48+
def main():
49+
args = get_args()
50+
51+
dist.init_process_group(backend="nccl")
52+
set_device_for_rank()
53+
54+
t0 = time.perf_counter()
55+
gpu_mem("init")
56+
57+
torch.cuda.reset_peak_memory_stats()
58+
model = build_model(args.hf_path)
59+
gpu_mem("after_build")
60+
61+
torch.cuda.reset_peak_memory_stats()
62+
shard_model(model, args.ep)
63+
gpu_mem("after_shard")
64+
65+
torch.cuda.reset_peak_memory_stats()
66+
save_model(model, args.out)
67+
gpu_mem("after_save")
68+
69+
print(f"[TIME] total={time.perf_counter()-t0:.3f}s")
70+
71+
dist.destroy_process_group()
72+
73+
if __name__ == "__main__":
74+
main()

xtuner/v1/model/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import math
3-
from concurrent.futures import ProcessPoolExecutor, wait
3+
from concurrent.futures import Future, ProcessPoolExecutor, wait
44
from functools import reduce
55
from itertools import chain
66
from pathlib import Path
@@ -73,6 +73,7 @@ class TransformerConfig(PydanticBaseModel):
7373
use_sliding_window: Annotated[bool, Parameter(group="model")] = False
7474
max_window_layers: Annotated[int | None, Parameter(group="model")] = None
7575
rope_scaling_cfg: RopeScalingConfig | None = None
76+
hf_save_worker: Annotated[int, Parameter(group="model")] = 16
7677

7778
@computed_field
7879
def num_attention_heads(self) -> int:
@@ -719,6 +720,7 @@ def _save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16)
719720
hf_dir / safetensor_name,
720721
)
721722
save_futures.append(future)
723+
self._wait_save_task(save_futures)
722724

723725
safetensor_index = 0
724726
for name_list, hf_tensor_list in chain(same_gen, shard_gen):
@@ -742,6 +744,7 @@ def _save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16)
742744
hf_dir / safetensor_name,
743745
)
744746
save_futures.append(future)
747+
self._wait_save_task(save_futures)
745748

746749
if save_executor is not None:
747750
wait(save_futures)
@@ -1103,3 +1106,15 @@ def _to_empty_meta(self):
11031106
module.to_empty(device=self.device, recurse=False)
11041107
DEVICE_MODULE.synchronize()
11051108
return
1109+
1110+
def _wait_save_task(self, tasks: list[Future]):
1111+
"Limit the number of concurrent save tasks to avoid OOM."
1112+
if len(tasks) >= self.config.hf_save_worker:
1113+
done, pending = wait(tasks)
1114+
for future in done:
1115+
if (exception := future.exception()) is not None:
1116+
raise exception
1117+
tasks.clear()
1118+
tasks.extend(pending)
1119+
else:
1120+
return

xtuner/v1/model/compose/intern_s1/intern_s1_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class InternS1BaseConfig(BaseModel):
9393
freeze_vision: bool = False
9494
freeze_projector: bool = False
9595
freeze_language: bool = False
96+
hf_save_worker: int = 16
9697

9798
def build(self) -> "InternS1ForConditionalGeneration":
9899
from .modeling_intern_s1 import InternS1ForConditionalGeneration

xtuner/v1/model/compose/internvl/internvl_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class InternVLBaseConfig(BaseModel):
9191
freeze_vision: bool = False
9292
freeze_projector: bool = False
9393
freeze_language: bool = False
94+
hf_save_worker: int = 16
9495

9596
def build(self) -> "InternVLForConditionalGeneration":
9697
from .modeling_internvl import InternVLForConditionalGeneration

xtuner/v1/model/compose/qwen3_vl/qwen3_vl_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class Qwen3VLBaseConfig(BaseModel):
7878
freeze_vision: bool = False
7979
freeze_projector: bool = False
8080
freeze_language: bool = False
81+
hf_save_worker: int = 16
8182

8283
def build(self):
8384
from .modeling_qwen3_vl import Qwen3VLForConditionalGeneration

0 commit comments

Comments
 (0)