Skip to content

Unintuitive Logic in masked_fill Function of DistilBERT Model Implementation #2721

@sondalex

Description

@sondalex

masked_fill function of distilbert model implementation has currently unintuitive logic

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

In the current setup, the user must invert the attention mask obtained from the tokenizer before passing it to the model.forward function. This requirement can be confusing as it differs from transformers implementation.

...
let text: Vec<&str>  = vec![...];
let encoded = tokenizer.encode_batch(text.to_vec().clone(), true)?;
let input_ids = encoded.iter().map(|v| v.get_ids().to_vec()).collect::<Vec<_>>();
let input_ids = Tensor::new(input_ids, &device)?;
let attention_mask = encoded.iter().map(|encoding| encoding.get_attention_mask().to_vec()).collect::<Vec<_>>();
let attention_mask = Tensor::new(attention_mask, &device)?;

let (batch_size, feature_size) = input_ids.dims2()?;

// Invert the attention mask for correct behavior --> Counterintuitive
let attention_mask = attention_mask.eq(0 as u32)?.reshape((batch_size, 1, 1, feature_size))?;

let output = model.forward(&input_ids, &attention_mask)?;
...

Proposition:

Replace masked_fill function with:

fn masked_fill(on_true: &Tensor, mask: &Tensor, on_false: f32) -> Result<Tensor> {
    let shape = mask.shape();
    let on_false = Tensor::new(on_false, on_true.device())?.broadcast_as(shape.dims())?;
    let m = mask.where_cond(&on_true, &on_false)?;
    Ok(m)
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions