Skip to content

Commit d5cb15a

Browse files
committed
logit_bias
1 parent 5438967 commit d5cb15a

File tree

4 files changed

+37
-30
lines changed

4 files changed

+37
-30
lines changed

vllm/config/__init__.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2651,24 +2651,46 @@ class PoolerConfig:
26512651
## for embeddings models
26522652
normalize: Optional[bool] = None
26532653
"""
2654-
Whether to normalize the embeddings outputs.
2654+
Whether to normalize the embeddings outputs. Defaults to True.
26552655
"""
26562656
dimensions: Optional[int] = None
26572657
"""
26582658
Reduce the dimensions of embeddings if model
2659-
support matryoshka representation.
2659+
support matryoshka representation. Defaults to None.
2660+
"""
2661+
enable_chunked_processing: Optional[bool] = None
2662+
"""
2663+
Whether to enable chunked processing for long inputs that exceed the model's
2664+
maximum position embeddings. When enabled, long inputs will be split into
2665+
chunks, processed separately, and then aggregated using weighted averaging.
2666+
This allows embedding models to handle arbitrarily long text without CUDA
2667+
errors. Defaults to False.
2668+
"""
2669+
max_embed_len: Optional[int] = None
2670+
"""
2671+
Maximum input length allowed for embedding generation. When set, allows
2672+
inputs longer than max_embed_len to be accepted for embedding models.
2673+
When an input exceeds max_embed_len, it will be handled according to
2674+
the original max_model_len validation logic.
2675+
Defaults to None (i.e. set to max_model_len).
26602676
"""
26612677

26622678
## for classification models
26632679
activation: Optional[bool] = None
26642680
"""
26652681
Whether to apply activation function to the classification outputs.
2682+
Defaults to True.
2683+
"""
2684+
logit_bias: Optional[float] = None
2685+
"""
2686+
If provided, classification logit biases. Defaults to None.
26662687
"""
26672688

26682689
## for reward models
26692690
softmax: Optional[bool] = None
26702691
"""
26712692
Whether to apply softmax to the reward outputs.
2693+
Defaults to True.
26722694
"""
26732695
step_tag_id: Optional[int] = None
26742696
"""
@@ -2683,25 +2705,6 @@ class PoolerConfig:
26832705
``math-shepherd-mistral-7b-prm`` model.
26842706
"""
26852707

2686-
enable_chunked_processing: Optional[bool] = None
2687-
"""
2688-
Whether to enable chunked processing for long inputs that exceed the model's
2689-
maximum position embeddings. When enabled, long inputs will be split into
2690-
chunks, processed separately, and then aggregated using weighted averaging.
2691-
This allows embedding models to handle arbitrarily long text without CUDA
2692-
errors. Defaults to False.
2693-
"""
2694-
2695-
max_embed_len: Optional[int] = None
2696-
"""
2697-
Maximum input length allowed for embedding generation. When set, allows
2698-
inputs longer than max_embed_len to be accepted for embedding models.
2699-
This parameter enables accepting long inputs without requiring
2700-
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
2701-
max_embed_len, it will be handled according to the original max_model_len
2702-
validation logic. Defaults to None (i.e. set to max_model_len).
2703-
"""
2704-
27052708
def compute_hash(self) -> str:
27062709
"""
27072710
WARNING: Whenever a new field is added to this config,

vllm/model_executor/layers/pooler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,13 @@ def __init__(
633633
) -> None:
634634
super().__init__()
635635

636+
from vllm.config import get_current_vllm_config
637+
vllm_config = get_current_vllm_config()
638+
636639
self.pooling = pooling
637640
self.classifier = classifier
638641
self.act_fn = act_fn or PoolerClassify()
642+
self.logit_bias: Optional[float] = vllm_config.model_config.pooler_config.logit_bias
639643

640644
def get_supported_tasks(self) -> Set[PoolingTask]:
641645
return {"classify", "score"}
@@ -654,6 +658,9 @@ def forward(
654658
pooled_data = self.classifier(pooled_data)
655659
# pooled_data shape: [batchsize, num_labels]
656660

661+
if self.logit_bias:
662+
pooled_data -= self.logit_bias
663+
657664
pooling_params = get_pooling_params(pooling_metadata)
658665
flags = [p.activation for p in pooling_params]
659666

vllm/model_executor/models/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,10 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
210210
@staticmethod
211211
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
212212
config = vllm_config.model_config.hf_config
213-
214213
config.num_labels = 1
214+
pooler_config = vllm_config.model_config.pooler_config
215+
if pooler_config.logit_bias is None:
216+
pooler_config.logit_bias = 2.65
215217

216218

217219
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):

vllm/model_executor/models/jina_vl.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
9292
pooler_config = vllm_config.model_config.pooler_config
9393
assert pooler_config is not None
9494

95-
# logit bias for sigmoid normalization
96-
self.LOGIT_BIAS = 2.65
97-
9895
self.score = JinaVLScorer(config)
9996
self.pooler = DispatchPooler({
10097
"encode":
10198
Pooler.for_encode(pooler_config),
10299
"classify":
103-
Pooler.for_classify(pooler_config, classifier=None),
100+
Pooler.for_classify(pooler_config, classifier=self.score),
104101
"score":
105-
Pooler.for_classify(pooler_config, classifier=None),
102+
Pooler.for_classify(pooler_config, classifier=self.score),
106103
})
107104

108105
@classmethod
@@ -137,9 +134,7 @@ def forward(
137134
inputs_embeds=inputs_embeds,
138135
**kwargs,
139136
)
140-
141-
logits = self.score(hidden_states) - self.LOGIT_BIAS
142-
return logits
137+
return hidden_states
143138

144139
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
145140
loader = AutoWeightsLoader(self)

0 commit comments

Comments
 (0)