Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gettin gradient of loss during inference #3371

Open
LalchandPandia opened this issue Jan 28, 2025 · 12 comments
Open

Gettin gradient of loss during inference #3371

LalchandPandia opened this issue Jan 28, 2025 · 12 comments

Comments

@LalchandPandia
Copy link

I am fine-tuning llama 2 using accelerate+deepseed zero3. During evaluation, which is run after every checkpoint step, I need to calculate gradient loss w.r.t certain input ids. As per my understanding the embedding matrix is sharded and when I try to get the gradient, I get an error saying that grad is set to None. Is there a cleaner way to do it using accelerate APIs?
My code:

def token_gradients(model, input_ids, targets):

    """
    Computes gradients of the loss with respect to the coordinates.
    
    Parameters
    ----------
    model : Transformer Model
        The transformer model to be used.
    input_ids : torch.Tensor
        The input sequence in the form of token ids.
    input_slice : slice
        The slice of the input sequence for which gradients need to be computed.
    targets :torch.Tensor
        The target sequence in the form of token ids .
    loss_slice : slice
        The slice of the logits to be used for computing the loss.

    Returns
    -------
    torch.Tensor
        The gradients of each token in the input_slice with respect to the loss.
    """
    valid_positions = (targets != -100).nonzero(as_tuple=True)[0]
    input_slice = slice(0, valid_positions[0].item())
    end_input_slice = valid_positions[-1].item()
    print('Luke input_slice valid_positions ',input_slice, ' end_input_slice ',end_input_slice)


    #embed_weights = get_embedding_matrix(model)
    embeddings = model.get_input_embeddings()
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        embedding_weights = embeddings.weight
        embedding_size = embedding_weights.shape[0]
        print('embed_weights ',embedding_weights.shape)
    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embedding_size,
        device=model.device,
        dtype=embeddings.weight.dtype
    )
    print('input_ids[input_slice].shape[0] ',input_ids[input_slice].shape[0])
    #print('embed_weights.shape  ',' embedding_size ',embedding_size)
    print('one_hot.shape ',one_hot.shape)
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embeddings.weight.dtype)
    )
    one_hot.requires_grad_()
    print('one_hot.shape ',one_hot.shape)
    #print('embeddings.weight ',embeddings.weight)
    input_embeds = (one_hot @ embedding_weights).unsqueeze(0)
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        #this contains embeddings for all ids present in input_slice
        input_embeds = (one_hot @ embeddings.weight)
        print('input_embeds grad ',input_embeds.grad)
    input_ids = input_ids.cpu().tolist()
        embeds = embeddings.weight[input_ids[:end_input_slice+1],:].detach()

        #Now we stitch the input_embeddings for all input ids before the input slice starts
        #embedings for all ids in input slice
        #embeddings for all ids after input slice:
        full_embeds = torch.cat(
            [
                embeds[:input_slice.start,:],
                input_embeds,
                embeds[input_slice.stop:,:]
            ],
            dim=0)
        full_embeds = full_embeds.unsqueeze(0)
        print('Luke full_embeds ',full_embeds.shape)


        logits = model(inputs_embeds=full_embeds).logits
        print('Luke logits ',logits.shape)
        #calculate loss for logit position correspondind to every token in vocabulary 
        loss = torch.nn.CrossEntropyLoss()(logits[0,:,:], targets[:end_input_slice+1])
        print('Luke loss ',loss.shape)

        loss.backward()
        print(one_hot.grad.shape)
    return one_hot.grad.clone()
@LalchandPandia
Copy link
Author

@muellerzr @ArthurZucker

@BenjaminBossan
Copy link
Member

For my understanding, did you confirm that the gradient being None is due to the usage of DeepSpeed and accelerate? That is, if you run the same code on a single device, the gradient is present? The reason why I'm asking is that at evaluation time, we typically don't calculate the gradient, as it's not needed. So for instance, your call might be inside a torch.inference_mode() context. Since you don't show the whole code, it's not possible to tell.

@LalchandPandia
Copy link
Author

@BenjaminBossan
Below is a code for single gpu set up that works perfectly:

import torch
from transformers import ( AutoModelForCausalLM, AutoTokenizer)
def token_gradients(model, input_ids, targets):

    valid_positions = (targets != -100).nonzero(as_tuple=True)[0]
    input_slice = slice(0, valid_positions[0].item())
    end_input_slice = valid_positions[-1].item()

    embeddings = model.get_input_embeddings()
    embedding_weights = embeddings.weight
    embedding_size = embedding_weights.shape[0]
    print('embed_weights ',embedding_weights.shape)
    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embedding_size,
        device=model.device,
        dtype=embeddings.weight.dtype
    )
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embeddings.weight.dtype)
    )
    one_hot.requires_grad_()

    input_embeds = (one_hot @ embeddings.weight)
    input_embeds.requires_grad_()
    input_embeds.retain_grad()
    input_ids = input_ids.cpu().tolist()
    embeds = embeddings.weight[input_ids[:end_input_slice+1],:].detach()
    print('embeds shape ',embeds.shape)

    full_embeds = torch.cat(
        [
            embeds[:input_slice.start,:],
            input_embeds,
            embeds[input_slice.stop:,:]
        ],
        dim=0)
    full_embeds = full_embeds.unsqueeze(0)
    logits = model(inputs_embeds=full_embeds).logits
    #calculate loss for logit position correspondind to every token in vocabulary 
    loss = torch.nn.CrossEntropyLoss()(logits[0,:,:], targets[:end_input_slice+1])

    loss.backward()
    print(one_hot.grad.shape)
    print('one hot grad ',one_hot.grad)
    print('input embeds grad ',input_embeds.grad)
    return one_hot.grad.clone()
device = "cuda"

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to(device)
print('model.device ',model.device)
model.eval()
input = torch.tensor([[    1,   894, 29901,  5122, 10753,   304, 14294,   670,  6567,  9098,491, 14051, 10549,   963, 29889,  8449, 19309,  7101,   674,  7738, 278,  1556, 12871, 29973,    13, 22550, 29901, 15589,  5112,  1516]])
target = torch.tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,-100,  -100,  -100,  -100,  -100, 22550, 29901, 15589,  5112,  1516])
token_gradients(model, input, target

But with changes for distrubuted setup fails:

def token_gradients(model, input_ids, targets):
    valid_positions = (targets != -100).nonzero(as_tuple=True)[0]
    input_slice = slice(0, valid_positions[0].item())
    end_input_slice = valid_positions[-1].item()
    embeddings = model.get_input_embeddings()
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        embedding_size = embeddings.weight.shape
    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embedding_size[0],
        device=model.device,
        dtype=embeddings.weight.dtype
    )
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embeddings.weight.dtype)
    )
    one_hot.requires_grad_()
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        input_embeds = (one_hot @ embeddings.weight)
        input_embeds.requires_grad_()
        input_embeds.retain_grad()
        input_ids = input_ids.cpu().tolist()
        embeds = embeddings.weight[input_ids[:end_input_slice+1],:].detach()
        full_embeds = torch.cat(
            [
                embeds[:input_slice.start,:],
                input_embeds,
                embeds[input_slice.stop:,:]
            ],
            dim=0)
        full_embeds = full_embeds.unsqueeze(0)
        logits = model(inputs_embeds=full_embeds).logits
        loss = torch.nn.CrossEntropyLoss()(logits[0,:,:], targets[:end_input_slice+1])
        print('Luke loss ',loss.shape)

        loss.backward()
    return one_hot.grad.clone()

The loss in this case an empty tensor

@LalchandPandia
Copy link
Author

Logits in single gpu setup (without accelearate+deepspeed) is a tensor with grad_fn=. But with multiple gpus there is no grad_fn
Single GPU: logits tensor([[[-12.9832, -7.4134, -0.4327, ..., -6.8297, -8.0879, -7.5863],
[ -6.5046, -3.2412, 4.3043, ..., -1.0471, -4.2429, -2.2271]]],
device='cuda:0', grad_fn=)
Multiple GPU with accelerate:
ogits tensor([[[-12.9832, -7.4134, -0.4327, ..., -6.8297, -8.0879, -7.5863],
[ -6.5046, -3.2412, 4.3043, ..., -1.0471, -4.2429, -2.2271]]],
device='cuda:0')

@BenjaminBossan
Copy link
Member

Thanks for providing some code. Unfortunately, I could not get it to run. I always run into errors, like wrong device, dimension error, etc. Each time I fix an error, the next one occurs. Could you please double check that the scripts run the way that you show them?

Moreover, I don't see where accelerate enters the picture, could explain? Finally, please show how you call the scripts and what your accelerate env is.

@LalchandPandia
Copy link
Author

Hi,
Can you please check with the below code for single gpu setup?

import torch

from transformers import ( AutoModelForCausalLM, AutoTokenizer)
def token_gradients(model, input_ids, targets):
    valid_positions = (targets != -100).nonzero(as_tuple=True)[0]
    input_slice = slice(0, valid_positions[0].item())
    end_input_slice = valid_positions[-1].item()
    
    embeddings = model.get_input_embeddings()
    #with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
    embedding_weights = embeddings.weight
    embedding_size = embedding_weights.shape[0]
    print('embed_weights ',embedding_weights.shape)
    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embedding_size,
        device=model.device,
        dtype=embeddings.weight.dtype
    )
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embeddings.weight.dtype)
    )
    one_hot.requires_grad_()
    
    input_embeds = (one_hot @ embeddings.weight)
    input_embeds.requires_grad_()
    input_embeds.retain_grad()
    print('input_embeds grad ',input_embeds.grad, ' input_embeds ',input_embeds.shape)
    input_ids = input_ids.cpu().tolist()
    #embeddings corresponding to only input ids
    embeds = embeddings.weight[input_ids[:end_input_slice+1],:]
    full_embeds = torch.cat(
        [
            embeds[:input_slice.start,:],
            input_embeds,
            embeds[input_slice.stop:,:]
        ],
        dim=0)
    full_embeds = full_embeds.unsqueeze(0)
    print('Luke full_embeds ',full_embeds.shape)
    logits = model(inputs_embeds=full_embeds).logits
    loss = torch.nn.CrossEntropyLoss()(logits[0,:,:], targets[:end_input_slice+1])
    loss.backward()
    return one_hot.grad.clone(), input_embeds.grad.clone()

device = "cuda"
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to(device)
print('model.device ',model.device)
model.eval()

input = torch.tensor([    1,   894, 29901,  5122, 10753,   304, 14294,   670,  6567,  9098,491, 14051, 10549,   963, 29889,  8449, 19309,  7101,   674,  7738, 278,  1556, 12871, 29973,    13, 22550, 29901, 15589,  5112,  1516]).to(device)


target = torch.tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,-100,  -100,  -100,  -100,  -100, 22550, 29901, 15589,  5112,  1516]).to(device)


onehot_grad, inputembed_grad = token_gradients(model, input, target)
print(onehot_grad.shape, ' ',inputembed_grad.shape)

My transformer version is 4.43.4 and torch is 2.4.0+cu118
Please let me know if it works for you on single gpu

@LalchandPandia
Copy link
Author

Also, this is the minimal code using accelerate which do not work
output of accelerate env

  • Accelerate version: 0.31.0
  • Platform: Linux-5.15.0-126-generic-x86_64-with-glibc2.35
  • accelerate bash location: /net/scratch/lcpandia/python_3_11/bin/accelerate
  • Python version: 3.11.9
  • Numpy version: 1.26.3
  • PyTorch version (GPU?): 2.4.0+cu118 (True)
  • PyTorch XPU available: False
  • PyTorch NPU available: False
  • PyTorch MLU available: False
  • System RAM: 503.56 GB
  • GPU type: NVIDIA A100 80GB PCIe
  • Accelerate default config:
    Not found
  • deepspeed 0.15.0
import json
import logging
import math
import os
import random
import subprocess
import time
from dataclasses import dataclass, field
from datetime import timedelta
from functools import partial
from typing import List, Optional, Union
import numpy as np
import datasets
import deepspeed
import torch, pickle, json, csv
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import InitProcessGroupKwargs, set_seed
from datasets import load_dataset
from torch.utils.data import DataLoader

from collections import defaultdict
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    get_scheduler,
    DataCollatorForSeq2Seq,
)

accelerator_log_kwargs = {}




timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
gradient_accumulation_steps=4 
accelerator = Accelerator(
    gradient_accumulation_steps=gradient_accumulation_steps,
    use_seedable_sampler=True,
    **accelerator_log_kwargs,
    kwargs_handlers=[timeout_kwargs],
)
use_flash_attn = True
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf",revision=None,trust_remote_code=False,use_fast=False,)

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",
                revision=None,
                trust_remote_code=False,
                low_cpu_mem_usage=False,
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2" if use_flash_attn else "eager",)
num_added_tokens = tokenizer.add_special_tokens(
            {
                "bos_token": "<s>",
                "eos_token": "</s>",
                "unk_token": "<unk>",
                "pad_token": "<pad>",
            }
        )
embeddings = model.get_input_embeddings()
with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
    embedding_size = embeddings.weight.shape[0]
# resize does its own gather
if len(tokenizer) > embedding_size:
    # pad to multiple for tensor cores.
    model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
# update embedding size after resizing for sum loss
embeddings = model.get_input_embeddings()
with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
    embedding_size = embeddings.weight.shape[0]
per_device_train_batch_size = 1
raw_datasets = load_dataset("allenai/ai2_arc", "ARC-Challenge")
train_dataset = raw_datasets["train"]

def encode_sft_example(example, tokenizer, max_seq_length, dataset):
    """
    This function encodes a single example into a format that can be used for sft training.
    Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields.
    We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
    """
    print('example ',example)
    ids = 0
    question = example['question']
    answer = example['choices']['text'][example['choices']['label'].index(example['answerKey'])]
    
    
    text = f'Question: {question}\nAnswer: {answer}'
    answer_start = text.find("Answer: ")

    input_ids = tokenizer(
            text,
            return_tensors="pt",
            padding=False,
            max_length=max_seq_length,
            truncation=True,
    )['input_ids'][0]
    #print('Luke input_ids ',input_ids)
    labels = input_ids.clone()
    if answer_start != -1:
        # Tokenize the portion that comes after 'Answer:'
        answer_tokens = tokenizer(
            text[answer_start:],
            return_tensors="pt"
        )['input_ids'][0][1:]  # Skip the first token ('Answer:')
        
    # Set the labels before the answer part to -100
    labels[:len(labels) - len(answer_tokens)] = -100
    attention_mask = torch.ones_like(input_ids)

    
    return {
        "input_ids": input_ids.flatten(),
        "labels": labels.flatten(),
        "attention_mask": attention_mask.flatten(),
        "ids": ids
    }
with accelerator.main_process_first():
    train_dataset = train_dataset.map(
        partial(encode_sft_example, tokenizer=tokenizer, max_seq_length=512, dataset=train_dataset),
        batched=False,
        num_proc=128,
        load_from_cache_file=True,
        remove_columns=[
            name for name in train_dataset.column_names if name not in ["input_ids", "labels", "attention_mask", "ids"]
        ],
        desc="Tokenizing and reformatting instruction data",
    )
    train_dataset.set_format(type="pt")
    train_dataset = train_dataset.filter(lambda example: (example["labels"] != -100).any())

train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
        batch_size=per_device_train_batch_size,
    )
#eval Dataloader:
eval_dataloader = DataLoader(
    train_dataset,
    batch_size=per_device_train_batch_size,
    shuffle=False,
    collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
)


num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
max_train_steps = 3 * num_update_steps_per_epoch
model.gradient_checkpointing_enable()
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=2e-5, fused=True)
num_training_steps_for_scheduler = max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_training_steps=num_training_steps_for_scheduler,
        num_warmup_steps=int(num_training_steps_for_scheduler * 0.03),)


# Prepare everything with `accelerator`.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, lr_scheduler
)
# Prepare evaldataloader with `accelerator`.
eval_dataloader = accelerator.prepare(eval_dataloader)

def token_gradients(model, input_ids, targets):
    print('input_ids ',input_ids)
    print('targets ',targets)
    valid_positions = (targets != -100).nonzero(as_tuple=True)[0]
    input_slice = slice(0, valid_positions[0].item())
    end_input_slice = valid_positions[-1].item()

    embeddings = model.get_input_embeddings()
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        embedding_weights = embeddings.weight
        embedding_size = embedding_weights.shape[0]
        embed_dtype=embeddings.weight.dtype
    
    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embedding_size,
        device=model.device,
        dtype=embeddings.weight.dtype
    )
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embed_dtype)
    )
    one_hot.requires_grad_()
    print('one_hot.shape ',one_hot.shape)
    #print('embeddings.weight ',embeddings.weight)
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        #this contains embeddings for all ids present in input_slice
        input_embeds = (one_hot @ embeddings.weight)
        input_embeds.requires_grad_()
        input_embeds.retain_grad()
        print('input_embeds grad ',input_embeds.grad, ' input_embeds ',input_embeds.shape)
        input_ids = input_ids.cpu().tolist()
        #embeddings corresponding to only input ids
        embeds = embeddings.weight[input_ids[:end_input_slice+1],:]#.detach()
        print('embeds shape ',embeds.shape)
        #Now we stitch the input_embeddings for all input ids before the input slice starts
        #embedings for all ids in input slice
        #embeddings for all ids after input slice:
    full_embeds = torch.cat(
        [
            embeds[:input_slice.start,:],
            input_embeds,
            embeds[input_slice.stop:,:]
        ],
        dim=0)
    full_embeds = full_embeds.unsqueeze(0)
    print('full_embeds ',full_embeds.shape)


    outputs = model(inputs_embeds=full_embeds)
    logits = outputs.logits
    print('logits ',logits)
    loss = torch.nn.CrossEntropyLoss()(logits[0,:,:], targets[:end_input_slice+1])
    print('loss ',loss)

    accelerator.backward(loss)
    print(one_hot.grad.shape)
    print('grad ',one_hot.grad)
    print('input embeds grad ',input_embeds.grad)

    return one_hot.grad.clone(), input_embeds.grad.clone()
with torch.no_grad():
    for step, batch  in enumerate(eval_dataloader):
        for input, target in zip(batch['input_ids'], batch['labels']):
            print(token_gradients(model, input, target))

@BenjaminBossan
Copy link
Member

Thanks for the updates. I could make some progress, but could not fully replicate yet. The single GPU script worked for me. For the DS script, I wanted to trim it down to be as close as possible to the single GPU script. This is what I came up with:

import torch
import deepspeed
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from transformers import AutoModelForCausalLM, AutoTokenizer


def token_gradients(model, input_ids, targets):
    valid_positions = (targets != -100).nonzero(as_tuple=True)[0]
    input_slice = slice(0, valid_positions[0].item())
    end_input_slice = valid_positions[-1].item()

    embeddings = model.get_input_embeddings()
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        embedding_weights = embeddings.weight
        embedding_size = embedding_weights.shape[0]

    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embedding_size,
        device=model.device,
        dtype=embeddings.weight.dtype
    )
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embeddings.weight.dtype)
    )
    one_hot.requires_grad_()
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        input_embeds = (one_hot @ embeddings.weight)
        input_embeds.requires_grad_()
        input_embeds.retain_grad()
        print('input_embeds grad ',input_embeds.grad, ' input_embeds ',input_embeds.shape)
        input_ids = input_ids.cpu().tolist()
        #embeddings corresponding to only input ids
        embeds = embeddings.weight[input_ids[:end_input_slice+1],:]
    full_embeds = torch.cat(
        [
            embeds[:input_slice.start,:],
            input_embeds,
            embeds[input_slice.stop:,:]
        ],
        dim=0)
    full_embeds = full_embeds.unsqueeze(0)
    print('full_embeds ',full_embeds.shape)
    logits = model(inputs_embeds=full_embeds).logits
    loss = torch.nn.CrossEntropyLoss()(logits[0,:,:], targets[:end_input_slice+1])
    accelerator.backward(loss)
    return one_hot.grad.clone(), input_embeds.grad.clone()


model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, fused=True)
accelerator = Accelerator()
# this line is only necessary because we don't prepare a dataset
AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 8
model, optimizer = accelerator.prepare(model, optimizer)
model.train()

input = torch.tensor([ 1, 894, 29901, 5122, 10753, 304, 14294, 670, 6567,
                       9098,491, 14051, 10549, 963, 29889, 8449, 19309, 7101, 674, 7738, 278, 1556,
                       12871, 29973, 13, 22550, 29901, 15589, 5112, 1516]).to(model.device)

target = torch.tensor([ -100, -100, -100, -100, -100, -100, -100, -100, -100,
                        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,-100, -100,
                        -100, -100, -100, 22550, 29901, 15589, 5112, 1516]).to(model.device)

onehot_grad, inputembed_grad = token_gradients(model, input, target)
print(onehot_grad.shape, ' ',inputembed_grad.shape)

As you can see, this is almost identical to the single GPU script. When calling this with accelarate launch script.py, I get an error though:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/name/work/forks/accelerate/bar.py", line 72, in <module>
[rank1]:     onehot_grad, inputembed_grad = token_gradients(model, input, target)
[rank1]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/name/work/forks/accelerate/bar.py", line 51, in token_gradients
[rank1]:     accelerator.backward(loss)
[rank1]:   File "/home/name/work/forks/accelerate/src/accelerate/accelerator.py", line 2238, in backward
[rank1]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank1]:   File "/home/name/work/forks/accelerate/src/accelerate/utils/deepspeed.py", line 261, in backward
[rank1]:     self.engine.backward(loss, **kwargs)
[rank1]:   File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 2011, in backward
[rank1]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank1]:   File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/deepspeed/runtime/zero/stage3.py", line 2214, in backward
[rank1]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank1]:   File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank1]:     scaled_loss.backward(retain_graph=retain_graph)
[rank1]:   File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/_tensor.py", line 581, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

No idea how this comes about. Can you reproduce?

Regarding your config, it says:

  • Accelerate default config:
    Not found

Do you have a specific config that you pass to accelerate launch? Otherwise, I'd suggest to recreate the whole config.

@LalchandPandia
Copy link
Author

LalchandPandia commented Feb 6, 2025

Hi,
I pass the following config present in stage3_no_offloading_accelerate.conf:
{
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 1e5,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
My script:
accelerate launch
--mixed_precision bf16
--num_machines 1
--num_processes $NUM_GPUS
--use_deepspeed
--deepspeed_config_file stage3_no_offloading_accelerate.conf script.py
Also, in your script please change model.train() to model.eval() because we are looking for ways to take gradient in eval mode. I understand that in standard scenario backward is not allowed in eval mode. But in my script for single gpu I am able to do it as I do not set torch._no_grad explicitly. I am hoping accelerate should also support such scenario

@BenjaminBossan
Copy link
Member

Thanks for the additional info. Using this DS config and setting model.eval(), I get the following error:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/name/work/forks/accelerate/bar.py", line 73, in <module>
[rank1]:     onehot_grad, inputembed_grad = token_gradients(model, input, target)
[rank1]:   File "/home/name/work/forks/accelerate/bar.py", line 52, in token_gradients
[rank1]:     accelerator.backward(loss)
[rank1]:   File "/home/name/work/forks/accelerate/src/accelerate/accelerator.py", line 2238, in backward
[rank1]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank1]:   File "/home/name/work/forks/accelerate/src/accelerate/utils/deepspeed.py", line 261, in backward
[rank1]:     self.engine.backward(loss, **kwargs)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1895, in backward
[rank1]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2041, in backward
[rank1]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank1]:     scaled_loss.backward(retain_graph=retain_graph)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
[rank1]:     return user_fn(self, *args)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 169, in backward
[rank1]:     ctx.pre_backward_function(ctx.module)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 435, in _run_before_backward_function
[rank1]:     self.pre_sub_module_backward_function(sub_module)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/home/name/anaconda3/envs/accelerate/lib/python3.10/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 511, in pre_sub_module_backward_function
[rank1]:     assert sub_module.training, "backward pass is invalid for module in evaluation mode"
[rank1]: AssertionError: backward pass is invalid for module in evaluation mode

So it's not the exact same error you reported initially (gradients being None), but it is similar: Gradients can be calculated on single GPU but not on multi-GPU with DeepSpeed.

I'm not knowledgeable about the inner workings of DeepSpeed. But from the stacktrace, the error comes from DeepSpeed. I don't know if accelerate could do anything differently to prevent this error.

What I still don't understand is why this needs to run in eval mode. If you don't want the gradients to update the parameters, could you not switch to train mode, calculate the gradients, and once you're finished, zero out the gradients?

@LalchandPandia
Copy link
Author

Hi,
Please find the response
So it's not the exact same error you reported initially (gradients being None), but it is similar: Gradients can be calculated on single GPU but not on multi-GPU with DeepSpeed.

In my code , I was using loss.backward(). I agree the error is similar
I'm not knowledgeable about the inner workings of DeepSpeed. But from the stacktrace, the error comes from DeepSpeed. I don't know if accelerate could do anything differently to prevent this error.
Ok I will raise an issue in DeepSpeed repo and update
What I still don't understand is why this needs to run in eval mode. If you don't want the gradients to update the parameters, could you not switch to train mode, calculate the gradients, and once you're finished, zero out the gradients?
This is a common use case in adversarial attack experiments. So, are you suggesting the following sequence model.train() -->accelerate.backward() and avoid optimizer.step() in order to prevent update of parameters?

@BenjaminBossan
Copy link
Member

Ok I will raise an issue in DeepSpeed repo and update

Let's see how that goes.

So, are you suggesting the following sequence model.train() -->accelerate.backward() and avoid optimizer.step() in order to prevent update of parameters?

Yes, that should work, then optimizer.zero_grad() and model.eval() to remove gradients and put the model back into eval mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants