Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions docs/en/rl/advanced_tutorial/efficiency.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# How to Configure Xtuner Concurrency to Improve Rollout Efficiency

During the Rollout phase of Xtuner, properly configuring concurrency-related parameters ensures that the inference engine maintains high load, fully utilizes hardware resources, and improves overall inference efficiency. This document introduces the main concurrency-related configuration options in Xtuner, explains their relationships, and provides best practice recommendations.

## Main Concurrency-Related Configuration Parameters

1. **RolloutConfig.rollout_max_batch_size_per_instance**
- Controls the maximum batch size that a single inference instance (such as a model process) can handle at one time.
- Larger batch sizes can improve GPU utilization, but excessively large values may cause out-of-memory errors or increased latency.
- It is recommended to adjust this parameter based on the model's context_length and actual GPU memory.
- Xtuner will provide recommended configurations for common models and context lengths in future releases.

2. **RolloutConfig.allow_over_concurrency_ratio**
- Controls the over-concurrency ratio for HTTP requests to ensure the inference engine is fully loaded.

3. **DataflowConfig.max_concurrent**
- Controls the maximum number of concurrent tasks in Dataflow. Dataflow acts as a single controller, distributing data to all rollout workers.
- Dataflow sends a batch of data each time; the actual number of data items sent at the same time is `max_concurrent * prompt_repeat_k`.
- It is recommended to set this slightly higher than the actual processing capability of the inference engine to ensure the inference queue always has tasks.

4. **RAY_MAX_CONCURRENCY**
- The maximum concurrency for the Ray backend, configured via environment variable. The default is 1024.

5. **httpx max connections**
- Controls the maximum number of concurrent connections that the HTTP client (such as RolloutWorker) can initiate to the inference service.
- It is recommended to set this equal to or slightly higher than `rollout_max_batch_size_per_instance`.

## Configuration Relationships and Recommendations

- **Recommended Configuration Process**:
1. Determine a reasonable value for `rollout_max_batch_size_per_instance` based on the model and hardware resources (e.g., 128, 256, 512, 1024). This parameter is optional; if not provided, Xtuner will use preset values based on `context_length`: concurrency is 1024 for `context_length` ≤ 4K, 512 for ≤ 16K, and 128 for ≤ 32K.
2. Set `DataflowConfig.max_concurrent`. It is recommended to use `rollout_max_batch_size_per_instance * num_of_infer_instance / prompt_repeat_k * allow_over_concurrency_ratio`, where `num_of_infer_instance` is the number of inference engine instances started (usually number of nodes / `tensor_parallel_size`).
3. Set the `RAY_MAX_CONCURRENCY` environment variable. It is recommended to set this equal to or slightly higher than `rollout_max_batch_size_per_instance * num_of_infer_instance`.
4. The default httpx max connections should be set to `rollout_max_batch_size_per_instance * allow_over_concurrency_ratio`.

- **Dynamic Adjustment**: You can dynamically adjust these parameters by monitoring the inference queue length, GPU utilization, and response latency to find the optimal concurrency configuration.

## Example Configuration

```python
resource = AcceleratorResourcesConfig(
num_workers=8
)
rollout = RolloutConfig(
rollout_max_batch_size_per_instance=1024,
tensor_parallel_size=1,
...
)

dataflow = DataflowConfig(
max_concurrent=600, # int(1024 * (8 / 1) / 16 * 1.2)
prompt_repeat_k=16,
...
)

# Environment variable setting
export RAY_MAX_CONCURRENCY=1024
```
58 changes: 58 additions & 0 deletions docs/zh_cn/rl/advanced_tutorial/efficiency.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# 如何配置 Xtuner 并发度以提升 Rollout 效率

在 Xtuner 的 Rollout 阶段,合理配置并发相关参数可以让推理引擎始终保持高负载,充分利用硬件资源,提升整体推理效率。本文介绍 Xtuner 中与并发度相关的主要配置项及其关系,并给出最佳实践建议。

## 主要并发相关配置参数

1. **RolloutConfig.rollout_max_batch_size_per_instance**
- 控制单个推理实例(如单个模型进程)每次可处理的最大 batch size。
- 较大的 batch size 能提升 GPU 利用率,但过大可能导致显存溢出或延迟增加。
- 推荐根据模型的 context_length 和显卡显存实际情况进行调整。
- Xtuner 后续会提供常见模型与上下文长度下的推荐配置。

2. **RolloutConfig.allow_over_concurrency_ratio**
- 控制 HTTP 请求的超额并发比例,用于保证能够打满推理引擎。

3. **DataflowConfig.max_concurrent**
- 控制 Dataflow 的最大并发任务数。Dataflow 为单一控制器,负责所有 rollout worker 的数据分发。
- Dataflow 每次发送一组数据,实际同一时间内发送的数据条数为 `max_concurrent * prompt_repeat_k`。
- 建议设置为略高于推理引擎实际处理能力,以保证推理队列始终有任务。

4. **RAY_MAX_CONCURRENCY**
- Ray 后端的最大并发任务数,通过环境变量配置,默认为 1024。

5. **httpx 最大连接数**
- 控制 HTTP 客户端(如 RolloutWorker)发起到推理服务的最大并发连接数。
- 建议与 `rollout_max_batch_size_per_instance` 保持一致或略高。

## 配置关系与建议

- **推荐配置流程**:
1. 根据模型和硬件资源,确定合理的 `rollout_max_batch_size_per_instance`(如 128、256、512、 1024)。该参数为可选,若用户不提供,Xtuner 会根据 `context_length` 提供预设值:`context_length` ≤ 4K 时并发度为 1024,≤ 16K 时并发度为 512,≤ 32K 时并发度为 128。
2. 设定 `DataflowConfig.max_concurrent`,建议为 `rollout_max_batch_size_per_instance * num_of_infer_instance / prompt_repeat_k * allow_over_concurrency_ratio`。其中 `num_of_infer_instance` 为启动的推理引擎实例数量,一般为节点数 / `tensor_parallel_size`。
3. 设置 `RAY_MAX_CONCURRENCY` 环境变量,建议与 `rollout_max_batch_size_per_instance * num_of_infer_instance` 保持一致或略高。
4. httpx 最大连接数默认配置为 `rollout_max_batch_size_per_instance * allow_over_concurrency_ratio`。

- **动态调整**:可通过监控推理队列长度、GPU 利用率和响应延迟,动态调整上述参数,找到最优并发配置。

## 示例配置

```python
resource = AcceleratorResourcesConfig(
num_workers=8
)
rollout = RolloutConfig(
rollout_max_batch_size_per_instance=1024,
tensor_parallel_size=1,
...
)

dataflow = DataflowConfig(
max_concurrent=600, # int(1024 * (8 / 1) / 16 * 1.2)
prompt_repeat_k=16,
...
)

# 环境变量设置
export RAY_MAX_CONCURRENCY=1024
```
4 changes: 3 additions & 1 deletion examples/v1/config/rl_interns1_mini_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@
tensor_parallel_size=rollout_tp_size,
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.75,
context_length=max_prompt_length+max_response_length,
extra_rollout_config={
"sglang_grammar_backend": 'none',
}
# rollout_max_batch_size_per_instance=16, # optional
)

# sampling params
Expand Down Expand Up @@ -135,7 +137,7 @@
prompt_repeat_k=prompt_repeat_k,
global_batch_size=global_batch_size,
sample_params=training_sample_params,
max_concurrent=64,
# max_concurrent=64, # optional
)

evaluator_cfg = EvaluatorConfig(
Expand Down
6 changes: 2 additions & 4 deletions examples/v1/config/rl_qwen25_7B_dapo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
eval_data_path = os.environ["EVAL_DATA_PATH"]
enable_evaluate = True if eval_data_path != "" else False
enbale_partial_rollout = int(os.environ.get("ENBALE_PARTIAL_ROLLOUT", "0"))
# TODO(@duanyanhui): change xtuner_max_concurrency to single rollout engine max concurrency
max_concurrent = int(os.environ.get("XTUNER_MAX_CONCURRENCY", 512))

# basic settings
experimental_name = "dapo_math"
Expand Down Expand Up @@ -64,8 +62,8 @@
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.8,
context_length = max_response_length + max_prompt_length,
rollout_max_batch_size=max_concurrent,
prompt_repeat_k=prompt_repeat_k,
# rollout_max_batch_size_per_instance=64, # optional, will be determined automatically if not set
)

# sampling params
Expand Down Expand Up @@ -111,7 +109,7 @@
global_batch_size=global_batch_size,
sample_params=training_sample_params,
enable_partial_rollout=enbale_partial_rollout,
max_concurrent=max_concurrent
# max_concurrent=64, # optional, will be determined automatically if not set
)


Expand Down
10 changes: 7 additions & 3 deletions examples/v1/config/rl_qwen3_8B_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
eval_data_path = os.environ["EVAL_DATA_PATH"]
enable_evaluate = True if eval_data_path != "" else False
enbale_partial_rollout = int(os.environ.get("ENBALE_PARTIAL_ROLLOUT", "0"))
max_concurrent = int(os.environ.get("XTUNER_MAX_CONCURRENCY", 512))

# basic settings
experimental_name = "grpo_gsm8k"
Expand All @@ -44,6 +43,11 @@
hf_interval = 15
enable_initial_evaluate = True
evaluate_step = 10
# TODO: 提供不同模型/不同输入输出长度下最优的rollout_max_batch_size_per_instance配置建议
# NOTE: 目前Xtuner的数据流并发度由rollout_max_batch_size_per_instance控制,并且提供allow_over_concurrency_ratio来控制数据流并发度略大于推理引擎并发度,
# 具体逻辑可见 xtuner/v1/ray/dataflow/flow.py 中 max_concurrent 的计算方式
# 当然你也可以手动调整 dataflow_config 中的 max_concurrent 参数来控制数据流并发度
rollout_max_batch_size_per_instance = 128

# grpo quick test settings for rapid accuracy validation within ~30 minutes:
# - Initial eval accuracy: ~25%
Expand Down Expand Up @@ -79,8 +83,8 @@
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.75,
context_length = max_response_length + max_prompt_length,
rollout_max_batch_size=max_concurrent,
prompt_repeat_k=prompt_repeat_k,
# rollout_max_batch_size_per_instance=rollout_max_batch_size_per_instance, # optional, will be determined automatically if not set
)

# sampling params
Expand Down Expand Up @@ -114,7 +118,7 @@
global_batch_size=global_batch_size,
sample_params=training_sample_params,
enable_partial_rollout=enbale_partial_rollout,
max_concurrent=max_concurrent
# max_concurrent=64, # optional, will be determined automatically if not set
)

evaluator_cfg = EvaluatorConfig(
Expand Down
2 changes: 2 additions & 0 deletions examples/v1/config/rl_qwen3_8B_grpo_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
tensor_parallel_size=rollout_tp_size,
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.75,
context_length=max_prompt_length+max_response_length,
# rollout_max_batch_size_per_instance=1024, # optional
)

# sampling params
Expand Down
3 changes: 3 additions & 0 deletions examples/v1/config/rl_qwen3_vl_8B_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
tensor_parallel_size=rollout_tp_size,
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.75,
context_length = max_response_length + max_prompt_length,
# rollout_max_batch_size_per_instance=64, # optional, will be determined automatically if not set
)

# sampling params
Expand Down Expand Up @@ -132,6 +134,7 @@
prompt_repeat_k=prompt_repeat_k,
global_batch_size=global_batch_size,
sample_params=training_sample_params,
# max_concurrent=64, # optional, will be determined automatically if not set
)

evaluator_cfg = EvaluatorConfig(
Expand Down
7 changes: 3 additions & 4 deletions tests/ray/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ def init_config(self):
cpu_memory_per_worker=16 * 1024**3, # 16 GB
)
self.max_prompt_length = 512
self.max_response_length = 1024
self.rollout_cfg = RolloutConfig(
env="test_rollout",
model_path=MODEL_PATH,
model_name=os.path.basename(MODEL_PATH).lower(),
tokenizer_path=MODEL_PATH,
tensor_parallel_size=8,
extra_rollout_config={
"lmdeploy_log_level": "CRITICAL",
}
context_length=self.max_prompt_length + self.max_response_length,
)
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
Expand All @@ -70,7 +69,7 @@ def init_config(self):
self.sample_params = SampleParams(
top_p=1.0,
temperature=0.0,
max_tokens=1024,
max_tokens=self.max_response_length,
top_k=1
)

Expand Down
4 changes: 2 additions & 2 deletions tests/ray/test_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def init_config(self):
cpu_memory_per_worker=16 * 1024**3, # 16 GB
)
self.max_prompt_length = 512
self.max_response_length = 1024
self.rollout_cfg = RolloutConfig(
env="test_rollout",
model_path=MODEL_PATH,
Expand All @@ -69,6 +70,7 @@ def init_config(self):
gpus_per_node=8, # gpu: 8, npu: 16
dtype="bfloat16",
launch_server_method="ray",
context_length=self.max_prompt_length + self.max_response_length,
)
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
Expand All @@ -77,7 +79,6 @@ def init_config(self):
)
self.dataflow_cfg = DataFlowConfig(
env="test",
max_concurrent=32,
prompt_repeat_k=2,
global_batch_size=2,
enable_partial_rollout=0,
Expand All @@ -92,7 +93,6 @@ def init_config(self):
},
]
self.dataloader_cfg = DataloaderConfig(
pack_max_length=self.max_prompt_length,
collator='fake_collator',
pack_level='none',
group_by_length=False,
Expand Down
45 changes: 25 additions & 20 deletions xtuner/v1/ray/config/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,22 @@ class RolloutConfig(BaseModel):
model_path (str | Path): Path to the inference model.
model_name (str): Model name for the backend engine.
tokenizer_path (str): Path to the model tokenizer. Defaults to "".
api_key (Optional[Union[List[str], str]]): API keys for rollout service.
Supports single key or list of keys. Defaults to None.

api_key (Optional[Union[List[str], str]]): API keys for rollout service. Supports single key or list of keys. Defaults to None.
api_port (Optional[int]): Port number for the rollout API server. If not set, it will find an available port starting from 8000. Defaults to 8000.
gpus_per_node (int): Number of GPUs per node. Defaults to 8.
dtype (str): Model data type ('bfloat16', 'float16', 'int8'). Defaults to "bfloat16".
gpu_memory_utilization (float): GPU memory utilization ratio. Defaults to 0.85.
random_seed (int): Random seed for reproducible generation. Defaults to 1024.

rollout_cross_node_comm (bool): Enable cross-node communication. Defaults to False.
rollout_max_batch_size_per_instance (int): Maximum batch size for the rollout worker. If not set, it will be determined automatically based on `context_length`. Defaults to 512.
allow_over_concurrency_ratio (float): Factor to allow over-concurrency in HTTP requests for the rollout worker to improve GPU utilization. Defaults to 1.2.
tensor_parallel_size (int): GPUs per inference engine (tensor parallelism). Defaults to 1.
expert_parallel_size (int): Experts per inference engine (expert parallelism). Defaults to 1.

enable_chunked_prefill (bool): Enable chunked prefill for memory efficiency. Defaults to False.
chunked_prefill_size (int): Chunk size for prefill operations. Defaults to 128.
skip_load_weights (bool): Skip weight loading for rollout worker. Defaults to False.
rollout_timeout (float): Timeout duration in seconds for rollout requests. Defaults to 3600.0.

context_length (int): Context length for the rollout worker.
launch_server_method (Literal["ray", "multiprocessing"]): Server launch method. Defaults to "ray".
system_prompt (Optional[str]): System prompt to guide generation behavior. Defaults to None.
extra_rollout_config (Optional[dict]): Backend-specific configurations using engine prefixes
Expand Down Expand Up @@ -111,20 +110,20 @@ class RolloutConfig(BaseModel):
help="Whether to enable cross-node communication for the rollout worker.",
),
] = False
rollout_max_batch_size: Annotated[
rollout_max_batch_size_per_instance: Annotated[
int,
Parameter(
group=infer_group,
help="Maximum batch size for the rollout worker. If not set, it will be determined automatically based on the model and GPU memory.",
),
] = 512
prompt_repeat_k: Annotated[
int,
allow_over_concurrency_ratio: Annotated[
float,
Parameter(
group=infer_group,
help="Number of times to repeat the prompt for each request in the rollout worker.",
help="Factor to allow over concurrency in the http request for rollout worker to improve GPU utilization.",
),
] = 8
] = 1.2
tensor_parallel_size: Annotated[
int,
Parameter(
Expand Down Expand Up @@ -244,15 +243,21 @@ def __init__(self, **kwargs):
kwargs["launch_server_method"] = "ray"
kwargs["rollout_cross_node_comm"] = True

# `rollout_max_batch_size` is the max batch size for each inference engine.
# In Xtuner, It is derived from `max_concurrent` in `DataflowConfig`. `max_concurrent` represents the concurrency level for group data batch.
# The total data received by all inference workers is `max_concurrent * prompt_repeat_k`.
# This is then divided by the number of inference engines (i.e., workers with TP_RANK=0) to determine the max batch size per engine.
kwargs["rollout_max_batch_size"] = (
kwargs.get("rollout_max_batch_size", 512)
* kwargs.get("prompt_repeat_k", 1)
/ (int(os.environ.get("NODE_COUNT", 1)) * kwargs["gpus_per_node"] / kwargs.get("tensor_parallel_size", 1))
)
if "rollout_max_batch_size_per_instance" not in kwargs:
assert "context_length" in kwargs, (
"`context_length` must be provided to determine `rollout_max_batch_size_per_instance`."
)

context_length = kwargs["context_length"]

# TODO(@duanyanhui): Provide better suggestions for different models/input-output lengths
if context_length <= 4096:
kwargs["rollout_max_batch_size_per_instance"] = 1024
elif context_length <= 8192:
kwargs["rollout_max_batch_size_per_instance"] = 512
else:
kwargs["rollout_max_batch_size_per_instance"] = 128

super().__init__(**kwargs)
self.worker_log_dir.mkdir(parents=True, exist_ok=True)

Expand Down
Loading