Skip to content

Commit 248c6e4

Browse files
supprt model reusing by hashing
1 parent 298a319 commit 248c6e4

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,38 @@ def get_outputs_from_model(
183183
)
184184
return outputs
185185

186-
def ov_infer(inputs, struct_params, struct_outputs):
187-
parameters = [
188-
p.output.get_node() for p in tree.flatten(struct_params)
189-
]
190-
results = [
191-
ov_opset.result(r.output)
192-
for r in tree.flatten(struct_outputs)
193-
]
194-
195-
ov_model = ov.Model(results=results, parameters=parameters)
196-
compile_ov_model = ov.compile_model(ov_model, "CPU")
186+
def ov_infer(inputs, stop_token_ids, fn):
187+
struct_params, struct_outputs = get_struct_outputs(
188+
inputs, stop_token_ids, fn
189+
)
190+
191+
if not hasattr(ov_infer, "_compiled_models"):
192+
ov_infer._compiled_models = {}
193+
194+
# Create hash based on parameters, results, and input shapes
195+
inputs_shapes = [str(v.shape) for k, v in inputs.items()]
196+
model_signature = (
197+
f"inputs_{len(inputs)}_"
198+
f"inputs_shapes_{'_'.join(inputs_shapes)}_"
199+
)
200+
201+
model_hash = hash(model_signature)
202+
203+
if model_hash not in ov_infer._compiled_models:
204+
parameters = [
205+
p.output.get_node() for p in tree.flatten(struct_params)
206+
]
207+
results = [
208+
ov_opset.result(r.output)
209+
for r in tree.flatten(struct_outputs)
210+
]
211+
212+
ov_model = ov.Model(results=results, parameters=parameters)
213+
ov_infer._compiled_models[model_hash] = ov.compile_model(
214+
ov_model, "CPU"
215+
)
216+
217+
compile_ov_model = ov_infer._compiled_models[model_hash]
197218
return get_outputs_from_model(
198219
inputs, struct_outputs, compile_ov_model
199220
)
@@ -202,10 +223,7 @@ def wrapped_generate_function(inputs, stop_token_ids=None):
202223
for k, v in inputs.items():
203224
if isinstance(v, OpenVINOKerasTensor):
204225
inputs[k] = ops.convert_to_numpy(v)
205-
struct_params, struct_outputs = get_struct_outputs(
206-
inputs, stop_token_ids, self.generate_step
207-
)
208-
return ov_infer(inputs, struct_params, struct_outputs)
226+
return ov_infer(inputs, stop_token_ids, self.generate_step)
209227

210228
self.generate_function = wrapped_generate_function
211229
if keras.config.backend() == "torch":

0 commit comments

Comments
 (0)