Skip to content

Commit 2044408

Browse files
remove disc mechanism
1 parent 692ae90 commit 2044408

File tree

2 files changed

+63
-89
lines changed

2 files changed

+63
-89
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 7 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,6 @@ def make_generate_function(self):
138138

139139
self.generate_function = self.generate_step
140140
if keras.config.backend() == "openvino":
141-
import os
142-
import shutil
143-
144141
import numpy as np
145142
import openvino as ov
146143
import openvino.runtime.opset14 as ov_opset
@@ -192,17 +189,13 @@ def get_outputs_from_model(inputs, model):
192189
return outputs
193190

194191
def get_model(inputs, fn, ov_model=None, compiled=False):
195-
config = {
196-
"CACHE_DIR": "openvino_cache",
197-
}
198-
199192
struct_params, _ = set_struct_outputs(inputs, fn)
200193

201194
if ov_model is not None:
202195
assert compiled, (
203196
"if you pass a model, you should make compiled=True"
204197
)
205-
return ov.compile_model(ov_model, "CPU", config)
198+
return ov.compile_model(ov_model, "CPU")
206199

207200
parameters = [
208201
p.output.get_node() for p in tree.flatten(struct_params)
@@ -216,21 +209,12 @@ def get_model(inputs, fn, ov_model=None, compiled=False):
216209
if not compiled:
217210
return ov_model
218211

219-
return ov.compile_model(ov_model, "CPU", config)
220-
221-
def compile_model_disc(inputs, fn, name):
222-
model_path = f"./run_dir/{name}.xml"
223-
if not os.path.exists(model_path):
224-
ov_model = get_model(inputs, fn)
225-
ov.save_model(ov_model, model_path)
226-
model = ov.Core().read_model(model_path)
227-
return get_model(inputs, fn, ov_model=model, compiled=True)
212+
return ov.compile_model(ov_model, "CPU")
228213

229214
def ov_infer(
230215
inputs,
231216
fn,
232217
cache=False,
233-
disc=False,
234218
name=None,
235219
):
236220
compiled_model = None
@@ -245,34 +229,19 @@ def ov_infer(
245229
else:
246230
set_struct_outputs(inputs, fn)
247231
compiled_model = self._ov_mem[name]
248-
elif disc:
249-
assert name is not None, (
250-
"you should provide the name of thr model"
251-
)
252-
compiled_model = compile_model_disc(inputs, fn, name)
253232
else:
254233
compiled_model = get_model(inputs, fn, compiled=True)
255234
outputs = get_outputs_from_model(inputs, compiled_model)
256235
del compiled_model
257236
return outputs
258237

259-
def delete_ov_cache():
260-
for path in ["openvino_cache", "run_dir"]:
261-
if os.path.exists(path):
262-
shutil.rmtree(path, ignore_errors=True)
263-
264238
self.ov_infer = ov_infer
265239

266240
def wrapped_generate_function(inputs, stop_token_ids=None):
267-
final_outputs = []
268-
os.makedirs("./run_dir", exist_ok=True)
269-
for input in inputs:
270-
outputs = self.generate_step(input, stop_token_ids)
271-
for k, v in outputs.items():
272-
outputs[k] = ops.convert_to_numpy(v)
273-
final_outputs.append(outputs)
274-
delete_ov_cache()
275-
return final_outputs
241+
outputs = self.generate_step(inputs, stop_token_ids)
242+
for k, v in outputs.items():
243+
outputs[k] = ops.convert_to_numpy(v)
244+
return outputs
276245

277246
self.generate_function = wrapped_generate_function
278247
if keras.config.backend() == "torch":
@@ -529,10 +498,7 @@ def postprocess(x):
529498
if strip_prompt:
530499
outputs = [strip_prompt_function(generate(x), x) for x in inputs]
531500
else:
532-
if keras.config.backend() == "openvino":
533-
outputs = generate(inputs)
534-
else:
535-
outputs = [generate(x) for x in inputs]
501+
outputs = [generate(x) for x in inputs]
536502

537503
if self.preprocessor is not None:
538504
outputs = [postprocess(x) for x in outputs]

keras_hub/src/models/gemma/gemma_causal_lm.py

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -197,29 +197,35 @@ def call_with_cache(
197197
the decoding cache.
198198
"""
199199

200+
use_openvino = keras.config.backend() == "openvino"
201+
200202
def embed_and_scale_tokens(token_ids):
201203
x = self.backbone.token_embedding(token_ids)
202204
return x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype)
203205

204-
def make_apply_fn(layer):
205-
def apply_transformer_layer(inputs):
206-
x = inputs["x"]
207-
current_cache = inputs["current_cache"]
208-
index = inputs["cache_update_index"]
209-
x, next_cache = layer(
210-
x, cache=current_cache, cache_update_index=index
206+
def apply_transformer_layers(inputs):
207+
x = inputs["x"]
208+
cache = inputs["cache"]
209+
cache_update_index = inputs["cache_update_index"]
210+
caches = []
211+
for i, transformer_layer in enumerate(
212+
self.backbone.transformer_layers
213+
):
214+
current_cache = cache[:, i, ...]
215+
x, next_cache = transformer_layer(
216+
x,
217+
cache=current_cache,
218+
cache_update_index=cache_update_index,
211219
)
212-
return x, next_cache
220+
caches.append(next_cache)
213221

214-
return apply_transformer_layer
222+
cache = ops.stack(caches, axis=1)
223+
return x, cache
215224

216-
def finalize_generation_step(inputs):
217-
x = self.backbone.layer_norm(inputs["x"])
218-
cache = ops.stack(inputs["caches"], axis=1)
225+
def finalize_generation_step(x):
226+
hidden_states = x = self.backbone.layer_norm(x)
219227
logits = self.backbone.token_embedding(x, reverse=True)
220-
return logits, x, cache
221-
222-
use_openvino = keras.config.backend() == "openvino"
228+
return logits, hidden_states
223229

224230
if use_openvino:
225231
token_ids = ops.convert_to_numpy(token_ids)
@@ -233,56 +239,58 @@ def finalize_generation_step(inputs):
233239
)
234240
else:
235241
ov_cache = self._ov_mem.get("cache")
236-
if ov_cache is not None and cache.shape == ov_cache.shape:
242+
if ov_cache is not None and cache.shape == ov_cache.shape:
237243
return None, self._ov_mem["hidden_states"], ov_cache
238244
x = self.ov_infer(token_ids, embed_and_scale_tokens)
239245
else:
240246
x = embed_and_scale_tokens(token_ids)
241247

242-
caches = []
243-
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
244-
current_cache = cache[:, i, ...]
245-
246-
inputs = {
247-
"x": x,
248-
"current_cache": current_cache,
249-
"cache_update_index": cache_update_index,
250-
}
251-
252-
apply_fn = make_apply_fn(transformer_layer)
253-
254-
if use_openvino:
255-
if token_ids.shape[1] == 1:
256-
x, next_cache = self.ov_infer(
257-
inputs,
258-
apply_fn,
259-
disc=True,
260-
name=f"layer_{i}",
261-
)
262-
else:
263-
x, next_cache = self.ov_infer(inputs, apply_fn)
248+
if use_openvino:
249+
if token_ids.shape[1] == 1:
250+
x, cache = self.ov_infer(
251+
{
252+
"x": x,
253+
"cache": cache,
254+
"cache_update_index": cache_update_index,
255+
},
256+
apply_transformer_layers,
257+
cache=True,
258+
name="apply_transformer_layers",
259+
)
264260
else:
265-
x, next_cache = apply_fn(inputs)
266-
267-
caches.append(next_cache)
261+
x, cache = self.ov_infer(
262+
{
263+
"x": x,
264+
"cache": cache,
265+
"cache_update_index": cache_update_index,
266+
},
267+
apply_transformer_layers,
268+
)
269+
self._ov_mem["cache"] = cache
270+
else:
271+
x, cache = apply_transformer_layers(
272+
{
273+
"x": x,
274+
"cache": cache,
275+
"cache_update_index": cache_update_index,
276+
}
277+
)
268278

269-
inputs = {"x": x, "caches": caches}
270279
if use_openvino:
271280
if token_ids.shape[1] == 1:
272-
logits, hidden_states, cache = self.ov_infer(
273-
inputs,
281+
logits, hidden_states = self.ov_infer(
282+
x,
274283
finalize_generation_step,
275284
cache=True,
276285
name="finalize_generation_step",
277286
)
278287
else:
279-
logits, hidden_states, cache = self.ov_infer(
280-
inputs, finalize_generation_step
288+
logits, hidden_states = self.ov_infer(
289+
x, finalize_generation_step
281290
)
282-
self._ov_mem["cache"] = cache
283291
self._ov_mem["hidden_states"] = hidden_states
284292
else:
285-
logits, hidden_states, cache = finalize_generation_step(inputs)
293+
logits, hidden_states = finalize_generation_step(x)
286294

287295
return logits, hidden_states, cache
288296

0 commit comments

Comments
 (0)