@@ -54,13 +54,18 @@ def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, max_num_bat
5454 self .num_hpu_blocks = None
5555 self .max_model_len = max_model_len
5656 self .initialized = True
57-
5857 self .fallback_bs_base_step = 2
5958 self .fallback_seq_base_step = 32
6059 self .fallback_blocks_base_step = 32
6160
6261 ### GENERATE BUCKETS FUNCTIONS ###
6362
63+ def read_from_file (self , is_prompt ):
64+ file_name = get_config ().VLLM_BUCKETING_FROM_FILE
65+ from vllm_gaudi .extension .bucketing .file_strategy import (FileBucketingStrategy )
66+ strategy = FileBucketingStrategy ()
67+ return strategy .get_buckets (file_name , is_prompt )
68+
6469 def get_bucketing_strategy (self ):
6570 strategy = None
6671 # TODO - we can use different strategies for decode and prompt
@@ -78,6 +83,8 @@ def get_bucketing_strategy(self):
7883
7984 def generate_unified_buckets (self ):
8085 if self .initialized :
86+ if get_config ().VLLM_BUCKETING_FROM_FILE :
87+ assert "Unified attention doesn't support bucketing from file"
8188 from vllm_gaudi .extension .bucketing .unified import (UnifiedBucketingStrategy )
8289 strategy = UnifiedBucketingStrategy ()
8390
@@ -105,20 +112,29 @@ def generate_unified_buckets(self):
105112
106113 def generate_prompt_buckets (self ):
107114 if self .initialized :
108- strategy = self .get_bucketing_strategy ()
109-
110- bs_cfg , query_cfg , ctx_cfg = strategy .get_prompt_cfgs (max_num_prefill_seqs = self .max_num_prefill_seqs ,
111- block_size = self .block_size ,
112- max_num_batched_tokens = self .max_num_batched_tokens ,
113- max_model_len = self .max_model_len )
114-
115- bs_range = strategy .get_range (bs_cfg )
116- query_range = strategy .get_range (query_cfg )
117- ctx_range = strategy .get_range (ctx_cfg )
115+ buckets_from_file = None
116+ bs_range = []
117+ query_range = []
118+ ctx_range = []
119+ if get_config ().VLLM_BUCKETING_FROM_FILE :
120+ buckets_from_file = self .read_from_file (is_prompt = True )
121+ else :
122+ strategy = self .get_bucketing_strategy ()
123+
124+ bs_cfg , query_cfg , ctx_cfg = strategy .get_prompt_cfgs (
125+ max_num_prefill_seqs = self .max_num_prefill_seqs ,
126+ block_size = self .block_size ,
127+ max_num_batched_tokens = self .max_num_batched_tokens ,
128+ max_model_len = self .max_model_len )
129+
130+ bs_range = strategy .get_range (bs_cfg )
131+ query_range = strategy .get_range (query_cfg )
132+ ctx_range = strategy .get_range (ctx_cfg )
118133
119134 self .prompt_buckets = generate_buckets (bs_range , query_range , ctx_range , True , self .max_model_len ,
120135 self .max_num_seqs , self .max_num_prefill_seqs ,
121- self .max_num_batched_tokens , self .block_size , self .num_hpu_blocks )
136+ self .max_num_batched_tokens , self .block_size , self .num_hpu_blocks ,
137+ buckets_from_file )
122138 self .log_generate_info (True )
123139 else :
124140 logger ().info ("Bucketing is off - skipping prompt buckets generation" )
@@ -127,24 +143,33 @@ def generate_prompt_buckets(self):
127143
128144 def generate_decode_buckets (self ):
129145 if self .initialized :
130- strategy = self .get_bucketing_strategy ()
131-
132- bs_cfg , query_cfg , ctx_cfg = strategy .get_decode_cfgs (max_num_seqs = self .max_num_seqs ,
133- block_size = self .block_size ,
134- max_num_batched_tokens = self .max_num_batched_tokens ,
135- max_model_len = self .max_model_len ,
136- max_blocks = self .num_hpu_blocks )
137-
138- bs_range = strategy .get_range (bs_cfg )
139- query_range = strategy .get_range (query_cfg )
140- ctx_range = strategy .get_range (ctx_cfg )
141-
142- if get_config ().use_contiguous_pa and ctx_range [- 1 ] < self .num_hpu_blocks :
143- ctx_range .append (self .num_hpu_blocks )
146+ buckets_from_file = None
147+ bs_range = []
148+ query_range = []
149+ ctx_range = []
150+ if get_config ().VLLM_BUCKETING_FROM_FILE :
151+ buckets_from_file = self .read_from_file (is_prompt = False )
152+ else :
153+ strategy = self .get_bucketing_strategy ()
154+
155+ bs_cfg , query_cfg , ctx_cfg = strategy .get_decode_cfgs (
156+ max_num_seqs = self .max_num_seqs ,
157+ block_size = self .block_size ,
158+ max_num_batched_tokens = self .max_num_batched_tokens ,
159+ max_model_len = self .max_model_len ,
160+ max_blocks = self .num_hpu_blocks )
161+
162+ bs_range = strategy .get_range (bs_cfg )
163+ query_range = strategy .get_range (query_cfg )
164+ ctx_range = strategy .get_range (ctx_cfg )
165+
166+ if get_config ().use_contiguous_pa and ctx_range [- 1 ] < self .num_hpu_blocks :
167+ ctx_range .append (self .num_hpu_blocks )
144168
145169 self .decode_buckets = generate_buckets (bs_range , query_range , ctx_range , False , self .max_model_len ,
146170 self .max_num_seqs , self .max_num_prefill_seqs ,
147- self .max_num_batched_tokens , self .block_size , self .num_hpu_blocks )
171+ self .max_num_batched_tokens , self .block_size , self .num_hpu_blocks ,
172+ buckets_from_file )
148173 self .log_generate_info (False )
149174 else :
150175 logger ().info ("Bucketing is off - skipping decode buckets generation" )
@@ -225,8 +250,17 @@ def get_bucketing_manager():
225250 return instance
226251
227252
228- def generate_buckets (bs_range , query_range , ctx_range , is_prompt , max_model_len , max_num_seqs , max_num_prefill_seqs ,
229- max_num_batched_tokens , block_size , max_blocks ):
253+ def generate_buckets (bs_range ,
254+ query_range ,
255+ ctx_range ,
256+ is_prompt ,
257+ max_model_len ,
258+ max_num_seqs ,
259+ max_num_prefill_seqs ,
260+ max_num_batched_tokens ,
261+ block_size ,
262+ max_blocks ,
263+ file_buckets = None ):
230264 use_merged_prefill = get_config ().merged_prefill
231265 use_contiguous_pa = get_config ().use_contiguous_pa
232266
@@ -307,15 +341,23 @@ def get_filters(is_prompt, use_merged_prefill, use_contiguous_pa):
307341 buckets_2d = set ()
308342 omitted_buckets = set ()
309343 filters = get_filters (is_prompt , use_merged_prefill , use_contiguous_pa )
310- for bs_idx , bs in enumerate (bs_range ):
311- for query_idx , query in enumerate (query_range ):
312- buckets_2d .update (
313- expand_to_neighbor_buckets (bs_idx , bs_range , query_idx , query_range , max_num_batched_tokens ))
314-
315- for bs , query in buckets_2d :
316- for ctx in ctx_range :
317- if all (bucket_filter (bs , query , ctx ) for bucket_filter in filters ):
318- buckets .add ((bs , query , ctx ))
344+
345+ if file_buckets :
346+ for bs , query , blocks in file_buckets :
347+ if all (bucket_filter (bs , query , blocks ) for bucket_filter in filters ):
348+ buckets .add ((bs , query , blocks ))
349+ else :
350+ for bs_idx , bs in enumerate (bs_range ):
351+ for ctx_idx , ctx in enumerate (ctx_range ):
352+ local_buckets = expand_to_neighbor_buckets (bs_idx , bs_range , ctx_idx , ctx_range ,
353+ max_num_batched_tokens ) if not is_prompt else {(bs , ctx )}
354+ buckets_2d .update (local_buckets )
355+
356+ for bs , ctx in buckets_2d :
357+ for query in query_range :
358+ if all (bucket_filter (bs , query , ctx ) for bucket_filter in filters ):
359+ buckets .add ((bs , query , ctx ))
360+
319361 if not buckets :
320362 phase = 'prompt' if is_prompt else 'decode'
321363 for bucket in omitted_buckets :
0 commit comments