@@ -197,29 +197,35 @@ def call_with_cache(
197
197
the decoding cache.
198
198
"""
199
199
200
+ use_openvino = keras .config .backend () == "openvino"
201
+
200
202
def embed_and_scale_tokens (token_ids ):
201
203
x = self .backbone .token_embedding (token_ids )
202
204
return x * ops .cast (ops .sqrt (self .backbone .hidden_dim ), x .dtype )
203
205
204
- def make_apply_fn (layer ):
205
- def apply_transformer_layer (inputs ):
206
- x = inputs ["x" ]
207
- current_cache = inputs ["current_cache" ]
208
- index = inputs ["cache_update_index" ]
209
- x , next_cache = layer (
210
- x , cache = current_cache , cache_update_index = index
206
+ def apply_transformer_layers (inputs ):
207
+ x = inputs ["x" ]
208
+ cache = inputs ["cache" ]
209
+ cache_update_index = inputs ["cache_update_index" ]
210
+ caches = []
211
+ for i , transformer_layer in enumerate (
212
+ self .backbone .transformer_layers
213
+ ):
214
+ current_cache = cache [:, i , ...]
215
+ x , next_cache = transformer_layer (
216
+ x ,
217
+ cache = current_cache ,
218
+ cache_update_index = cache_update_index ,
211
219
)
212
- return x , next_cache
220
+ caches . append ( next_cache )
213
221
214
- return apply_transformer_layer
222
+ cache = ops .stack (caches , axis = 1 )
223
+ return x , cache
215
224
216
- def finalize_generation_step (inputs ):
217
- x = self .backbone .layer_norm (inputs ["x" ])
218
- cache = ops .stack (inputs ["caches" ], axis = 1 )
225
+ def finalize_generation_step (x ):
226
+ hidden_states = x = self .backbone .layer_norm (x )
219
227
logits = self .backbone .token_embedding (x , reverse = True )
220
- return logits , x , cache
221
-
222
- use_openvino = keras .config .backend () == "openvino"
228
+ return logits , hidden_states
223
229
224
230
if use_openvino :
225
231
token_ids = ops .convert_to_numpy (token_ids )
@@ -233,56 +239,58 @@ def finalize_generation_step(inputs):
233
239
)
234
240
else :
235
241
ov_cache = self ._ov_mem .get ("cache" )
236
- if ov_cache is not None and cache .shape == ov_cache .shape :
242
+ if ov_cache is not None and cache .shape == ov_cache .shape :
237
243
return None , self ._ov_mem ["hidden_states" ], ov_cache
238
244
x = self .ov_infer (token_ids , embed_and_scale_tokens )
239
245
else :
240
246
x = embed_and_scale_tokens (token_ids )
241
247
242
- caches = []
243
- for i , transformer_layer in enumerate (self .backbone .transformer_layers ):
244
- current_cache = cache [:, i , ...]
245
-
246
- inputs = {
247
- "x" : x ,
248
- "current_cache" : current_cache ,
249
- "cache_update_index" : cache_update_index ,
250
- }
251
-
252
- apply_fn = make_apply_fn (transformer_layer )
253
-
254
- if use_openvino :
255
- if token_ids .shape [1 ] == 1 :
256
- x , next_cache = self .ov_infer (
257
- inputs ,
258
- apply_fn ,
259
- disc = True ,
260
- name = f"layer_{ i } " ,
261
- )
262
- else :
263
- x , next_cache = self .ov_infer (inputs , apply_fn )
248
+ if use_openvino :
249
+ if token_ids .shape [1 ] == 1 :
250
+ x , cache = self .ov_infer (
251
+ {
252
+ "x" : x ,
253
+ "cache" : cache ,
254
+ "cache_update_index" : cache_update_index ,
255
+ },
256
+ apply_transformer_layers ,
257
+ cache = True ,
258
+ name = "apply_transformer_layers" ,
259
+ )
264
260
else :
265
- x , next_cache = apply_fn (inputs )
266
-
267
- caches .append (next_cache )
261
+ x , cache = self .ov_infer (
262
+ {
263
+ "x" : x ,
264
+ "cache" : cache ,
265
+ "cache_update_index" : cache_update_index ,
266
+ },
267
+ apply_transformer_layers ,
268
+ )
269
+ self ._ov_mem ["cache" ] = cache
270
+ else :
271
+ x , cache = apply_transformer_layers (
272
+ {
273
+ "x" : x ,
274
+ "cache" : cache ,
275
+ "cache_update_index" : cache_update_index ,
276
+ }
277
+ )
268
278
269
- inputs = {"x" : x , "caches" : caches }
270
279
if use_openvino :
271
280
if token_ids .shape [1 ] == 1 :
272
- logits , hidden_states , cache = self .ov_infer (
273
- inputs ,
281
+ logits , hidden_states = self .ov_infer (
282
+ x ,
274
283
finalize_generation_step ,
275
284
cache = True ,
276
285
name = "finalize_generation_step" ,
277
286
)
278
287
else :
279
- logits , hidden_states , cache = self .ov_infer (
280
- inputs , finalize_generation_step
288
+ logits , hidden_states = self .ov_infer (
289
+ x , finalize_generation_step
281
290
)
282
- self ._ov_mem ["cache" ] = cache
283
291
self ._ov_mem ["hidden_states" ] = hidden_states
284
292
else :
285
- logits , hidden_states , cache = finalize_generation_step (inputs )
293
+ logits , hidden_states = finalize_generation_step (x )
286
294
287
295
return logits , hidden_states , cache
288
296
0 commit comments