@@ -183,17 +183,38 @@ def get_outputs_from_model(
183
183
)
184
184
return outputs
185
185
186
- def ov_infer (inputs , struct_params , struct_outputs ):
187
- parameters = [
188
- p .output .get_node () for p in tree .flatten (struct_params )
189
- ]
190
- results = [
191
- ov_opset .result (r .output )
192
- for r in tree .flatten (struct_outputs )
193
- ]
194
-
195
- ov_model = ov .Model (results = results , parameters = parameters )
196
- compile_ov_model = ov .compile_model (ov_model , "CPU" )
186
+ def ov_infer (inputs , stop_token_ids , fn ):
187
+ struct_params , struct_outputs = get_struct_outputs (
188
+ inputs , stop_token_ids , fn
189
+ )
190
+
191
+ if not hasattr (ov_infer , "_compiled_models" ):
192
+ ov_infer ._compiled_models = {}
193
+
194
+ # Create hash based on parameters, results, and input shapes
195
+ inputs_shapes = [str (v .shape ) for k , v in inputs .items ()]
196
+ model_signature = (
197
+ f"inputs_{ len (inputs )} _"
198
+ f"inputs_shapes_{ '_' .join (inputs_shapes )} _"
199
+ )
200
+
201
+ model_hash = hash (model_signature )
202
+
203
+ if model_hash not in ov_infer ._compiled_models :
204
+ parameters = [
205
+ p .output .get_node () for p in tree .flatten (struct_params )
206
+ ]
207
+ results = [
208
+ ov_opset .result (r .output )
209
+ for r in tree .flatten (struct_outputs )
210
+ ]
211
+
212
+ ov_model = ov .Model (results = results , parameters = parameters )
213
+ ov_infer ._compiled_models [model_hash ] = ov .compile_model (
214
+ ov_model , "CPU"
215
+ )
216
+
217
+ compile_ov_model = ov_infer ._compiled_models [model_hash ]
197
218
return get_outputs_from_model (
198
219
inputs , struct_outputs , compile_ov_model
199
220
)
@@ -202,10 +223,7 @@ def wrapped_generate_function(inputs, stop_token_ids=None):
202
223
for k , v in inputs .items ():
203
224
if isinstance (v , OpenVINOKerasTensor ):
204
225
inputs [k ] = ops .convert_to_numpy (v )
205
- struct_params , struct_outputs = get_struct_outputs (
206
- inputs , stop_token_ids , self .generate_step
207
- )
208
- return ov_infer (inputs , struct_params , struct_outputs )
226
+ return ov_infer (inputs , stop_token_ids , self .generate_step )
209
227
210
228
self .generate_function = wrapped_generate_function
211
229
if keras .config .backend () == "torch" :
0 commit comments