Skip to content

Commit 298a319

Browse files
remove unwanted code
1 parent 2044408 commit 298a319

File tree

6 files changed

+54
-219
lines changed

6 files changed

+54
-219
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 21 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ class CausalLM(Task):
5858

5959
def __init__(self, *args, **kwargs):
6060
super().__init__(*args, **kwargs)
61-
# only OpenVINO needs these declarations
62-
if keras.config.backend() == "openvino":
63-
self._ov_mem = {}
64-
self.struct_outputs = None
65-
self.ov_infer = None
6661

6762
def compile(
6863
self,
@@ -170,78 +165,47 @@ def parameterize_inputs(inputs):
170165
else:
171166
raise TypeError(f"Unknown input type: {type(inputs)}")
172167

173-
def set_struct_outputs(inputs, fn):
168+
def get_struct_outputs(inputs, stop_token_ids, fn):
174169
struct_params = parameterize_inputs(inputs)
175-
self.struct_outputs = fn(struct_params)
176-
return struct_params, self.struct_outputs
170+
struct_outputs = fn(struct_params, stop_token_ids)
171+
return struct_params, struct_outputs
177172

178-
def get_outputs_from_model(inputs, model):
173+
def get_outputs_from_model(
174+
inputs, struct_outputs, compile_ov_model
175+
):
179176
flatten_inputs = tree.flatten(inputs)
180177
assert OpenVINOKerasTensor not in inputs, (
181178
"inputs should be numpy arrays"
182179
)
183-
outputs = model(flatten_inputs)
180+
outputs = compile_ov_model(flatten_inputs)
184181
outputs = unpack_singleton(
185-
tree.pack_sequence_as(
186-
self.struct_outputs, outputs.to_tuple()
187-
)
182+
tree.pack_sequence_as(struct_outputs, outputs.to_tuple())
188183
)
189184
return outputs
190185

191-
def get_model(inputs, fn, ov_model=None, compiled=False):
192-
struct_params, _ = set_struct_outputs(inputs, fn)
193-
194-
if ov_model is not None:
195-
assert compiled, (
196-
"if you pass a model, you should make compiled=True"
197-
)
198-
return ov.compile_model(ov_model, "CPU")
199-
186+
def ov_infer(inputs, struct_params, struct_outputs):
200187
parameters = [
201188
p.output.get_node() for p in tree.flatten(struct_params)
202189
]
203190
results = [
204191
ov_opset.result(r.output)
205-
for r in tree.flatten(self.struct_outputs)
192+
for r in tree.flatten(struct_outputs)
206193
]
207194

208195
ov_model = ov.Model(results=results, parameters=parameters)
209-
if not compiled:
210-
return ov_model
211-
212-
return ov.compile_model(ov_model, "CPU")
213-
214-
def ov_infer(
215-
inputs,
216-
fn,
217-
cache=False,
218-
name=None,
219-
):
220-
compiled_model = None
221-
if cache:
222-
assert name is not None, (
223-
"you should provide name of the model being cached"
224-
)
225-
if self._ov_mem.get(name) is None:
226-
self._ov_mem[name] = get_model(
227-
inputs, fn, compiled=True
228-
)
229-
else:
230-
set_struct_outputs(inputs, fn)
231-
compiled_model = self._ov_mem[name]
232-
else:
233-
compiled_model = get_model(inputs, fn, compiled=True)
234-
outputs = get_outputs_from_model(inputs, compiled_model)
235-
del compiled_model
236-
return outputs
237-
238-
self.ov_infer = ov_infer
196+
compile_ov_model = ov.compile_model(ov_model, "CPU")
197+
return get_outputs_from_model(
198+
inputs, struct_outputs, compile_ov_model
199+
)
239200

240201
def wrapped_generate_function(inputs, stop_token_ids=None):
241-
outputs = self.generate_step(inputs, stop_token_ids)
242-
for k, v in outputs.items():
243-
outputs[k] = ops.convert_to_numpy(v)
244-
return outputs
202+
for k, v in inputs.items():
203+
if isinstance(v, OpenVINOKerasTensor):
204+
inputs[k] = ops.convert_to_numpy(v)
205+
struct_params, struct_outputs = get_struct_outputs(
206+
inputs, stop_token_ids, self.generate_step
207+
)
208+
return ov_infer(inputs, struct_params, struct_outputs)
245209

246210
self.generate_function = wrapped_generate_function
247211
if keras.config.backend() == "torch":

keras_hub/src/models/gemma/gemma_causal_lm.py

Lines changed: 17 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -196,102 +196,22 @@ def call_with_cache(
196196
the final hidden representation of the input tokens, and `cache` is
197197
the decoding cache.
198198
"""
199-
200-
use_openvino = keras.config.backend() == "openvino"
201-
202-
def embed_and_scale_tokens(token_ids):
203-
x = self.backbone.token_embedding(token_ids)
204-
return x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype)
205-
206-
def apply_transformer_layers(inputs):
207-
x = inputs["x"]
208-
cache = inputs["cache"]
209-
cache_update_index = inputs["cache_update_index"]
210-
caches = []
211-
for i, transformer_layer in enumerate(
212-
self.backbone.transformer_layers
213-
):
214-
current_cache = cache[:, i, ...]
215-
x, next_cache = transformer_layer(
216-
x,
217-
cache=current_cache,
218-
cache_update_index=cache_update_index,
219-
)
220-
caches.append(next_cache)
221-
222-
cache = ops.stack(caches, axis=1)
223-
return x, cache
224-
225-
def finalize_generation_step(x):
226-
hidden_states = x = self.backbone.layer_norm(x)
227-
logits = self.backbone.token_embedding(x, reverse=True)
228-
return logits, hidden_states
229-
230-
if use_openvino:
231-
token_ids = ops.convert_to_numpy(token_ids)
232-
cache = ops.convert_to_numpy(cache)
233-
if token_ids.shape[1] == 1:
234-
x = self.ov_infer(
235-
token_ids,
236-
embed_and_scale_tokens,
237-
cache=True,
238-
name="embed_and_scale_tokens",
239-
)
240-
else:
241-
ov_cache = self._ov_mem.get("cache")
242-
if ov_cache is not None and cache.shape == ov_cache.shape:
243-
return None, self._ov_mem["hidden_states"], ov_cache
244-
x = self.ov_infer(token_ids, embed_and_scale_tokens)
245-
else:
246-
x = embed_and_scale_tokens(token_ids)
247-
248-
if use_openvino:
249-
if token_ids.shape[1] == 1:
250-
x, cache = self.ov_infer(
251-
{
252-
"x": x,
253-
"cache": cache,
254-
"cache_update_index": cache_update_index,
255-
},
256-
apply_transformer_layers,
257-
cache=True,
258-
name="apply_transformer_layers",
259-
)
260-
else:
261-
x, cache = self.ov_infer(
262-
{
263-
"x": x,
264-
"cache": cache,
265-
"cache_update_index": cache_update_index,
266-
},
267-
apply_transformer_layers,
268-
)
269-
self._ov_mem["cache"] = cache
270-
else:
271-
x, cache = apply_transformer_layers(
272-
{
273-
"x": x,
274-
"cache": cache,
275-
"cache_update_index": cache_update_index,
276-
}
199+
x = self.backbone.token_embedding(token_ids)
200+
x = x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype)
201+
# Each decoder layer has a cache; we update them separately.
202+
caches = []
203+
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
204+
current_cache = cache[:, i, ...]
205+
x, next_cache = transformer_layer(
206+
x,
207+
cache=current_cache,
208+
cache_update_index=cache_update_index,
277209
)
210+
caches.append(next_cache)
278211

279-
if use_openvino:
280-
if token_ids.shape[1] == 1:
281-
logits, hidden_states = self.ov_infer(
282-
x,
283-
finalize_generation_step,
284-
cache=True,
285-
name="finalize_generation_step",
286-
)
287-
else:
288-
logits, hidden_states = self.ov_infer(
289-
x, finalize_generation_step
290-
)
291-
self._ov_mem["hidden_states"] = hidden_states
292-
else:
293-
logits, hidden_states = finalize_generation_step(x)
294-
212+
cache = ops.stack(caches, axis=1)
213+
hidden_states = x = self.backbone.layer_norm(x)
214+
logits = self.backbone.token_embedding(x, reverse=True)
295215
return logits, hidden_states, cache
296216

297217
def _build_cache(self, token_ids):
@@ -338,6 +258,9 @@ def next(prompt, cache, index):
338258
cache_update_index = index - 1
339259
batch_size = ops.shape(prompt)[0]
340260
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
261+
if keras.config.backend() == "openvino":
262+
# Avoid returning dynamic shape by openvino slice
263+
prompt = ops.reshape(prompt, [batch_size, 1])
341264
logits, hidden_states, cache = self.call_with_cache(
342265
prompt,
343266
cache,

keras_hub/src/models/gemma/gemma_causal_lm_test.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ def test_causal_lm_basics(self):
6464
expected_output_shape=(2, 8, 11),
6565
)
6666

67-
@pytest.mark.skipif(
68-
keras.config.backend() == "openvino",
69-
reason="OpenVINO is for inference only",
70-
)
7167
def test_cache_correctness(self):
7268
token_ids = self.input_data["token_ids"]
7369
padding_mask = ops.ones_like(self.input_data["padding_mask"])
@@ -97,9 +93,6 @@ def test_generate(self):
9793
causal_lm.preprocessor = None
9894
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
9995
# Assert prompt is in output in token id space.
100-
if keras.config.backend() == "openvino":
101-
for k, v in prompt_ids.items():
102-
prompt_ids[k] = ops.convert_to_numpy(v)
10396
self.assertAllEqual(
10497
outputs["token_ids"][:, :4],
10598
prompt_ids["token_ids"][:, :4],
@@ -139,9 +132,6 @@ def test_generate_with_bfloat16(self):
139132
causal_lm.preprocessor = None
140133
outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
141134
# Assert prompt is in output in token id space.
142-
if keras.config.backend() == "openvino":
143-
for k, v in prompt_ids.items():
144-
prompt_ids[k] = ops.convert_to_numpy(v)
145135
self.assertAllEqual(
146136
outputs["token_ids"][:, :4],
147137
prompt_ids["token_ids"][:, :4],
@@ -163,12 +153,6 @@ def wrapper(*args, **kwargs):
163153
"""Modify output logits to always favor end_token_id"""
164154
logits, hidden_states, cache = call_with_cache(*args, **kwargs)
165155
index = self.preprocessor.tokenizer.end_token_id
166-
if keras.config.backend() == "openvino":
167-
"""Set all logits to a large negative number
168-
to avoid NaNs produced by ov.einsum"""
169-
logits = ops.ones_like(logits) * ops.convert_to_tensor(
170-
-1e9, dtype=logits.dtype
171-
)
172156
update = ops.ones_like(logits)[:, :, index] * 1.0e9
173157
update = ops.expand_dims(update, axis=-1)
174158
logits = ops.slice_update(logits, (0, 0, index), update)
@@ -188,12 +172,6 @@ def wrapper(*args, **kwargs):
188172
"""Modify output logits to always favor end_token_id"""
189173
logits, hidden_states, cache = call_with_cache(*args, **kwargs)
190174
index = self.preprocessor.tokenizer.end_token_id
191-
if keras.config.backend() == "openvino":
192-
"""Set all logits to a large negative number
193-
to avoid NaNs produced by ov.einsum"""
194-
logits = ops.ones_like(logits) * ops.convert_to_tensor(
195-
-1e9, dtype=logits.dtype
196-
)
197175
update = ops.ones_like(logits)[:, :, index] * 1.0e9
198176
update = ops.expand_dims(update, axis=-1)
199177
logits = ops.slice_update(logits, (0, 0, index), update)
@@ -237,10 +215,6 @@ def test_all_presets(self):
237215
input_data=self.input_data,
238216
)
239217

240-
@pytest.mark.skipif(
241-
keras.config.backend() == "openvino",
242-
reason="OpenVINO is for inference only",
243-
)
244218
def test_score_logits(self):
245219
# Setup prompts, models, and associated expected shapes.
246220
prompts = ["the quick brown fox", "the quick brown fox"]
@@ -263,10 +237,6 @@ def test_score_logits(self):
263237

264238
self.assertEqual(ops.shape(scores), expected_score_shape)
265239

266-
@pytest.mark.skipif(
267-
keras.config.backend() == "openvino",
268-
reason="OpenVINO is for inference only",
269-
)
270240
def test_score_loss(self):
271241
# Setup prompts, models, and associated expected shapes.
272242
prompts = ["the quick brown fox", "the quick brown fox"]

keras_hub/src/models/mistral/mistral_causal_lm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def next(prompt, cache, index):
145145
cache_update_index = index - 1
146146
batch_size = ops.shape(prompt)[0]
147147
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
148+
if keras.config.backend() == "openvino":
149+
# Avoid returning dynamic shape by openvino slice
150+
prompt = ops.reshape(prompt, [batch_size, 1])
148151
logits, hidden_states, cache = self.call_with_cache(
149152
prompt,
150153
cache,

0 commit comments

Comments
 (0)