Skip to content

Commit 4645678

Browse files
authored
Hotfix gaudi2 with newer transformers. (#3176)
1 parent ad765cd commit 4645678

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

backends/gaudi/server/text_generation_server/utils/logits_process.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import habana_frameworks.torch.core as htcore
44

55
from loguru import logger
6-
from typing import Dict, Union
6+
from typing import Dict
77
from text_generation_server.pb.generate_pb2 import GrammarType
88

99
from outlines.fsm.fsm import RegexFSM
@@ -13,7 +13,6 @@
1313
import time
1414

1515
from transformers import (
16-
LogitsWarper,
1716
LogitsProcessor,
1817
TemperatureLogitsWarper,
1918
TopKLogitsWarper,
@@ -191,7 +190,7 @@ def filter(self, indices):
191190

192191
class HeterogeneousTemperatureLogitsWarper:
193192
r"""
194-
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
193+
[`LogitsProcessor`] for temperature (exponential scaling output probability distribution).
195194
This version allows for a separate value for each sample and runs inplace when possible.
196195
It doesn't validate inputs.
197196
@@ -220,7 +219,7 @@ def filter(self, indices):
220219
return None
221220

222221

223-
class HeterogeneousTopPLogitsWarper(LogitsWarper):
222+
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
224223
"""
225224
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
226225
This version allows for a separate value for each sample and runs inplace when possible.
@@ -279,9 +278,9 @@ def filter(self, indices):
279278
return None
280279

281280

282-
class HeterogeneousTopKLogitsWarper(LogitsWarper):
281+
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
283282
r"""
284-
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
283+
[`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.
285284
This version allows for a separate value for each sample and runs inplace when possible.
286285
It doesn't validate inputs.
287286
@@ -360,9 +359,9 @@ def filter(self, indices):
360359
return None
361360

362361

363-
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
362+
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
364363
r"""
365-
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
364+
[`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language
366365
Generation](https://arxiv.org/abs/2202.00666) for more information.
367366
This version allows for a separate value for each sample and runs inplace when possible.
368367
It doesn't validate inputs.
@@ -454,13 +453,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
454453
r"""
455454
A wrapper for logit warpers or processors without heterogeneous parameter support.
456455
Args:
457-
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
456+
processors (`Dict[int, LogitsProcessor]`):
458457
A mapping of sample indices to logit warpers or processors, to be run sequentially.
459458
"""
460459

461460
def __init__(
462461
self,
463-
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
462+
processors: Dict[int, LogitsProcessor],
464463
):
465464
self.processors = processors
466465

0 commit comments

Comments
 (0)