|
3 | 3 | import habana_frameworks.torch.core as htcore
|
4 | 4 |
|
5 | 5 | from loguru import logger
|
6 |
| -from typing import Dict, Union |
| 6 | +from typing import Dict |
7 | 7 | from text_generation_server.pb.generate_pb2 import GrammarType
|
8 | 8 |
|
9 | 9 | from outlines.fsm.fsm import RegexFSM
|
|
13 | 13 | import time
|
14 | 14 |
|
15 | 15 | from transformers import (
|
16 |
| - LogitsWarper, |
17 | 16 | LogitsProcessor,
|
18 | 17 | TemperatureLogitsWarper,
|
19 | 18 | TopKLogitsWarper,
|
@@ -191,7 +190,7 @@ def filter(self, indices):
|
191 | 190 |
|
192 | 191 | class HeterogeneousTemperatureLogitsWarper:
|
193 | 192 | r"""
|
194 |
| - [`LogitsWarper`] for temperature (exponential scaling output probability distribution). |
| 193 | + [`LogitsProcessor`] for temperature (exponential scaling output probability distribution). |
195 | 194 | This version allows for a separate value for each sample and runs inplace when possible.
|
196 | 195 | It doesn't validate inputs.
|
197 | 196 |
|
@@ -220,7 +219,7 @@ def filter(self, indices):
|
220 | 219 | return None
|
221 | 220 |
|
222 | 221 |
|
223 |
| -class HeterogeneousTopPLogitsWarper(LogitsWarper): |
| 222 | +class HeterogeneousTopPLogitsWarper(LogitsProcessor): |
224 | 223 | """
|
225 | 224 | [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
226 | 225 | This version allows for a separate value for each sample and runs inplace when possible.
|
@@ -279,9 +278,9 @@ def filter(self, indices):
|
279 | 278 | return None
|
280 | 279 |
|
281 | 280 |
|
282 |
| -class HeterogeneousTopKLogitsWarper(LogitsWarper): |
| 281 | +class HeterogeneousTopKLogitsWarper(LogitsProcessor): |
283 | 282 | r"""
|
284 |
| - [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. |
| 283 | + [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. |
285 | 284 | This version allows for a separate value for each sample and runs inplace when possible.
|
286 | 285 | It doesn't validate inputs.
|
287 | 286 |
|
@@ -360,9 +359,9 @@ def filter(self, indices):
|
360 | 359 | return None
|
361 | 360 |
|
362 | 361 |
|
363 |
| -class HeterogeneousTypicalLogitsWarper(LogitsWarper): |
| 362 | +class HeterogeneousTypicalLogitsWarper(LogitsProcessor): |
364 | 363 | r"""
|
365 |
| - [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language |
| 364 | + [`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language |
366 | 365 | Generation](https://arxiv.org/abs/2202.00666) for more information.
|
367 | 366 | This version allows for a separate value for each sample and runs inplace when possible.
|
368 | 367 | It doesn't validate inputs.
|
@@ -454,13 +453,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
454 | 453 | r"""
|
455 | 454 | A wrapper for logit warpers or processors without heterogeneous parameter support.
|
456 | 455 | Args:
|
457 |
| - processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): |
| 456 | + processors (`Dict[int, LogitsProcessor]`): |
458 | 457 | A mapping of sample indices to logit warpers or processors, to be run sequentially.
|
459 | 458 | """
|
460 | 459 |
|
461 | 460 | def __init__(
|
462 | 461 | self,
|
463 |
| - processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], |
| 462 | + processors: Dict[int, LogitsProcessor], |
464 | 463 | ):
|
465 | 464 | self.processors = processors
|
466 | 465 |
|
|
0 commit comments