-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GRPO Environments for custom multi-step rollouts (vLLM-only) #2810
Conversation
trl/trainer/grpo_trainer.py
Outdated
enable_prefix_caching=True, | ||
max_model_len=self.args.vllm_max_model_len, | ||
) | ||
self.sampling_params = SamplingParams( | ||
temperature=args.temperature, | ||
max_tokens=self.max_completion_length, | ||
skip_special_tokens=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI - on vLLM we have another parameter spaces_between_special_tokens
(set to True
by default) that adds spaces between special tokens in the output when skip_special_tokens=False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, thanks for heads up. Will do some testing and make sure everything looks OK at token-level, maybe don't need that line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted that line. If users want different SamplingParams settings they can specify that at the Environment level anyway.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to wrap the generate method instead? It would allow keep grpo as is
I'm not quite sure what you mean by wrapping the generate method here. Many parts of the codebase touch the LLM object, wrapping it in another object would require changing the access in each place. Overriding generate entails a high amount of complexity on the user end, as most applications will either want to use the true generate (or chat, which calls generate) method. Adding an if/else Environment route like in this PR is the simplest approach I could think of which allows users to directly use the LLM object as-is in their rollouts, while still allowing the Trainer to reference LLM normally throughout. Note that this enables many requested features--tool use, sampling strategies, agentic interactions--to be encapsulated within Environments, avoiding further complexity down the road. If you have something specific in mind that you could illustrate with a short snippet I'm happy to try. |
Actually it can be pretty straightforward and simple: def wrapper_decorator(generate_func):
def generate_wrapper(*args, **kwargs):
... # stuff before
result = generate_func(*args, **kwargs)
... # stuff after
return result
return generate_wrapper
trainer.llm.model.generate = wrapper_decorator(trainer.llm.model.generate) |
Seems like the wrapper idea could be implemented internally. The user interface could remain the same with the environments idea. Also, I think there might be a bug in the # original
completion_ids = [all_ids[states[i]["prompt_tokens"]:] for i, output in enumerate(outputs)]
# fixed
completion_ids = [all_ids[i][states[i]["prompt_tokens"]:] for i in range(len(prompts))] |
trl/environment/env_protocol.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstrings to make implementation easier for end users:
class Environment(Protocol):
"""
A protocol describing the minimal interface needed for integration
with the trainer. Your environment can run any multi-step logic,
but must ultimately return token sequences akin to a typical
vllm.LLM's generate() output. https://docs.vllm.ai/en/stable/api/offline_inference/llm.html
"""
def generate(
self,
prompts: List[List[Dict[str, Any]]],
llm: Any,
sampling_params: Any
) -> List[Any]:
...
docs/source/grpo_trainer.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Examples are best way to learn how to use libraries... I took your example and added illustrative comments to make it easier to understand what is going on and how a user might implement their own.
class DoubleCheckEnv:
"""
Example Environment that:
1) Sends an initial user prompt to the LLM.
2) Appends the assistant's reply and a follow-up user query: "Are you sure?".
3) Sends everything again to the LLM for a final response.
4) Returns just the completion tokens for each prompt.
"""
def step(
self,
states: List[Dict[str, Any]],
llm: LLM,
sampling_params: SamplingParams
) -> Tuple[List[Dict[str, Any]], List[RequestOutput]]:
# First LLM call for each state's messages
outputs = llm.chat([s["messages"] for s in states], sampling_params=sampling_params)
for i, state in enumerate(states):
state["messages"].append({
"role": "assistant",
"content": outputs[i].outputs[0].text
})
state["messages"].append({
"role": "user",
"content": "Are you sure?"
})
# Track prompt_tokens to later slice out the completion part
state["prompt_tokens"] = len(outputs[i].prompt_token_ids)
# Second LLM call after "Are you sure?" is appended
outputs = llm.chat([s["messages"] for s in states], sampling_params=sampling_params)
for i, state in enumerate(states):
state["messages"].append({
"role": "assistant",
"content": outputs[i].outputs[0].text
})
state["completed"] = True
return states, outputs
def generate(
self,
prompts: List[List[Dict[str, Any]]],
llm: LLM,
sampling_params: SamplingParams
) -> List[Sequence[int]]:
# Setup conversation states
states = [{"messages": p, "completed": False, "prompt_tokens": -1} for p in prompts]
outputs = [None] * len(prompts)
# Keep stepping until each conversation is marked complete
while not all(s["completed"] for s in states):
states, outputs = self.step(states, llm, sampling_params)
# Gather prompt+completion IDs, then slice out the prompt portion
all_ids = [
list(o.prompt_token_ids) + list(o.outputs[0].token_ids)
for o in outputs
]
completion_ids = [
all_ids[i][states[i]["prompt_tokens"]:]
for i in range(len(prompts))
]
return completion_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!! Will add that + docstring to the PR
@xiangjjj One complication is that many base model tokenizers have pad_token_id = eos_token_id, so when padding a batch, the "last EOS token" will be the last token in the pad sequence. Trying out a couple workarounds. |
Ah, I see! That is tricky. Thank for this! |
Simplest solution I think is to move the masking logic into the respective vllm/transformers generate routes. vLLM now masks based on completion_ids length rather than the position of the first EOS token. |
Sure, it makes sense and should resolve the masking issue! Thanks for fixing this! |
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False) | ||
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] | ||
if self.env is not None: | ||
completion_ids = self.env.generate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering if we could enhance the output structure to include additional metadata alongside the current completion_ids. In multi-step rollouts, having step-wise details available directly would be really useful for determining the final rewards. Currently, parsing completion_ids to reconstruct this information feels a bit cumbersome. Would it be possible to return a dict that encapsulates both the tokens and the extra metadata?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll defer to @qgallouedec on that one, could be added with pretty minimal changes (happy to do so), but there are also easy enough workarounds for what you're describing (computing rewards at generation time and caching them in a data structure accessed by your reward functions). My first priority with this PR is getting basic support for Environments enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have suggestions on how to cache the information? I'm new to trl
and have not done that before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qgallouedec @willccbb Regarding logits_to_keep, what do you think is the best strategy to filter out tool observation tokens? My concern is that when we invoke a web search tool—which might return thousands of tokens—those tokens could overwhelm the policy tokens during gradient updates. If this extraneous information turns out to be noisy, it might adversely affect the policy gradient learning. We can leave this out of the current PR, but I’d appreciate your thoughts on ideas/directions to resolve this as I'm new to trl
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I'd suggest having deterministic processing of tool call results to avoid excessive outputs, that'll cause problems whether or not we mask tokens. In some experiments on multi-step code tool use, I found limiting allowed printouts to 500 chars per step worked reasonably well.
Not clear to me that naive masking of tool outputs "makes sense" for GRPO algorithmically, and for now it probably is fine to treat tool call results as just part of the LLM output. Especially you are letting your model "reason" for many tokens per step, it should not be a major problem I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that applying policy gradient directly to tool call outputs is not theoretically sound—especially in outcome reward methods. Note that I'm not suggesting "naive masking of tool outputs." I propose that we do not compute the log-likelihood of the tool observation tokens together with their policy gradient and KL loss functions. While one might argue this is an empirical question depending on the setup of the system, I’m interested in learning about what kind of design in grpo would enable more flexible loss masking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m fine with keeping this PR as is, but I’m interested in exploring a design that allows for more flexible loss masking in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I see what you mean. I think this is a very interesting research question, but hard to say definitively right now. Including them as normal response tokens is essentially forcing the model to "model" the tool calls directly, which feels reasonable to me.
Papers exploring this issue are pretty recent:
If interest in multi-step RL continues to grow, I would imagine having more specialized trainers in the future could make sense. For now this is just a way to get something that runs using GRPO, whether or not it is the most principled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that mean in the current implementation we're still calculating KLD and policy gradient on environment responses? I thought we're not doing that based on my understanding of DoubleCheckEnv
but now I'm confused...
All I would love to help out with this as I am working on it myself. |
@willccbb I reach out via your email on your profile! let me know if you want to Colab on this as I am continuing to work on it. |
I still don't understand why wrapping would limit what you can do. For example for the double call: def wrapper_decorator(generate_func):
def generate_wrapper(*args, **kwargs):
... # stuff before
result1 = generate_func(*args, **kwargs)
result1 = generate_func(*args, **kwargs)
... # stuff after
return result
return generate_wrapper
trainer.llm.model.generate = wrapper_decorator(trainer.llm.model.generate) Taking the env paradigm, I think it should work as is with the main branch with the something like: env = MyEnv(...)
def wrapper_decorator(generate_func):
def generate_wrapper(*args, **kwargs):
prompts = args[0]
return env.generate(prompts, self, *args, **kwargs)
return generate_wrapper
trainer.llm.model.generate = wrapper_decorator(trainer.llm.model.generate) I might be missing something though |
@qgallouedec I've been experimenting with combining chat and completion responses during training. The idea is to score each response based on its format and content. If a mistake is detected, another LLM—one that provides the correct answer—is consulted. This secondary LLM offers a brief hint to guide the correction, and then the response is re-generated. For example, this dataset snippet:
If a mistake is found (say, the tool call should be within the
Now I go back with the hint and try to get a competition:
My goal is to see I can get auto correction via hints scored on top of it. I have made a really hacky overfitted solution where I was training in epoch runs like this, creating the dataset and then going back to round 2,3,4 GRPO training... slowly guiding it to the right answer. Now I think I need to work on getting an actual solution that's how I ended up here. Happy to help/code things up and test. Thanks all! |
@qgallouedec biggest problem is that It also is just much nicer to be able to have access to the |
Any update? The PR is falling behind the GRPO changes on main |
Working on a refactor now which I think should allow directly using the main TRL branch. Approach is to extend |
Working now on the |
What does this PR do?
Adds a protocol under
trl/environments
for anEnvironment
object which wraps vLLM's.generate(...)
to allow for custom rollout logic, and an optionalenv
field to the Trainer for passing such an object.A simple example usage is included below, others in this repo: willccbb/verifiers
Given the breadth of different agentic tasks people are interested in, I think an implementation of multi-step behavior should be as open-ended and customizable as possible, rather than having everything flow through explicit tool use or a predefined format. Here, the only requirement for an Environment is that it mirrors the behavior of calling
llm.generate()
+ extractingtoken_ids
. I have found that it's more practical to pass message dicts to Environments rather than preformatted text, hence the addition ofgather_object(prompts)
.I agree with Quentin's comment here that while masking of tool outputs/environment responses/messages from other users/agents is maybe desirable in some cases, it is probably not necessary for initial experimentation. Future iterations could perhaps extend the
Environment
definition to allow for mask outputs.Usage looks like:
Example implementation of such an environment:
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.