Skip to content

Commit a57c986

Browse files
committed
support mha def in (q,v) format
1 parent 72538ad commit a57c986

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

hls4ml/converters/keras_v3/hgq2/multi_head_attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ def handle(
2626
from hgq.layers import QEinsum
2727
from keras import KerasTensor
2828

29-
assert len(in_tensors) in (3, 4), 'MultiHead layer must have 3 (Q, K, V) or 4 (Q, K, V, M) input tensors'
29+
# fmt: off
30+
assert len(in_tensors) in (2, 3, 4,), (
31+
'MultiHead layer must have 2 (Q, V), 3 (Q, V, K) or 4 (Q, V, K, M) input tensors'
32+
)
33+
# fmt: on
3034
assert len(out_tensors) == 1, 'Attention score output is not supported yet'
31-
assert len(in_tensors) == 3, 'Mask tensor is not supported yet'
35+
assert len(in_tensors) <= 3, 'Mask tensor is not supported yet'
3236
tensor_q, *_ = in_tensors
3337
tensor_O, *tensor_attn = out_tensors
3438
unique_name: str = layer.name
@@ -50,6 +54,8 @@ def handle(
5054
tensor_q = bound.arguments['query']
5155
tensor_k = bound.arguments['key']
5256
tensor_v = bound.arguments['value']
57+
if tensor_k is None:
58+
tensor_k = tensor_v
5359
tensor_q_mask = bound.arguments['query_mask']
5460
tensor_k_mask = bound.arguments['key_mask']
5561
tensor_v_mask = bound.arguments['value_mask']

test/pytest/test_hgq2_mha.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
from hgq.config import QuantizerConfigScope
1111
from hgq.layers import QMultiHeadAttention
12-
from hgq.utils import trace_mode
12+
from hgq.utils import trace_minmax
1313

1414
from hls4ml.converters import convert_from_keras_model
1515

@@ -30,8 +30,7 @@ def test_hgq2_mha(strategy):
3030
data_k = np.random.randn(10000, 12, 7).astype(np.float32) * 3
3131
data = [data_q, data_v, data_k]
3232

33-
with trace_mode(model):
34-
r_keras = model.predict(data, batch_size=1000)
33+
r_keras = trace_minmax(model, data, return_results=True)
3534

3635
model_hls = convert_from_keras_model(
3736
model,

0 commit comments

Comments
 (0)