Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions ci/scripts/save_hf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import time
import torch
import torch.distributed as dist
from xtuner.v1.model import get_model_config_from_hf
from xtuner.v1.config import FSDPConfig

from memory_profiler import profile

MB = 1024 ** 2

def get_args():
p = argparse.ArgumentParser("Profile build/shard/save with @profile (RSS) and simple GPU stats")
p.add_argument("hf_path", type=str, help="HF model path")
p.add_argument("out", type=str, help="Output HF path")
p.add_argument("--ep", type=int, default=1, help="expert parallel size")
return p.parse_args()

def set_device_for_rank():
if torch.cuda.is_available():
rank = dist.get_rank() if dist.is_initialized() else 0
torch.cuda.set_device(rank % torch.cuda.device_count())

def gpu_mem(label):
if not torch.cuda.is_available():
print(f"[GPU] {label}: no CUDA")
return
torch.cuda.synchronize()
alloc = torch.cuda.memory_allocated() / MB
reserved = torch.cuda.memory_reserved() / MB
peak = torch.cuda.max_memory_allocated() / MB
print(f"[GPU] {label}: alloc={alloc:.2f}MB reserved={reserved:.2f}MB peak={peak:.2f}MB")

def build_model(hf_path: str):
cfg = get_model_config_from_hf(hf_path)
model = cfg.build()
return model

def shard_model(model, ep: int):
fsdp_cfg = FSDPConfig(ep_size=ep)
model.fully_shard(fsdp_config=fsdp_cfg)
return model

@profile
def save_model(model, out: str):
model.save_hf(out)

def main():
args = get_args()

dist.init_process_group(backend="nccl")
set_device_for_rank()

t0 = time.perf_counter()
gpu_mem("init")

torch.cuda.reset_peak_memory_stats()
model = build_model(args.hf_path)
gpu_mem("after_build")

torch.cuda.reset_peak_memory_stats()
shard_model(model, args.ep)
gpu_mem("after_shard")

torch.cuda.reset_peak_memory_stats()
save_model(model, args.out)
gpu_mem("after_save")

print(f"[TIME] total={time.perf_counter()-t0:.3f}s")

dist.destroy_process_group()

if __name__ == "__main__":
main()
19 changes: 18 additions & 1 deletion xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import math
from concurrent.futures import ProcessPoolExecutor, wait
from concurrent.futures import Future, ProcessPoolExecutor, wait
from functools import reduce
from itertools import chain
from pathlib import Path
Expand Down Expand Up @@ -73,6 +73,7 @@ class TransformerConfig(PydanticBaseModel):
use_sliding_window: Annotated[bool, Parameter(group="model")] = False
max_window_layers: Annotated[int | None, Parameter(group="model")] = None
rope_scaling_cfg: RopeScalingConfig | None = None
hf_save_worker: Annotated[int, Parameter(group="model")] = 16

@computed_field
def num_attention_heads(self) -> int:
Expand Down Expand Up @@ -722,6 +723,7 @@ def _save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16)
hf_dir / safetensor_name,
)
save_futures.append(future)
self._wait_save_task(save_futures)

safetensor_index = 0
for name_list, hf_tensor_list in chain(same_gen, shard_gen):
Expand All @@ -745,6 +747,7 @@ def _save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16)
hf_dir / safetensor_name,
)
save_futures.append(future)
self._wait_save_task(save_futures)

if save_executor is not None:
wait(save_futures)
Expand Down Expand Up @@ -1115,3 +1118,17 @@ def _to_empty_meta(self):
module.to_empty(device=self.device, recurse=False)
DEVICE_MODULE.synchronize()
return

def _wait_save_task(self, tasks: list[Future]):
"Limit the number of concurrent save tasks to avoid OOM."
# The older version of xtuner does not have hf_save_worker attributes, using `getattr` avoid from unpickling
# the old config for backward compatibility.
if len(tasks) >= getattr(self.config, "hf_save_worker", 16):
done, pending = wait(tasks)
for future in done:
if (exception := future.exception()) is not None:
raise exception
tasks.clear()
tasks.extend(pending)
else:
return
1 change: 1 addition & 0 deletions xtuner/v1/model/compose/intern_s1/intern_s1_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class InternS1BaseConfig(BaseModel):
freeze_vision: bool = False
freeze_projector: bool = False
freeze_language: bool = False
hf_save_worker: int = 16

def build(self) -> "InternS1ForConditionalGeneration":
from .modeling_intern_s1 import InternS1ForConditionalGeneration
Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/model/compose/internvl/internvl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class InternVLBaseConfig(BaseModel):
freeze_vision: bool = False
freeze_projector: bool = False
freeze_language: bool = False
hf_save_worker: int = 16

def build(self) -> "InternVLForConditionalGeneration":
from .modeling_internvl import InternVLForConditionalGeneration
Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/model/compose/qwen3_vl/qwen3_vl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Qwen3VLBaseConfig(BaseModel):
freeze_vision: bool = False
freeze_projector: bool = False
freeze_language: bool = False
hf_save_worker: int = 16

def build(self):
from .modeling_qwen3_vl import Qwen3VLForConditionalGeneration
Expand Down
Loading