1
- from typing import Dict , Set
1
+ from typing import Set
2
2
3
3
4
4
class SamplingParams :
@@ -16,54 +16,6 @@ def __init__(
16
16
max_tokens : int = 16 ,
17
17
logprobs : int = 0 ,
18
18
) -> None :
19
- if n < 1 :
20
- raise ValueError (f"n must be at least 1, got { n } ." )
21
- if not - 2.0 <= presence_penalty <= 2.0 :
22
- raise ValueError (
23
- f"presence_penalty must be in [-2, 2], got { presence_penalty } ." )
24
- if not - 2.0 <= frequency_penalty <= 2.0 :
25
- raise ValueError (
26
- f"frequency_penalty must be in [-2, 2], got { frequency_penalty } ." )
27
- if temperature < 0.0 :
28
- raise ValueError (
29
- f"temperature must be non-negative, got { temperature } ." )
30
- if not 0.0 < top_p <= 1.0 :
31
- raise ValueError (f"top_p must be in (0, 1], got { top_p } ." )
32
- if top_k < - 1 or top_k == 0 :
33
- raise ValueError (f"top_k must be -1 (disable), or at least 1, "
34
- f"got { top_k } ." )
35
- if max_tokens < 1 :
36
- raise ValueError (
37
- f"max_tokens must be at least 1, got { max_tokens } ." )
38
- if logprobs < 0 :
39
- raise ValueError (
40
- f"logprobs must be non-negative, got { logprobs } ." )
41
-
42
- if use_beam_search :
43
- if n == 1 :
44
- raise ValueError (
45
- "n must be greater than 1 when using beam search." )
46
- if temperature > 0.0 :
47
- raise ValueError (
48
- "temperature must be 0 when using beam search." )
49
- if top_p < 1.0 :
50
- raise ValueError (
51
- "top_p must be 1 when using beam search." )
52
- if top_k != - 1 :
53
- raise ValueError (
54
- "top_k must be -1 when using beam search." )
55
- elif temperature == 0.0 :
56
- # Zero temperature means greedy sampling.
57
- if n > 1 :
58
- raise ValueError (
59
- "n must be 1 when using greedy sampling." )
60
- if top_p < 1.0 :
61
- raise ValueError (
62
- "top_p must be 1 when using greedy sampling." )
63
- if top_k != - 1 :
64
- raise ValueError (
65
- "top_k must be -1 when using greedy sampling." )
66
-
67
19
self .n = n
68
20
self .presence_penalty = presence_penalty
69
21
self .frequency_penalty = frequency_penalty
@@ -75,6 +27,55 @@ def __init__(
75
27
self .max_tokens = max_tokens
76
28
self .logprobs = logprobs
77
29
30
+ self ._verify_args ()
31
+ if self .use_beam_search :
32
+ self ._verity_beam_search ()
33
+ elif self .temperature == 0.0 :
34
+ # Zero temperature means greedy sampling.
35
+ self ._verify_greedy_sampling ()
36
+
37
+ def _verify_args (self ) -> None :
38
+ if self .n < 1 :
39
+ raise ValueError (f"n must be at least 1, got { self .n } ." )
40
+ if not - 2.0 <= self .presence_penalty <= 2.0 :
41
+ raise ValueError ("presence_penalty must be in [-2, 2], got "
42
+ f"{ self .presence_penalty } ." )
43
+ if not - 2.0 <= self .frequency_penalty <= 2.0 :
44
+ raise ValueError ("frequency_penalty must be in [-2, 2], got "
45
+ f"{ self .frequency_penalty } ." )
46
+ if self .temperature < 0.0 :
47
+ raise ValueError (
48
+ f"temperature must be non-negative, got { self .temperature } ." )
49
+ if not 0.0 < self .top_p <= 1.0 :
50
+ raise ValueError (f"top_p must be in (0, 1], got { self .top_p } ." )
51
+ if self .top_k < - 1 or self .top_k == 0 :
52
+ raise ValueError (f"top_k must be -1 (disable), or at least 1, "
53
+ f"got { self .top_k } ." )
54
+ if self .max_tokens < 1 :
55
+ raise ValueError (
56
+ f"max_tokens must be at least 1, got { self .max_tokens } ." )
57
+ if self .logprobs < 0 :
58
+ raise ValueError (
59
+ f"logprobs must be non-negative, got { self .logprobs } ." )
60
+
61
+ def _verity_beam_search (self ) -> None :
62
+ if self .n == 1 :
63
+ raise ValueError ("n must be greater than 1 when using beam search." )
64
+ if self .temperature > 0.0 :
65
+ raise ValueError ("temperature must be 0 when using beam search." )
66
+ if self .top_p < 1.0 :
67
+ raise ValueError ("top_p must be 1 when using beam search." )
68
+ if self .top_k != - 1 :
69
+ raise ValueError ("top_k must be -1 when using beam search." )
70
+
71
+ def _verify_greedy_sampling (self ) -> None :
72
+ if self .n > 1 :
73
+ raise ValueError ("n must be 1 when using greedy sampling." )
74
+ if self .top_p < 1.0 :
75
+ raise ValueError ("top_p must be 1 when using greedy sampling." )
76
+ if self .top_k != - 1 :
77
+ raise ValueError ("top_k must be -1 when using greedy sampling." )
78
+
78
79
def __repr__ (self ) -> str :
79
80
return (f"SamplingParams(n={ self .n } , "
80
81
f"presence_penalty={ self .presence_penalty } , "
0 commit comments