Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunked implementation of siglip sigmoid loss #130

Open
cdeepakroy opened this issue Sep 12, 2024 · 0 comments
Open

Add chunked implementation of siglip sigmoid loss #130

cdeepakroy opened this issue Sep 12, 2024 · 0 comments

Comments

@cdeepakroy
Copy link

cdeepakroy commented Sep 12, 2024

Thank you for the great work on siglip paper. In particular, the gains in metrics at small batch sizes are impressive.
Sigmoid Loss for Language Image Pre-Training

I tried using this pytorch based distributed chunked implementation in open clip repo using torch.distributed.P2POp
https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/loss.py#L307

I tried taking a ViT B vision encoder + XLM Roberta text encoder and train it using both CLIP softmax and SigLip sigmoid loss on an in house dataset of 10M image-text pairs at an effective batch size of 9k (with V100 GPUs) and observed that CLIP softmax still performs better than siglip sigmoid loss on nDCG metric.

I was wondering if there is any error in above implementation using p2pop. I also tried using an all_gather to get negative text_features from other gpus but still the behavior seems to be the same

class SigLipLossAllGather(nn.Module):
    """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
    """
    def __init__(
            self,
            logit_scale=np.log(10),
            logit_bias=-10
    ):
        super().__init__()
        self.logit_scale = logit_scale
        self.logit_bias = logit_bias

        self.labels = {}

    def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
        labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
        if not negative_only:
            labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
        return labels

    def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
        logits = logit_scale * image_features @ text_features.T
        if logit_bias is not None:
            logits += logit_bias
        return logits

    def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
        logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
        labels = self.get_ground_truth(
            image_features.device,
            image_features.dtype,
            image_features.shape[0],
            negative_only=negative_only,
        )
        loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
        return loss

    def forward(self, image_features, text_features):
        
        loss = self._loss(image_features, text_features, self.logit_scale, self.logit_bias)

        if global_manager.world_size > 1:

            # Gather text features from all ranks
            text_features_dict = {'text_features': text_features, 'global_rank': global_manager.rank}
            all_text_features_dict = all_gather_objects(text_features_dict)

            # Compute loss against negative text features from all other ranks
            for i in range(len(all_text_features_dict)):
                neigh_rank = all_text_features_dict[i]['global_rank']
                neigh_text_features = all_text_features_dict[i]['text_features']
                if neigh_rank != global_manager.rank:
                    loss += self._loss(image_features, neigh_text_features, self.logit_scale, self.logit_bias, negative_only=True)
                    
        return {"loss": loss}

The implementation in this repo seems to be the non-chunked version

def loss_fn(params):
zimg, ztxt, extras = model.apply(
{"params": params}, images, labels,
train=True, rngs={"dropout": rng_model})
logits = jnp.dot(zimg, ztxt.T)
logits = logits * extras["t"] + extras["b"]
eye = jnp.eye(zimg.shape[0])
# Standard sigmoid computes everything twice, once assuming positive
# labels and once assuming negative ones. But here we know exactly where
# to find positives (on "me" diagonal) and negatives (everywhere else),
# so compute each one's loss only once:
m1_diag1 = -jnp.ones_like(logits) + 2 * eye
loglik = jax.nn.log_sigmoid(m1_diag1 * logits)
# Normalize by npos per column, but that's one, so just sum.
nll = -jnp.sum(loglik, axis=-1)
# NOTE: same as concat'ing me/ot along axis -1 above.
l = jnp.mean(nll)
return l

I am wondering if you can add a chunked implementation of the siglip sigmoig loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant