Skip to content

add linformer parser #1360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions hls4ml/converters/keras_v3/hgq2/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ def handle(
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
from hgq.layers import QEinsum
from keras import KerasTensor

# fmt: off
assert len(in_tensors) in (2, 3, 4,), (
'MultiHead layer must have 2 (Q, V), 3 (Q, V, K) or 4 (Q, V, K, M) input tensors'
Expand All @@ -35,7 +32,6 @@ def handle(
assert len(in_tensors) <= 3, 'Mask tensor is not supported yet'
tensor_q, *_ = in_tensors
tensor_O, *tensor_attn = out_tensors
unique_name: str = layer.name

node_index: int = tensor_q._keras_history.node_index # type: ignore
assert all(
Expand Down Expand Up @@ -72,6 +68,12 @@ def handle(
)
assert n_mask_def <= 1, f'Layer {layer.name} has {n_mask_def} masks defined, expected at most 1'

return self._handle(layer, tensor_q, tensor_O, node_index, tensor_k, tensor_v)

def _handle(self, layer, tensor_q, tensor_O, node_index, tensor_k, tensor_v):
from hgq.layers import QEinsum
from keras import KerasTensor

unique_name = f'{layer.name}_{node_index}'
to_Q = layer.query_dense
to_K = layer.key_dense
Expand Down Expand Up @@ -123,3 +125,72 @@ def handle(
for conf in configs:
conf['name'] = f'{layer.name}_{conf["name"]}'
return configs


@register
class QLinformerAttentionHandler(QMultiHeadAttentionHandler):
handles = ('hgq.layers.linformer_attention.QLinformerAttention',)

def handle(
self,
layer: 'hgq.layers.linformer_attention.QLinformerAttention',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
from keras import KerasTensor

# fmt: off
assert len(in_tensors) in (2, 3, 4,), (
'MultiHead layer must have 2 (Q, V), 3 (Q, V, K) or 4 (Q, V, K, M) input tensors'
)
# fmt: on
assert len(out_tensors) == 1, 'Attention score output is not supported yet'
assert len(in_tensors) <= 3, 'Mask tensor is not supported yet'
tensor_q, *_ = in_tensors
tensor_O, *tensor_attn = out_tensors
unique_name: str = layer.name

node_index: int = tensor_q._keras_history.node_index # type: ignore
assert all(
[node_index == inp._keras_history.node_index for inp in layer.input[1:]]
), f'Critical error handling layer {layer.name}'
node = layer._inbound_nodes[node_index]

args = node.arguments.args
kwargs = node.arguments.kwargs
sig: Signature = layer._call_signature

# map everything to kwargs
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

tensor_q = bound.arguments['query']
tensor_k = bound.arguments['key']
tensor_v = bound.arguments['value']
if tensor_k is None:
tensor_k = tensor_v
tensor_q_mask = bound.arguments['query_mask']
tensor_k_mask = bound.arguments['key_mask']
tensor_v_mask = bound.arguments['value_mask']
tensor_attn_mask = bound.arguments['attention_mask']
return_scores = bound.arguments['return_attention_scores'] # noqa: F841

n_mask_def = np.sum(
[
tensor_q_mask is not None,
tensor_k_mask is not None,
tensor_v_mask is not None,
tensor_attn_mask is not None,
]
)
assert n_mask_def <= 1, f'Layer {layer.name} has {n_mask_def} masks defined, expected at most 1'

tensor_k_proj = KerasTensor(layer._key_shape_proj, name=f'{unique_name}_k_proj')
tensor_v_proj = KerasTensor(layer._value_shape_proj, name=f'{unique_name}_v_proj')
einsum_dense_handler = QEinsumDenseHandler()

k_proj = einsum_dense_handler(layer._lin_k_proj, [tensor_k], [tensor_k_proj])
v_proj = einsum_dense_handler(layer._lin_v_proj, [tensor_v], [tensor_v_proj])

layers = self._handle(layer, tensor_q, tensor_O, node_index, tensor_k_proj, tensor_v_proj)
return *k_proj, *v_proj, *layers
Loading