Skip to content

Commit 6208d62

Browse files
authored
Minor code cleaning for SamplingParams (vllm-project#99)
1 parent 42f1042 commit 6208d62

File tree

1 file changed

+50
-49
lines changed

1 file changed

+50
-49
lines changed

cacheflow/sampling_params.py

+50-49
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Set
1+
from typing import Set
22

33

44
class SamplingParams:
@@ -16,54 +16,6 @@ def __init__(
1616
max_tokens: int = 16,
1717
logprobs: int = 0,
1818
) -> None:
19-
if n < 1:
20-
raise ValueError(f"n must be at least 1, got {n}.")
21-
if not -2.0 <= presence_penalty <= 2.0:
22-
raise ValueError(
23-
f"presence_penalty must be in [-2, 2], got {presence_penalty}.")
24-
if not -2.0 <= frequency_penalty <= 2.0:
25-
raise ValueError(
26-
f"frequency_penalty must be in [-2, 2], got {frequency_penalty}.")
27-
if temperature < 0.0:
28-
raise ValueError(
29-
f"temperature must be non-negative, got {temperature}.")
30-
if not 0.0 < top_p <= 1.0:
31-
raise ValueError(f"top_p must be in (0, 1], got {top_p}.")
32-
if top_k < -1 or top_k == 0:
33-
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
34-
f"got {top_k}.")
35-
if max_tokens < 1:
36-
raise ValueError(
37-
f"max_tokens must be at least 1, got {max_tokens}.")
38-
if logprobs < 0:
39-
raise ValueError(
40-
f"logprobs must be non-negative, got {logprobs}.")
41-
42-
if use_beam_search:
43-
if n == 1:
44-
raise ValueError(
45-
"n must be greater than 1 when using beam search.")
46-
if temperature > 0.0:
47-
raise ValueError(
48-
"temperature must be 0 when using beam search.")
49-
if top_p < 1.0:
50-
raise ValueError(
51-
"top_p must be 1 when using beam search.")
52-
if top_k != -1:
53-
raise ValueError(
54-
"top_k must be -1 when using beam search.")
55-
elif temperature == 0.0:
56-
# Zero temperature means greedy sampling.
57-
if n > 1:
58-
raise ValueError(
59-
"n must be 1 when using greedy sampling.")
60-
if top_p < 1.0:
61-
raise ValueError(
62-
"top_p must be 1 when using greedy sampling.")
63-
if top_k != -1:
64-
raise ValueError(
65-
"top_k must be -1 when using greedy sampling.")
66-
6719
self.n = n
6820
self.presence_penalty = presence_penalty
6921
self.frequency_penalty = frequency_penalty
@@ -75,6 +27,55 @@ def __init__(
7527
self.max_tokens = max_tokens
7628
self.logprobs = logprobs
7729

30+
self._verify_args()
31+
if self.use_beam_search:
32+
self._verity_beam_search()
33+
elif self.temperature == 0.0:
34+
# Zero temperature means greedy sampling.
35+
self._verify_greedy_sampling()
36+
37+
def _verify_args(self) -> None:
38+
if self.n < 1:
39+
raise ValueError(f"n must be at least 1, got {self.n}.")
40+
if not -2.0 <= self.presence_penalty <= 2.0:
41+
raise ValueError("presence_penalty must be in [-2, 2], got "
42+
f"{self.presence_penalty}.")
43+
if not -2.0 <= self.frequency_penalty <= 2.0:
44+
raise ValueError("frequency_penalty must be in [-2, 2], got "
45+
f"{self.frequency_penalty}.")
46+
if self.temperature < 0.0:
47+
raise ValueError(
48+
f"temperature must be non-negative, got {self.temperature}.")
49+
if not 0.0 < self.top_p <= 1.0:
50+
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
51+
if self.top_k < -1 or self.top_k == 0:
52+
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
53+
f"got {self.top_k}.")
54+
if self.max_tokens < 1:
55+
raise ValueError(
56+
f"max_tokens must be at least 1, got {self.max_tokens}.")
57+
if self.logprobs < 0:
58+
raise ValueError(
59+
f"logprobs must be non-negative, got {self.logprobs}.")
60+
61+
def _verity_beam_search(self) -> None:
62+
if self.n == 1:
63+
raise ValueError("n must be greater than 1 when using beam search.")
64+
if self.temperature > 0.0:
65+
raise ValueError("temperature must be 0 when using beam search.")
66+
if self.top_p < 1.0:
67+
raise ValueError("top_p must be 1 when using beam search.")
68+
if self.top_k != -1:
69+
raise ValueError("top_k must be -1 when using beam search.")
70+
71+
def _verify_greedy_sampling(self) -> None:
72+
if self.n > 1:
73+
raise ValueError("n must be 1 when using greedy sampling.")
74+
if self.top_p < 1.0:
75+
raise ValueError("top_p must be 1 when using greedy sampling.")
76+
if self.top_k != -1:
77+
raise ValueError("top_k must be -1 when using greedy sampling.")
78+
7879
def __repr__(self) -> str:
7980
return (f"SamplingParams(n={self.n}, "
8081
f"presence_penalty={self.presence_penalty}, "

0 commit comments

Comments
 (0)