Skip to content

Commit a8a02cd

Browse files
jkanieckiadobrzynmichalkuligowski
authored
Buckets from file - alpha version (#375) (#496)
Cherry-pick of d84e734 --------- Signed-off-by: Agata Dobrzyniewicz <[email protected]> Co-authored-by: Agata Dobrzyniewicz <[email protected]> Co-authored-by: Michał Kuligowski <[email protected]>
1 parent 0001620 commit a8a02cd

File tree

4 files changed

+159
-38
lines changed

4 files changed

+159
-38
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# This is a sample bucketing file
2+
3+
# Add buckets as follows:
4+
# Prompt: (batch_size, query_len, context)
5+
# Decode: (batch_size, 1, context)
6+
7+
# You can also use lists to end up with Cartesian product like so:
8+
# (1, [256, 512], [0, 4, 8])
9+
# In this case you will end up with 6 buckets
10+
# You can also use python's range to create similiar lists
11+
# range(min, max, step)
12+
# Examples are shown below
13+
14+
# Not supported for unified attention buckets
15+
# use '#' to comment out lines
16+
17+
18+
# Buckets:
19+
(1, 2048, 0)
20+
(1, [256, 512], [0, 4, 8])
21+
22+
(64, 1, 1024)
23+
(32, 1, 512)
24+
(1, 1, range(256, 512, 32))

vllm_gaudi/extension/bucketing/common.py

Lines changed: 80 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,18 @@ def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, max_num_bat
5454
self.num_hpu_blocks = None
5555
self.max_model_len = max_model_len
5656
self.initialized = True
57-
5857
self.fallback_bs_base_step = 2
5958
self.fallback_seq_base_step = 32
6059
self.fallback_blocks_base_step = 32
6160

6261
### GENERATE BUCKETS FUNCTIONS ###
6362

63+
def read_from_file(self, is_prompt):
64+
file_name = get_config().VLLM_BUCKETING_FROM_FILE
65+
from vllm_gaudi.extension.bucketing.file_strategy import (FileBucketingStrategy)
66+
strategy = FileBucketingStrategy()
67+
return strategy.get_buckets(file_name, is_prompt)
68+
6469
def get_bucketing_strategy(self):
6570
strategy = None
6671
# TODO - we can use different strategies for decode and prompt
@@ -78,6 +83,8 @@ def get_bucketing_strategy(self):
7883

7984
def generate_unified_buckets(self):
8085
if self.initialized:
86+
if get_config().VLLM_BUCKETING_FROM_FILE:
87+
assert "Unified attention doesn't support bucketing from file"
8188
from vllm_gaudi.extension.bucketing.unified import (UnifiedBucketingStrategy)
8289
strategy = UnifiedBucketingStrategy()
8390

@@ -105,20 +112,29 @@ def generate_unified_buckets(self):
105112

106113
def generate_prompt_buckets(self):
107114
if self.initialized:
108-
strategy = self.get_bucketing_strategy()
109-
110-
bs_cfg, query_cfg, ctx_cfg = strategy.get_prompt_cfgs(max_num_prefill_seqs=self.max_num_prefill_seqs,
111-
block_size=self.block_size,
112-
max_num_batched_tokens=self.max_num_batched_tokens,
113-
max_model_len=self.max_model_len)
114-
115-
bs_range = strategy.get_range(bs_cfg)
116-
query_range = strategy.get_range(query_cfg)
117-
ctx_range = strategy.get_range(ctx_cfg)
115+
buckets_from_file = None
116+
bs_range = []
117+
query_range = []
118+
ctx_range = []
119+
if get_config().VLLM_BUCKETING_FROM_FILE:
120+
buckets_from_file = self.read_from_file(is_prompt=True)
121+
else:
122+
strategy = self.get_bucketing_strategy()
123+
124+
bs_cfg, query_cfg, ctx_cfg = strategy.get_prompt_cfgs(
125+
max_num_prefill_seqs=self.max_num_prefill_seqs,
126+
block_size=self.block_size,
127+
max_num_batched_tokens=self.max_num_batched_tokens,
128+
max_model_len=self.max_model_len)
129+
130+
bs_range = strategy.get_range(bs_cfg)
131+
query_range = strategy.get_range(query_cfg)
132+
ctx_range = strategy.get_range(ctx_cfg)
118133

119134
self.prompt_buckets = generate_buckets(bs_range, query_range, ctx_range, True, self.max_model_len,
120135
self.max_num_seqs, self.max_num_prefill_seqs,
121-
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks)
136+
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
137+
buckets_from_file)
122138
self.log_generate_info(True)
123139
else:
124140
logger().info("Bucketing is off - skipping prompt buckets generation")
@@ -127,24 +143,33 @@ def generate_prompt_buckets(self):
127143

128144
def generate_decode_buckets(self):
129145
if self.initialized:
130-
strategy = self.get_bucketing_strategy()
131-
132-
bs_cfg, query_cfg, ctx_cfg = strategy.get_decode_cfgs(max_num_seqs=self.max_num_seqs,
133-
block_size=self.block_size,
134-
max_num_batched_tokens=self.max_num_batched_tokens,
135-
max_model_len=self.max_model_len,
136-
max_blocks=self.num_hpu_blocks)
137-
138-
bs_range = strategy.get_range(bs_cfg)
139-
query_range = strategy.get_range(query_cfg)
140-
ctx_range = strategy.get_range(ctx_cfg)
141-
142-
if get_config().use_contiguous_pa and ctx_range[-1] < self.num_hpu_blocks:
143-
ctx_range.append(self.num_hpu_blocks)
146+
buckets_from_file = None
147+
bs_range = []
148+
query_range = []
149+
ctx_range = []
150+
if get_config().VLLM_BUCKETING_FROM_FILE:
151+
buckets_from_file = self.read_from_file(is_prompt=False)
152+
else:
153+
strategy = self.get_bucketing_strategy()
154+
155+
bs_cfg, query_cfg, ctx_cfg = strategy.get_decode_cfgs(
156+
max_num_seqs=self.max_num_seqs,
157+
block_size=self.block_size,
158+
max_num_batched_tokens=self.max_num_batched_tokens,
159+
max_model_len=self.max_model_len,
160+
max_blocks=self.num_hpu_blocks)
161+
162+
bs_range = strategy.get_range(bs_cfg)
163+
query_range = strategy.get_range(query_cfg)
164+
ctx_range = strategy.get_range(ctx_cfg)
165+
166+
if get_config().use_contiguous_pa and ctx_range[-1] < self.num_hpu_blocks:
167+
ctx_range.append(self.num_hpu_blocks)
144168

145169
self.decode_buckets = generate_buckets(bs_range, query_range, ctx_range, False, self.max_model_len,
146170
self.max_num_seqs, self.max_num_prefill_seqs,
147-
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks)
171+
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
172+
buckets_from_file)
148173
self.log_generate_info(False)
149174
else:
150175
logger().info("Bucketing is off - skipping decode buckets generation")
@@ -225,8 +250,17 @@ def get_bucketing_manager():
225250
return instance
226251

227252

228-
def generate_buckets(bs_range, query_range, ctx_range, is_prompt, max_model_len, max_num_seqs, max_num_prefill_seqs,
229-
max_num_batched_tokens, block_size, max_blocks):
253+
def generate_buckets(bs_range,
254+
query_range,
255+
ctx_range,
256+
is_prompt,
257+
max_model_len,
258+
max_num_seqs,
259+
max_num_prefill_seqs,
260+
max_num_batched_tokens,
261+
block_size,
262+
max_blocks,
263+
file_buckets=None):
230264
use_merged_prefill = get_config().merged_prefill
231265
use_contiguous_pa = get_config().use_contiguous_pa
232266

@@ -307,15 +341,23 @@ def get_filters(is_prompt, use_merged_prefill, use_contiguous_pa):
307341
buckets_2d = set()
308342
omitted_buckets = set()
309343
filters = get_filters(is_prompt, use_merged_prefill, use_contiguous_pa)
310-
for bs_idx, bs in enumerate(bs_range):
311-
for query_idx, query in enumerate(query_range):
312-
buckets_2d.update(
313-
expand_to_neighbor_buckets(bs_idx, bs_range, query_idx, query_range, max_num_batched_tokens))
314-
315-
for bs, query in buckets_2d:
316-
for ctx in ctx_range:
317-
if all(bucket_filter(bs, query, ctx) for bucket_filter in filters):
318-
buckets.add((bs, query, ctx))
344+
345+
if file_buckets:
346+
for bs, query, blocks in file_buckets:
347+
if all(bucket_filter(bs, query, blocks) for bucket_filter in filters):
348+
buckets.add((bs, query, blocks))
349+
else:
350+
for bs_idx, bs in enumerate(bs_range):
351+
for ctx_idx, ctx in enumerate(ctx_range):
352+
local_buckets = expand_to_neighbor_buckets(bs_idx, bs_range, ctx_idx, ctx_range,
353+
max_num_batched_tokens) if not is_prompt else {(bs, ctx)}
354+
buckets_2d.update(local_buckets)
355+
356+
for bs, ctx in buckets_2d:
357+
for query in query_range:
358+
if all(bucket_filter(bs, query, ctx) for bucket_filter in filters):
359+
buckets.add((bs, query, ctx))
360+
319361
if not buckets:
320362
phase = 'prompt' if is_prompt else 'decode'
321363
for bucket in omitted_buckets:
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import itertools
2+
import operator
3+
import os
4+
import math
5+
import ast
6+
from dataclasses import dataclass, field
7+
from typing import List, Tuple
8+
9+
from vllm_gaudi.extension.logger import logger as logger
10+
from vllm_gaudi.extension.runtime import get_config
11+
12+
13+
class FileBucketingStrategy:
14+
15+
def get_buckets(self, file_name, is_prompt):
16+
prompt_buckets = []
17+
decode_buckets = []
18+
19+
with open(file_name, 'r') as f:
20+
for line in f:
21+
line = line.strip()
22+
if not line or line.startswith('#'):
23+
continue
24+
25+
try:
26+
bucket = eval(line, {"__builtins__": None}, {"range": range})
27+
except Exception as e:
28+
print(f"Skipping line due to eval error: {e} - {line}")
29+
continue
30+
31+
if not isinstance(bucket, tuple) or len(bucket) != 3:
32+
print('Skipping line due to incorrect format - ', bucket)
33+
continue
34+
35+
x_num = ensure_is_list(bucket[0])
36+
y_num = ensure_is_list(bucket[1])
37+
z_num = ensure_is_list(bucket[2])
38+
39+
for full_bucket in itertools.product(x_num, y_num, z_num):
40+
x, y, z = map(int, full_bucket)
41+
if y == 1:
42+
decode_buckets.append((x, y, z))
43+
else:
44+
prompt_buckets.append((x, y, z))
45+
return sorted(prompt_buckets) if is_prompt else sorted(decode_buckets)
46+
47+
48+
def ensure_is_list(value):
49+
if isinstance(value, list):
50+
return value
51+
elif isinstance(value, range):
52+
return list(value)
53+
else:
54+
return [value]

vllm_gaudi/extension/features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def get_user_flags():
3434
Env('VLLM_DECODE_BLOCK_BUCKET_STEP', int),
3535
Env('VLLM_DECODE_BLOCK_BUCKET_MAX', int),
3636
Env('VLLM_DECODE_BLOCK_BUCKET_LIMIT', int),
37+
Env('VLLM_BUCKETING_FROM_FILE', str),
3738

3839
# Non-vllm flags that are also important to print
3940
Env('EXPERIMENTAL_WEIGHT_SHARING', str),

0 commit comments

Comments
 (0)