Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO Environments for custom multi-step rollouts (vLLM-only) #2810

Closed
wants to merge 33 commits into from

Conversation

willccbb
Copy link

@willccbb willccbb commented Feb 9, 2025

What does this PR do?

Adds a protocol under trl/environments for an Environment object which wraps vLLM's .generate(...) to allow for custom rollout logic, and an optional env field to the Trainer for passing such an object.

A simple example usage is included below, others in this repo: willccbb/verifiers

Given the breadth of different agentic tasks people are interested in, I think an implementation of multi-step behavior should be as open-ended and customizable as possible, rather than having everything flow through explicit tool use or a predefined format. Here, the only requirement for an Environment is that it mirrors the behavior of calling llm.generate() + extracting token_ids. I have found that it's more practical to pass message dicts to Environments rather than preformatted text, hence the addition of gather_object(prompts).

I agree with Quentin's comment here that while masking of tool outputs/environment responses/messages from other users/agents is maybe desirable in some cases, it is probably not necessary for initial experimentation. Future iterations could perhaps extend the Environment definition to allow for mask outputs.

Usage looks like:

doublecheck_env = DoubleCheckEnv() # assistant message -> hardcoded user "Are you sure?" message -> second assistant message
...
trainer = GRPOTrainer(
    model=model_name,
    processing_class=tokenizer,
    reward_funcs=reward_funcs,
    env=doublecheck_env, # Optional, defaults to `None`
    args=training_args,
    train_dataset=dataset
)
trainer.train()

Example implementation of such an environment:

from typing import List, Callable, Dict, Any, Sequence, Tuple
from vllm import LLM, SamplingParams, RequestOutput

class DoubleCheckEnv:

    def step(self,
             states: List[Dict[str, Any]],
             llm: LLM,
             sampling_params: SamplingParams) -> Tuple[List[Dict[str, Any]], List[RequestOutput]]:
        
        outputs = llm.chat([state["messages"] for state in states], sampling_params=sampling_params) # type: ignore
        for i, state in enumerate(states):
            state["messages"].append({'role': 'assistant', 'content': outputs[i].outputs[0].text})
            state["messages"].append({'role': 'user', 'content': 'Are you sure?'})
            state["prompt_tokens"] = len(outputs[i].prompt_token_ids)

        outputs = llm.chat([state["messages"] for state in states], sampling_params=sampling_params) # type: ignore

        for i, state in enumerate(states):
            state["messages"].append({'role': 'assistant', 'content': outputs[i].outputs[0].text})
            state["completed"] = True
        return states, outputs

    def generate(self, prompts: List[List[Dict[str, Any]]], llm: LLM, sampling_params: SamplingParams) -> List[Sequence[int]]:
        all_completed = False
        states = [{"messages": m, "completed": False, "prompt_tokens": -1} for m in prompts]
        outputs = [None] * len(prompts)
        while not all_completed:
            states, outputs = self.step(states, llm, sampling_params)
            all_completed = all(state["completed"] for state in states)
        all_ids = [list(output.prompt_token_ids) + list(output.outputs[0].token_ids) for output in outputs]
        completion_ids = [all_ids[i][states[i]["prompt_tokens"]:] for i in range(len(outputs))]
        return completion_ids

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

enable_prefix_caching=True,
max_model_len=self.args.vllm_max_model_len,
)
self.sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=self.max_completion_length,
skip_special_tokens=False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI - on vLLM we have another parameter spaces_between_special_tokens(set to True by default) that adds spaces between special tokens in the output when skip_special_tokens=False.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks for heads up. Will do some testing and make sure everything looks OK at token-level, maybe don't need that line.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted that line. If users want different SamplingParams settings they can specify that at the Environment level anyway.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to wrap the generate method instead? It would allow keep grpo as is

@willccbb
Copy link
Author

I'm not quite sure what you mean by wrapping the generate method here. Many parts of the codebase touch the LLM object, wrapping it in another object would require changing the access in each place. Overriding generate entails a high amount of complexity on the user end, as most applications will either want to use the true generate (or chat, which calls generate) method. Adding an if/else Environment route like in this PR is the simplest approach I could think of which allows users to directly use the LLM object as-is in their rollouts, while still allowing the Trainer to reference LLM normally throughout. Note that this enables many requested features--tool use, sampling strategies, agentic interactions--to be encapsulated within Environments, avoiding further complexity down the road.

If you have something specific in mind that you could illustrate with a short snippet I'm happy to try.

@qgallouedec
Copy link
Member

Actually it can be pretty straightforward and simple:

def wrapper_decorator(generate_func):
    def generate_wrapper(*args, **kwargs):
        ...  # stuff before
        result = generate_func(*args, **kwargs)
        ...  # stuff after
        return result
    return generate_wrapper

trainer.llm.model.generate = wrapper_decorator(trainer.llm.model.generate)

@accupham
Copy link

Seems like the wrapper idea could be implemented internally. The user interface could remain the same with the environments idea.

Also, I think there might be a bug in the DoubleCheckEnv.generate example. I think this is the correct implementation:

# original
completion_ids = [all_ids[states[i]["prompt_tokens"]:] for i, output in enumerate(outputs)]
# fixed
completion_ids = [all_ids[i][states[i]["prompt_tokens"]:] for i in range(len(prompts))]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstrings to make implementation easier for end users:

class Environment(Protocol):
    """
    A protocol describing the minimal interface needed for integration 
    with the trainer. Your environment can run any multi-step logic, 
    but must ultimately return token sequences akin to a typical 
    vllm.LLM's generate() output. https://docs.vllm.ai/en/stable/api/offline_inference/llm.html
    """
    def generate(
        self,
        prompts: List[List[Dict[str, Any]]],
        llm: Any,
        sampling_params: Any
    ) -> List[Any]:
        ...

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Examples are best way to learn how to use libraries... I took your example and added illustrative comments to make it easier to understand what is going on and how a user might implement their own.

class DoubleCheckEnv:
    """
    Example Environment that:
      1) Sends an initial user prompt to the LLM.
      2) Appends the assistant's reply and a follow-up user query: "Are you sure?".
      3) Sends everything again to the LLM for a final response.
      4) Returns just the completion tokens for each prompt.
    """

    def step(
        self,
        states: List[Dict[str, Any]],
        llm: LLM,
        sampling_params: SamplingParams
    ) -> Tuple[List[Dict[str, Any]], List[RequestOutput]]:
        # First LLM call for each state's messages
        outputs = llm.chat([s["messages"] for s in states], sampling_params=sampling_params)
        for i, state in enumerate(states):
            state["messages"].append({
                "role": "assistant", 
                "content": outputs[i].outputs[0].text
            })
            state["messages"].append({
                "role": "user", 
                "content": "Are you sure?"
            })
            # Track prompt_tokens to later slice out the completion part
            state["prompt_tokens"] = len(outputs[i].prompt_token_ids)

        # Second LLM call after "Are you sure?" is appended
        outputs = llm.chat([s["messages"] for s in states], sampling_params=sampling_params)
        for i, state in enumerate(states):
            state["messages"].append({
                "role": "assistant", 
                "content": outputs[i].outputs[0].text
            })
            state["completed"] = True

        return states, outputs

    def generate(
        self,
        prompts: List[List[Dict[str, Any]]],
        llm: LLM,
        sampling_params: SamplingParams
    ) -> List[Sequence[int]]:
        # Setup conversation states
        states = [{"messages": p, "completed": False, "prompt_tokens": -1} for p in prompts]
        outputs = [None] * len(prompts)

        # Keep stepping until each conversation is marked complete
        while not all(s["completed"] for s in states):
            states, outputs = self.step(states, llm, sampling_params)

        # Gather prompt+completion IDs, then slice out the prompt portion
        all_ids = [
            list(o.prompt_token_ids) + list(o.outputs[0].token_ids) 
            for o in outputs
        ]
        completion_ids = [
            all_ids[i][states[i]["prompt_tokens"]:] 
            for i in range(len(prompts))
        ]
        return completion_ids

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!! Will add that + docstring to the PR

@willccbb
Copy link
Author

@xiangjjj One complication is that many base model tokenizers have pad_token_id = eos_token_id, so when padding a batch, the "last EOS token" will be the last token in the pad sequence. Trying out a couple workarounds.

@xiangjjj
Copy link

Ah, I see! That is tricky. Thank for this!

@willccbb
Copy link
Author

Simplest solution I think is to move the masking logic into the respective vllm/transformers generate routes. vLLM now masks based on completion_ids length rather than the position of the first EOS token.

@xiangjjj
Copy link

Simplest solution I think is to move the masking logic into the respective vllm/transformers generate routes. vLLM now masks based on completion_ids length rather than the position of the first EOS token.

Sure, it makes sense and should resolve the masking issue! Thanks for fixing this!

outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
if self.env is not None:
completion_ids = self.env.generate(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if we could enhance the output structure to include additional metadata alongside the current completion_ids. In multi-step rollouts, having step-wise details available directly would be really useful for determining the final rewards. Currently, parsing completion_ids to reconstruct this information feels a bit cumbersome. Would it be possible to return a dict that encapsulates both the tokens and the extra metadata?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll defer to @qgallouedec on that one, could be added with pretty minimal changes (happy to do so), but there are also easy enough workarounds for what you're describing (computing rewards at generation time and caching them in a data structure accessed by your reward functions). My first priority with this PR is getting basic support for Environments enabled.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have suggestions on how to cache the information? I'm new to trl and have not done that before.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec @willccbb Regarding logits_to_keep, what do you think is the best strategy to filter out tool observation tokens? My concern is that when we invoke a web search tool—which might return thousands of tokens—those tokens could overwhelm the policy tokens during gradient updates. If this extraneous information turns out to be noisy, it might adversely affect the policy gradient learning. We can leave this out of the current PR, but I’d appreciate your thoughts on ideas/directions to resolve this as I'm new to trl.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I'd suggest having deterministic processing of tool call results to avoid excessive outputs, that'll cause problems whether or not we mask tokens. In some experiments on multi-step code tool use, I found limiting allowed printouts to 500 chars per step worked reasonably well.

Not clear to me that naive masking of tool outputs "makes sense" for GRPO algorithmically, and for now it probably is fine to treat tool call results as just part of the LLM output. Especially you are letting your model "reason" for many tokens per step, it should not be a major problem I think.

Copy link

@xiangjjj xiangjjj Feb 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that applying policy gradient directly to tool call outputs is not theoretically sound—especially in outcome reward methods. Note that I'm not suggesting "naive masking of tool outputs." I propose that we do not compute the log-likelihood of the tool observation tokens together with their policy gradient and KL loss functions. While one might argue this is an empirical question depending on the setup of the system, I’m interested in learning about what kind of design in grpo would enable more flexible loss masking.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m fine with keeping this PR as is, but I’m interested in exploring a design that allows for more flexible loss masking in the future.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I see what you mean. I think this is a very interesting research question, but hard to say definitively right now. Including them as normal response tokens is essentially forcing the model to "model" the tool calls directly, which feels reasonable to me.

Papers exploring this issue are pretty recent:

If interest in multi-step RL continues to grow, I would imagine having more specialized trainers in the future could make sense. For now this is just a way to get something that runs using GRPO, whether or not it is the most principled.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean in the current implementation we're still calculating KLD and policy gradient on environment responses? I thought we're not doing that based on my understanding of DoubleCheckEnv but now I'm confused...

@vladrad
Copy link

vladrad commented Feb 17, 2025

All I would love to help out with this as I am working on it myself.

@vladrad
Copy link

vladrad commented Feb 17, 2025

@willccbb I reach out via your email on your profile! let me know if you want to Colab on this as I am continuing to work on it.

@qgallouedec
Copy link
Member

I still don't understand why wrapping would limit what you can do. For example for the double call:

def wrapper_decorator(generate_func):
    def generate_wrapper(*args, **kwargs):
        ...  # stuff before
        result1 = generate_func(*args, **kwargs)
        result1 = generate_func(*args, **kwargs)
        ...  # stuff after
        return result
    return generate_wrapper

trainer.llm.model.generate = wrapper_decorator(trainer.llm.model.generate)

Taking the env paradigm, I think it should work as is with the main branch with the something like:

env = MyEnv(...)

def wrapper_decorator(generate_func):
    def generate_wrapper(*args, **kwargs):
        prompts = args[0]
        return env.generate(prompts, self, *args, **kwargs)
    return generate_wrapper

trainer.llm.model.generate = wrapper_decorator(trainer.llm.model.generate)

I might be missing something though

@vladrad
Copy link

vladrad commented Feb 17, 2025

@qgallouedec
I tried something similar before without success, but I'll give it another go.

I've been experimenting with combining chat and completion responses during training. The idea is to score each response based on its format and content. If a mistake is detected, another LLM—one that provides the correct answer—is consulted. This secondary LLM offers a brief hint to guide the correction, and then the response is re-generated.

For example, this dataset snippet:

<think>
Oh, the user wants me to call a tool.
</think>
</answer>
I am going to call X...
</answer>
<tool>
<function name=read_file>...</function>
</tool>

If a mistake is found (say, the tool call should be within the <answer> tags), the correction process would work like this:

<think>
Oh, I forgot the format requires tool calls to be within the answer 

Now I go back with the hint and try to get a competition:

tags. This will provide the correct format the user requested.
</think>
</answer>
I am going to call X...
<tool>
<function name=read_file>...</function>
</tool>
</answer>

My goal is to see I can get auto correction via hints scored on top of it.

I have made a really hacky overfitted solution where I was training in epoch runs like this, creating the dataset and then going back to round 2,3,4 GRPO training... slowly guiding it to the right answer. Now I think I need to work on getting an actual solution that's how I ended up here. Happy to help/code things up and test.

Thanks all!

@willccbb
Copy link
Author

@qgallouedec biggest problem is that LLM.chat relies on LLM.generate, and many (most?) multi-step interaction protocols will want to use LLM.chat. If we override generate as you propose, any call to chat inside of our wrapper will result in recursive blowup. We also have to keep all of our logic contained within a single wrapper function, and we also can't easily maintain global state within/across rollouts (for things like precomputing/caching rewards to be retrieved by reward functions, which can be objects with access to the Env state).

It also is just much nicer to be able to have access to the SamplingParams and LLM objects directly, as this is how people typically develop agent applications on top of vLLM. The added complexity to the trainer by allowing an Env object is pretty minor, but it unlocks quite a bit from the user perspective. Other libraries which have already built these kinds of environments (TextArena, reasoning-gym, etc.) are way easier to adapt if we can just "use the model like a normal LLM" rather than having to rewrite all of the chat parsing logic again for every application.

@amitlevy
Copy link

Any update? The PR is falling behind the GRPO changes on main

@willccbb
Copy link
Author

Working on a refactor now which I think should allow directly using the main TRL branch. Approach is to extend GRPOTrainer to a class GRPOEnvTrainer which only needs to override the _generate_and_score_completions function, which I think is already a reasonable encapsulation of the minimum logic needed to implement custom rollout strategies. Once that's tested + pushed to the verifiers repo I'll probably just close this PR.

@willccbb
Copy link
Author

Working now on the dev branch of verifiers, will clean up some things + merge to main shortly. Closing this PR, will maybe revisit later, but overloading the trainer seems to be the best method for supporting these kinds of features for now.

@willccbb willccbb closed this Feb 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants