Skip to content

Commit 0b9f1e8

Browse files
authored
update "rollout_max_batch_size" to replace "max_concurrent" for user settings (#1225)
* use rollout_max_batch_size to max_concurrent in concurrency settings * add comments * add comments * fix * add comments * fix * fix * fix * fix * fix * fix * fix comments
1 parent 5c9d3a8 commit 0b9f1e8

File tree

14 files changed

+191
-49
lines changed

14 files changed

+191
-49
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# How to Configure Xtuner Concurrency to Improve Rollout Efficiency
2+
3+
During the Rollout phase of Xtuner, properly configuring concurrency-related parameters ensures that the inference engine maintains high load, fully utilizes hardware resources, and improves overall inference efficiency. This document introduces the main concurrency-related configuration options in Xtuner, explains their relationships, and provides best practice recommendations.
4+
5+
## Main Concurrency-Related Configuration Parameters
6+
7+
1. **RolloutConfig.rollout_max_batch_size_per_instance**
8+
- Controls the maximum batch size that a single inference instance (such as a model process) can handle at one time.
9+
- Larger batch sizes can improve GPU utilization, but excessively large values may cause out-of-memory errors or increased latency.
10+
- It is recommended to adjust this parameter based on the model's context_length and actual GPU memory.
11+
- Xtuner will provide recommended configurations for common models and context lengths in future releases.
12+
13+
2. **RolloutConfig.allow_over_concurrency_ratio**
14+
- Controls the over-concurrency ratio for HTTP requests to ensure the inference engine is fully loaded.
15+
16+
3. **DataflowConfig.max_concurrent**
17+
- Controls the maximum number of concurrent tasks in Dataflow. Dataflow acts as a single controller, distributing data to all rollout workers.
18+
- Dataflow sends a batch of data each time; the actual number of data items sent at the same time is `max_concurrent * prompt_repeat_k`.
19+
- It is recommended to set this slightly higher than the actual processing capability of the inference engine to ensure the inference queue always has tasks.
20+
21+
4. **RAY_MAX_CONCURRENCY**
22+
- The maximum concurrency for the Ray backend, configured via environment variable. The default is 1024.
23+
24+
5. **httpx max connections**
25+
- Controls the maximum number of concurrent connections that the HTTP client (such as RolloutWorker) can initiate to the inference service.
26+
- It is recommended to set this equal to or slightly higher than `rollout_max_batch_size_per_instance`.
27+
28+
## Configuration Relationships and Recommendations
29+
30+
- **Recommended Configuration Process**:
31+
1. Determine a reasonable value for `rollout_max_batch_size_per_instance` based on the model and hardware resources (e.g., 128, 256, 512, 1024). This parameter is optional; if not provided, Xtuner will use preset values based on `context_length`: concurrency is 1024 for `context_length` ≤ 4K, 512 for ≤ 16K, and 128 for ≤ 32K.
32+
2. Set `DataflowConfig.max_concurrent`. It is recommended to use `rollout_max_batch_size_per_instance * num_of_infer_instance / prompt_repeat_k * allow_over_concurrency_ratio`, where `num_of_infer_instance` is the number of inference engine instances started (usually number of nodes / `tensor_parallel_size`).
33+
3. Set the `RAY_MAX_CONCURRENCY` environment variable. It is recommended to set this equal to or slightly higher than `rollout_max_batch_size_per_instance * num_of_infer_instance`.
34+
4. The default httpx max connections should be set to `rollout_max_batch_size_per_instance * allow_over_concurrency_ratio`.
35+
36+
- **Dynamic Adjustment**: You can dynamically adjust these parameters by monitoring the inference queue length, GPU utilization, and response latency to find the optimal concurrency configuration.
37+
38+
## Example Configuration
39+
40+
```python
41+
resource = AcceleratorResourcesConfig(
42+
num_workers=8
43+
)
44+
rollout = RolloutConfig(
45+
rollout_max_batch_size_per_instance=1024,
46+
tensor_parallel_size=1,
47+
...
48+
)
49+
50+
dataflow = DataflowConfig(
51+
max_concurrent=600, # int(1024 * (8 / 1) / 16 * 1.2)
52+
prompt_repeat_k=16,
53+
...
54+
)
55+
56+
# Environment variable setting
57+
export RAY_MAX_CONCURRENCY=1024
58+
```
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# 如何配置 Xtuner 并发度以提升 Rollout 效率
2+
3+
在 Xtuner 的 Rollout 阶段,合理配置并发相关参数可以让推理引擎始终保持高负载,充分利用硬件资源,提升整体推理效率。本文介绍 Xtuner 中与并发度相关的主要配置项及其关系,并给出最佳实践建议。
4+
5+
## 主要并发相关配置参数
6+
7+
1. **RolloutConfig.rollout_max_batch_size_per_instance**
8+
- 控制单个推理实例(如单个模型进程)每次可处理的最大 batch size。
9+
- 较大的 batch size 能提升 GPU 利用率,但过大可能导致显存溢出或延迟增加。
10+
- 推荐根据模型的 context_length 和显卡显存实际情况进行调整。
11+
- Xtuner 后续会提供常见模型与上下文长度下的推荐配置。
12+
13+
2. **RolloutConfig.allow_over_concurrency_ratio**
14+
- 控制 HTTP 请求的超额并发比例,用于保证能够打满推理引擎。
15+
16+
3. **DataflowConfig.max_concurrent**
17+
- 控制 Dataflow 的最大并发任务数。Dataflow 为单一控制器,负责所有 rollout worker 的数据分发。
18+
- Dataflow 每次发送一组数据,实际同一时间内发送的数据条数为 `max_concurrent * prompt_repeat_k`
19+
- 建议设置为略高于推理引擎实际处理能力,以保证推理队列始终有任务。
20+
21+
4. **RAY_MAX_CONCURRENCY**
22+
- Ray 后端的最大并发任务数,通过环境变量配置,默认为 1024。
23+
24+
5. **httpx 最大连接数**
25+
- 控制 HTTP 客户端(如 RolloutWorker)发起到推理服务的最大并发连接数。
26+
- 建议与 `rollout_max_batch_size_per_instance` 保持一致或略高。
27+
28+
## 配置关系与建议
29+
30+
- **推荐配置流程**
31+
1. 根据模型和硬件资源,确定合理的 `rollout_max_batch_size_per_instance`(如 128、256、512、 1024)。该参数为可选,若用户不提供,Xtuner 会根据 `context_length` 提供预设值:`context_length` ≤ 4K 时并发度为 1024,≤ 16K 时并发度为 512,≤ 32K 时并发度为 128。
32+
2. 设定 `DataflowConfig.max_concurrent`,建议为 `rollout_max_batch_size_per_instance * num_of_infer_instance / prompt_repeat_k * allow_over_concurrency_ratio`。其中 `num_of_infer_instance` 为启动的推理引擎实例数量,一般为节点数 / `tensor_parallel_size`
33+
3. 设置 `RAY_MAX_CONCURRENCY` 环境变量,建议与 `rollout_max_batch_size_per_instance * num_of_infer_instance` 保持一致或略高。
34+
4. httpx 最大连接数默认配置为 `rollout_max_batch_size_per_instance * allow_over_concurrency_ratio`
35+
36+
- **动态调整**:可通过监控推理队列长度、GPU 利用率和响应延迟,动态调整上述参数,找到最优并发配置。
37+
38+
## 示例配置
39+
40+
```python
41+
resource = AcceleratorResourcesConfig(
42+
num_workers=8
43+
)
44+
rollout = RolloutConfig(
45+
rollout_max_batch_size_per_instance=1024,
46+
tensor_parallel_size=1,
47+
...
48+
)
49+
50+
dataflow = DataflowConfig(
51+
max_concurrent=600, # int(1024 * (8 / 1) / 16 * 1.2)
52+
prompt_repeat_k=16,
53+
...
54+
)
55+
56+
# 环境变量设置
57+
export RAY_MAX_CONCURRENCY=1024
58+
```

examples/v1/config/rl_interns1_mini_grpo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@
7474
tensor_parallel_size=rollout_tp_size,
7575
expert_parallel_size=rollout_ep_size,
7676
gpu_memory_utilization=0.75,
77+
context_length=max_prompt_length+max_response_length,
7778
extra_rollout_config={
7879
"sglang_grammar_backend": 'none',
7980
}
81+
# rollout_max_batch_size_per_instance=16, # optional
8082
)
8183

8284
# sampling params
@@ -135,7 +137,7 @@
135137
prompt_repeat_k=prompt_repeat_k,
136138
global_batch_size=global_batch_size,
137139
sample_params=training_sample_params,
138-
max_concurrent=64,
140+
# max_concurrent=64, # optional
139141
)
140142

141143
evaluator_cfg = EvaluatorConfig(

examples/v1/config/rl_qwen25_7B_dapo.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
eval_data_path = os.environ["EVAL_DATA_PATH"]
2929
enable_evaluate = True if eval_data_path != "" else False
3030
enbale_partial_rollout = int(os.environ.get("ENBALE_PARTIAL_ROLLOUT", "0"))
31-
# TODO(@duanyanhui): change xtuner_max_concurrency to single rollout engine max concurrency
32-
max_concurrent = int(os.environ.get("XTUNER_MAX_CONCURRENCY", 512))
3331

3432
# basic settings
3533
experimental_name = "dapo_math"
@@ -64,8 +62,8 @@
6462
expert_parallel_size=rollout_ep_size,
6563
gpu_memory_utilization=0.8,
6664
context_length = max_response_length + max_prompt_length,
67-
rollout_max_batch_size=max_concurrent,
6865
prompt_repeat_k=prompt_repeat_k,
66+
# rollout_max_batch_size_per_instance=64, # optional, will be determined automatically if not set
6967
)
7068

7169
# sampling params
@@ -111,7 +109,7 @@
111109
global_batch_size=global_batch_size,
112110
sample_params=training_sample_params,
113111
enable_partial_rollout=enbale_partial_rollout,
114-
max_concurrent=max_concurrent
112+
# max_concurrent=64, # optional, will be determined automatically if not set
115113
)
116114

117115

examples/v1/config/rl_qwen3_8B_grpo.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
eval_data_path = os.environ["EVAL_DATA_PATH"]
2929
enable_evaluate = True if eval_data_path != "" else False
3030
enbale_partial_rollout = int(os.environ.get("ENBALE_PARTIAL_ROLLOUT", "0"))
31-
max_concurrent = int(os.environ.get("XTUNER_MAX_CONCURRENCY", 512))
3231

3332
# basic settings
3433
experimental_name = "grpo_gsm8k"
@@ -44,6 +43,11 @@
4443
hf_interval = 15
4544
enable_initial_evaluate = True
4645
evaluate_step = 10
46+
# TODO: 提供不同模型/不同输入输出长度下最优的rollout_max_batch_size_per_instance配置建议
47+
# NOTE: 目前Xtuner的数据流并发度由rollout_max_batch_size_per_instance控制,并且提供allow_over_concurrency_ratio来控制数据流并发度略大于推理引擎并发度,
48+
# 具体逻辑可见 xtuner/v1/ray/dataflow/flow.py 中 max_concurrent 的计算方式
49+
# 当然你也可以手动调整 dataflow_config 中的 max_concurrent 参数来控制数据流并发度
50+
rollout_max_batch_size_per_instance = 128
4751

4852
# grpo quick test settings for rapid accuracy validation within ~30 minutes:
4953
# - Initial eval accuracy: ~25%
@@ -79,8 +83,8 @@
7983
expert_parallel_size=rollout_ep_size,
8084
gpu_memory_utilization=0.75,
8185
context_length = max_response_length + max_prompt_length,
82-
rollout_max_batch_size=max_concurrent,
8386
prompt_repeat_k=prompt_repeat_k,
87+
# rollout_max_batch_size_per_instance=rollout_max_batch_size_per_instance, # optional, will be determined automatically if not set
8488
)
8589

8690
# sampling params
@@ -114,7 +118,7 @@
114118
global_batch_size=global_batch_size,
115119
sample_params=training_sample_params,
116120
enable_partial_rollout=enbale_partial_rollout,
117-
max_concurrent=max_concurrent
121+
# max_concurrent=64, # optional, will be determined automatically if not set
118122
)
119123

120124
evaluator_cfg = EvaluatorConfig(

examples/v1/config/rl_qwen3_8B_grpo_tiny.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
tensor_parallel_size=rollout_tp_size,
5656
expert_parallel_size=rollout_ep_size,
5757
gpu_memory_utilization=0.75,
58+
context_length=max_prompt_length+max_response_length,
59+
# rollout_max_batch_size_per_instance=1024, # optional
5860
)
5961

6062
# sampling params

examples/v1/config/rl_qwen3_vl_8B_grpo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474
tensor_parallel_size=rollout_tp_size,
7575
expert_parallel_size=rollout_ep_size,
7676
gpu_memory_utilization=0.75,
77+
context_length = max_response_length + max_prompt_length,
78+
# rollout_max_batch_size_per_instance=64, # optional, will be determined automatically if not set
7779
)
7880

7981
# sampling params
@@ -132,6 +134,7 @@
132134
prompt_repeat_k=prompt_repeat_k,
133135
global_batch_size=global_batch_size,
134136
sample_params=training_sample_params,
137+
# max_concurrent=64, # optional, will be determined automatically if not set
135138
)
136139

137140
evaluator_cfg = EvaluatorConfig(

tests/ray/test_evaluator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,14 @@ def init_config(self):
3535
cpu_memory_per_worker=16 * 1024**3, # 16 GB
3636
)
3737
self.max_prompt_length = 512
38+
self.max_response_length = 1024
3839
self.rollout_cfg = RolloutConfig(
3940
env="test_rollout",
4041
model_path=MODEL_PATH,
4142
model_name=os.path.basename(MODEL_PATH).lower(),
4243
tokenizer_path=MODEL_PATH,
4344
tensor_parallel_size=8,
44-
extra_rollout_config={
45-
"lmdeploy_log_level": "CRITICAL",
46-
}
45+
context_length=self.max_prompt_length + self.max_response_length,
4746
)
4847
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
4948
gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
@@ -70,7 +69,7 @@ def init_config(self):
7069
self.sample_params = SampleParams(
7170
top_p=1.0,
7271
temperature=0.0,
73-
max_tokens=1024,
72+
max_tokens=self.max_response_length,
7473
top_k=1
7574
)
7675

tests/ray/test_rollout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def init_config(self):
5858
cpu_memory_per_worker=16 * 1024**3, # 16 GB
5959
)
6060
self.max_prompt_length = 512
61+
self.max_response_length = 1024
6162
self.rollout_cfg = RolloutConfig(
6263
env="test_rollout",
6364
model_path=MODEL_PATH,
@@ -69,6 +70,7 @@ def init_config(self):
6970
gpus_per_node=8, # gpu: 8, npu: 16
7071
dtype="bfloat16",
7172
launch_server_method="ray",
73+
context_length=self.max_prompt_length + self.max_response_length,
7274
)
7375
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
7476
gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
@@ -77,7 +79,6 @@ def init_config(self):
7779
)
7880
self.dataflow_cfg = DataFlowConfig(
7981
env="test",
80-
max_concurrent=32,
8182
prompt_repeat_k=2,
8283
global_batch_size=2,
8384
enable_partial_rollout=0,
@@ -92,7 +93,6 @@ def init_config(self):
9293
},
9394
]
9495
self.dataloader_cfg = DataloaderConfig(
95-
pack_max_length=self.max_prompt_length,
9696
collator='fake_collator',
9797
pack_level='none',
9898
group_by_length=False,

xtuner/v1/ray/config/worker.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,22 @@ class RolloutConfig(BaseModel):
3434
model_path (str | Path): Path to the inference model.
3535
model_name (str): Model name for the backend engine.
3636
tokenizer_path (str): Path to the model tokenizer. Defaults to "".
37-
api_key (Optional[Union[List[str], str]]): API keys for rollout service.
38-
Supports single key or list of keys. Defaults to None.
39-
37+
api_key (Optional[Union[List[str], str]]): API keys for rollout service. Supports single key or list of keys. Defaults to None.
38+
api_port (Optional[int]): Port number for the rollout API server. If not set, it will find an available port starting from 8000. Defaults to 8000.
4039
gpus_per_node (int): Number of GPUs per node. Defaults to 8.
4140
dtype (str): Model data type ('bfloat16', 'float16', 'int8'). Defaults to "bfloat16".
4241
gpu_memory_utilization (float): GPU memory utilization ratio. Defaults to 0.85.
4342
random_seed (int): Random seed for reproducible generation. Defaults to 1024.
44-
4543
rollout_cross_node_comm (bool): Enable cross-node communication. Defaults to False.
44+
rollout_max_batch_size_per_instance (int): Maximum batch size for the rollout worker. If not set, it will be determined automatically based on `context_length`. Defaults to 512.
45+
allow_over_concurrency_ratio (float): Factor to allow over-concurrency in HTTP requests for the rollout worker to improve GPU utilization. Defaults to 1.2.
4646
tensor_parallel_size (int): GPUs per inference engine (tensor parallelism). Defaults to 1.
4747
expert_parallel_size (int): Experts per inference engine (expert parallelism). Defaults to 1.
48-
4948
enable_chunked_prefill (bool): Enable chunked prefill for memory efficiency. Defaults to False.
5049
chunked_prefill_size (int): Chunk size for prefill operations. Defaults to 128.
5150
skip_load_weights (bool): Skip weight loading for rollout worker. Defaults to False.
5251
rollout_timeout (float): Timeout duration in seconds for rollout requests. Defaults to 3600.0.
53-
52+
context_length (int): Context length for the rollout worker.
5453
launch_server_method (Literal["ray", "multiprocessing"]): Server launch method. Defaults to "ray".
5554
system_prompt (Optional[str]): System prompt to guide generation behavior. Defaults to None.
5655
extra_rollout_config (Optional[dict]): Backend-specific configurations using engine prefixes
@@ -114,20 +113,20 @@ class RolloutConfig(BaseModel):
114113
help="Whether to enable cross-node communication for the rollout worker.",
115114
),
116115
] = False
117-
rollout_max_batch_size: Annotated[
116+
rollout_max_batch_size_per_instance: Annotated[
118117
int,
119118
Parameter(
120119
group=infer_group,
121120
help="Maximum batch size for the rollout worker. If not set, it will be determined automatically based on the model and GPU memory.",
122121
),
123122
] = 512
124-
prompt_repeat_k: Annotated[
125-
int,
123+
allow_over_concurrency_ratio: Annotated[
124+
float,
126125
Parameter(
127126
group=infer_group,
128-
help="Number of times to repeat the prompt for each request in the rollout worker.",
127+
help="Factor to allow over concurrency in the http request for rollout worker to improve GPU utilization.",
129128
),
130-
] = 8
129+
] = 1.2
131130
tensor_parallel_size: Annotated[
132131
int,
133132
Parameter(
@@ -247,15 +246,21 @@ def __init__(self, **kwargs):
247246
kwargs["launch_server_method"] = "ray"
248247
kwargs["rollout_cross_node_comm"] = True
249248

250-
# `rollout_max_batch_size` is the max batch size for each inference engine.
251-
# In Xtuner, It is derived from `max_concurrent` in `DataflowConfig`. `max_concurrent` represents the concurrency level for group data batch.
252-
# The total data received by all inference workers is `max_concurrent * prompt_repeat_k`.
253-
# This is then divided by the number of inference engines (i.e., workers with TP_RANK=0) to determine the max batch size per engine.
254-
kwargs["rollout_max_batch_size"] = (
255-
kwargs.get("rollout_max_batch_size", 512)
256-
* kwargs.get("prompt_repeat_k", 1)
257-
/ (int(os.environ.get("NODE_COUNT", 1)) * kwargs["gpus_per_node"] / kwargs.get("tensor_parallel_size", 1))
258-
)
249+
if "rollout_max_batch_size_per_instance" not in kwargs:
250+
assert "context_length" in kwargs, (
251+
"`context_length` must be provided to determine `rollout_max_batch_size_per_instance`."
252+
)
253+
254+
context_length = kwargs["context_length"]
255+
256+
# TODO(@duanyanhui): Provide better suggestions for different models/input-output lengths
257+
if context_length <= 4096:
258+
kwargs["rollout_max_batch_size_per_instance"] = 1024
259+
elif context_length <= 8192:
260+
kwargs["rollout_max_batch_size_per_instance"] = 512
261+
else:
262+
kwargs["rollout_max_batch_size_per_instance"] = 128
263+
259264
super().__init__(**kwargs)
260265
self.worker_log_dir.mkdir(parents=True, exist_ok=True)
261266

0 commit comments

Comments
 (0)