Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 4 additions & 25 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
if self.is_encoder_decoder:
self.skip_prompt = False
self.mode: Literal['pt', 'vllm', 'lmdeploy', 'sglang', # infer
'train', 'rlhf', 'kto', 'gkd'] = 'pt' # train
'train', 'rlhf', 'kto'] = 'pt' # train
self.task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'prm', 'reranker',
'generative_reranker'] = 'causal_lm'
self.use_megatron = False
Expand Down Expand Up @@ -383,14 +383,6 @@ def _kto_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
encoded['label'] = bool(inputs.chosen.label)
return encoded

def _gkd_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = self._encode_truncated(inputs)
encoded['prompts'] = encoded['input_ids'][:-len(encoded.pop('answer_input_ids'))]
for k in list(encoded.keys()):
if k.startswith('prompt_') or k.endswith('answer_'):
encoded.pop(k, None)
return encoded

def _embedding_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
_encoded = {}
labels = []
Expand Down Expand Up @@ -521,8 +513,6 @@ def encode(self,
encoded = self._rlhf_encode(inputs)
elif self.mode == 'kto':
encoded = self._kto_encode(inputs)
elif self.mode == 'gkd':
encoded = self._gkd_encode(chosen)
elif self.task_type == 'seq_cls':
if self.mode == 'rlhf':
encoded = self._rlhf_encode(inputs)
Expand Down Expand Up @@ -637,7 +627,7 @@ def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any:
@contextmanager
def generate_context(self):
origin_mode = self.mode
if self.mode in {'train', 'rlhf', 'kto', 'gkd'}:
if self.mode in {'train', 'rlhf', 'kto'}:
self.set_mode('pt')
is_multimodal = self.model_meta.is_multimodal
if is_multimodal:
Expand Down Expand Up @@ -1291,7 +1281,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
res_context_list, loss_scale_list, answer_len = (
self._swift_encode(inputs) if template_backend == 'swift' else self._jinja_encode(inputs))
encoded = {}
if self.is_encoder_decoder or self.mode == 'gkd':
if self.is_encoder_decoder:
total_len = len(res_context_list)
for key, _slice in zip(['prompt', 'answer'],
[slice(0, total_len - answer_len),
Expand Down Expand Up @@ -1414,7 +1404,7 @@ def pre_forward_hook(self, model: nn.Module, args, kwargs):
def is_training(self):
return self.mode not in {'pt', 'vllm', 'lmdeploy', 'sglang'}

def set_mode(self, mode: Literal['pt', 'vllm', 'lmdeploy', 'sglang', 'train', 'rlhf', 'kto', 'gkd']) -> None:
def set_mode(self, mode: Literal['pt', 'vllm', 'lmdeploy', 'sglang', 'train', 'rlhf', 'kto']) -> None:
self.mode = mode

def register_post_encode_hook(self, models: List[nn.Module]) -> None:
Expand Down Expand Up @@ -1467,8 +1457,6 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int
res = self._rlhf_data_collator(batch, padding_to=padding_to)
elif self.mode == 'kto':
res = self._kto_data_collator(batch, padding_to=padding_to)
elif self.mode == 'gkd':
res = self._gkd_data_collator(batch, padding_to=padding_to)
elif self.task_type == 'prm':
res = self._data_collator(batch, padding_to=padding_to)
elif self.task_type == 'seq_cls':
Expand Down Expand Up @@ -1563,15 +1551,6 @@ def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optiona
res['label'] = label
return res

def _gkd_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
res = self._data_collator(batch, padding_to=padding_to)
prompts_batch = [{'input_ids': b['prompts']} for b in batch if b.get('prompts') is not None]
if prompts_batch:
prompts_res = self._data_collator(prompts_batch, padding_to=padding_to)
res['prompts'] = prompts_res.pop('input_ids')
res.update({f'prompt_{k}': v for k, v in prompts_res.items()})
return res

def _embedding_data_collator(self,
batch: List[Dict[str, Any]],
*,
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/train/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_t
def _prepare_template(self) -> None:
args = self.args
super()._prepare_template()
mode_mapping = {'kto': 'kto', 'gkd': 'gkd', 'ppo': 'pt', 'grpo': 'train'}
mode_mapping = {'kto': 'kto', 'gkd': 'train', 'ppo': 'pt', 'grpo': 'train'}
self.template.set_mode(mode_mapping.get(args.rlhf_type, 'rlhf'))

if args.rlhf_type == 'ppo':
Expand Down
98 changes: 64 additions & 34 deletions swift/trainers/rlhf_trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,15 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non

# Code borrowed from huggingface/trl
def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token_id=None):
"""Generate on-policy outputs using the model.

When encode_prompt_only=True, inputs['input_ids'] already contains only the prompt part.
"""
assert not self.template.padding_free, 'generate not support padding_free/packing.'
# Generate output with respect to the prompt only
model_inputs = {k: v for k, v in inputs.items() if not k.startswith('prompt') and k != 'labels'}
model_inputs['input_ids'] = inputs['prompts']
model_inputs.update({k[len('prompt_'):]: v for k, v in inputs.items() if k.startswith('prompt_')})
prompt_input_ids = inputs['input_ids']
model_inputs = {k: v for k, v in inputs.items() if k != 'labels'}
model_inputs.pop('position_ids', None)
model_inputs.pop('text_position_ids', None)
kwargs = {}
base_model = self.template.get_base_model(model)
parameters = inspect.signature(base_model.generate).parameters
Expand All @@ -127,11 +130,11 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token
# Get the generated token IDs
generated_tokens = generated_outputs.sequences
if not self.template.skip_prompt:
generated_tokens = torch.concat([inputs['prompts'], generated_tokens], dim=1)
generated_tokens = torch.concat([prompt_input_ids, generated_tokens], dim=1)
# Calculate new attention mask
new_attention_mask = torch.ones_like(generated_tokens)
new_labels = generated_tokens.clone()
new_labels[:, :inputs['prompts'].shape[1]] = -100
new_labels[:, :prompt_input_ids.shape[1]] = -100

# If there's pad_token_id, set attention mask to 0 for padding tokens
if pad_token_id is not None:
Expand Down Expand Up @@ -263,20 +266,37 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
else:
return loss

def _prepare_batch_inputs(self, inputs: list) -> Dict[str, torch.Tensor]:
def _prepare_batch_inputs(self, inputs: list, encode_prompt_only: bool = False) -> Dict[str, torch.Tensor]:
"""Prepare batch inputs for training.

Args:
inputs: List of input data dictionaries
encode_prompt_only: If True, only encode the prompt part (for on-policy/seq_kd generation).
If False, encode the full messages including response (for offline dataset).
"""
from swift.llm import to_device
from .utils import replace_assistant_response_with_ids

template = self.template
batch_encoded_inputs = []

for data in inputs:
if 'response_token_ids' in data and data['response_token_ids']:
from .utils import replace_assistant_response_with_ids
data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'])
# Use 'pt' mode for prompt-only encoding, 'train' mode for full encoding
mode = 'pt' if encode_prompt_only else 'train'
with self._template_context(template, mode=mode):
for data in inputs:
if 'response_token_ids' in data and data['response_token_ids']:
data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'])

encoded = template.encode(data, return_length=True)
batch_encoded_inputs.append(encoded)
if encode_prompt_only:
# Remove response content for prompt-only encoding
messages = data.get('messages', [])
if messages and messages[-1].get('role') == 'assistant':
messages[-1]['content'] = None

from swift.llm import to_device
batch_encoded = to_device(template.data_collator(batch_encoded_inputs), self.model.device)
encoded = template.encode(data, return_length=True)
batch_encoded_inputs.append(encoded)

batch_encoded = to_device(template.data_collator(batch_encoded_inputs), self.model.device)

return batch_encoded

Expand Down Expand Up @@ -314,54 +334,64 @@ def training_step(self,
self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(valid_messages))
self._logs['completion'].extend(valid_completions)
with self._template_context(self.template):
inputs = self._prepare_batch_inputs(generated_inputs)
# vLLM already generated response, encode full messages
encoded_inputs = self._prepare_batch_inputs(generated_inputs, encode_prompt_only=False)
else:
inputs = self._prepare_batch_inputs(inputs)
# Need prompt-only encoding for on-policy generation
encoded_inputs = self._prepare_batch_inputs(inputs, encode_prompt_only=True)
with unwrap_model_for_generation(
model, self.accelerator,
gather_deepspeed3_params=args.ds3_gather_for_generation) as unwrapped_model:
unwrapped_model.eval()
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id)
unwrapped_model, encoded_inputs, self.generation_config, self.processing_class.pad_token_id)
unwrapped_model.train()
inputs['input_ids'] = new_input_ids
inputs['attention_mask'] = new_attention_mask
inputs['labels'] = new_labels
# override with generated inputs
encoded_inputs['input_ids'] = new_input_ids
encoded_inputs['attention_mask'] = new_attention_mask
encoded_inputs['labels'] = new_labels

elif self.seq_kd:
# Sequential KD: teacher model generates responses
data_source = DataSource.TEACHER

# Resample inputs that fail encoding when truncation_strategy is 'raise'('delete')
if self.template.truncation_strategy == 'raise':
inputs = self.resample_encode_failed_inputs(inputs)
inputs = self._prepare_batch_inputs(inputs)
# Need prompt-only encoding for teacher generation
encoded_inputs = self._prepare_batch_inputs(inputs, encode_prompt_only=True)
load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext()
with load_context, unwrap_model_for_generation(
self.teacher_model,
self.accelerator,
gather_deepspeed3_params=self.teacher_ds3_gather_for_generation) as unwrapped_model:
unwrapped_model.eval()
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id)
inputs['input_ids'] = new_input_ids
inputs['attention_mask'] = new_attention_mask
inputs['labels'] = new_labels
unwrapped_model, encoded_inputs, self.generation_config, self.processing_class.pad_token_id)
# override with generated inputs
encoded_inputs['input_ids'] = new_input_ids
encoded_inputs['attention_mask'] = new_attention_mask
encoded_inputs['labels'] = new_labels

else:
# Off-policy: use dataset responses
# Off-policy: use dataset responses, encode full messages
data_source = DataSource.DATASET
inputs = self._prepare_batch_inputs(inputs)
total_length = self.template.max_length + self.max_completion_length
with self._template_context(self.template, max_length=total_length):
encoded_inputs = self._prepare_batch_inputs(inputs, encode_prompt_only=False)

# Mark data source for downstream processing (e.g., conditional SFT loss)
inputs['_data_source'] = data_source
encoded_inputs['_data_source'] = data_source

with self.template.forward_context(self.model, inputs):
loss = HFSFTTrainer.training_step(self, model, inputs, num_items_in_batch)
with self.template.forward_context(self.model, encoded_inputs):
loss = HFSFTTrainer.training_step(self, model, encoded_inputs, num_items_in_batch)
return loss

def prediction_step(self, model, inputs, *args, **kwargs):
inputs = self._prepare_batch_inputs(inputs)
with self.template.forward_context(self.model, inputs):
return super().prediction_step(model, inputs, *args, **kwargs)
# Prediction uses full messages
encoded_inputs = self._prepare_batch_inputs(inputs, encode_prompt_only=False)
with self.template.forward_context(self.model, encoded_inputs):
return super().prediction_step(model, encoded_inputs, *args, **kwargs)

@contextmanager
def offload_context(self):
Expand Down
18 changes: 13 additions & 5 deletions swift/trainers/rlhf_trainer/rollout_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,16 +1268,24 @@ def _disable_sp_context(self, template: Optional[Template] = None):
template.sequence_parallel_size = original_sequence_parallel_size

@contextmanager
def _template_context(self, template: Template, inputs: Optional['DataType'] = None):
# The max_length for prompt and completion has already been restricted, so there is no need for max_length here.
max_length = template.max_length
template.max_length = None
def _template_context(self,
template: Template,
inputs: Optional['DataType'] = None,
max_length: Optional[int] = None,
mode: Optional[str] = None):
original_max_length = template.max_length
original_mode = template.mode
template.max_length = max_length
if mode is not None:
template.set_mode(mode)
forward_ctx = template.forward_context(self.model, inputs) if inputs is not None else nullcontext()
try:
with forward_ctx:
yield
finally:
template.max_length = max_length
template.max_length = original_max_length
if mode is not None:
template.set_mode(original_mode)

def _prepare_resample_data_iterator(self):
"""Initialize resample data iterator for truncation_strategy 'raise'('delete').
Expand Down
Loading