Skip to content

Commit c6eb22a

Browse files
authored
Handle ND EmbeddingLookup indices (google-ai-edge#117)
* Handle ND EmbeddingLookup indices * address comments * address comments
1 parent 6bb40ca commit c6eb22a

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import copy
1717
import functools
18+
from functools import reduce
1819
from typing import Any, Callable
1920

2021
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
@@ -228,25 +229,25 @@ def embedding(*args, **kwargs):
228229
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
229230
_, embedding_dim = full_kwargs["weight"].size()
230231
idx = full_kwargs["indices"]
231-
# TODO(b/356458830): Handle relative positional encoding
232-
if len(idx.size()) == 2:
233-
idx = idx.type(torch.int)
234-
B, T = idx.size()
235-
236-
idx = torch.reshape(idx, (B * T,))
237-
238-
builder = StableHLOCompositeBuilder("odml.embedding_lookup")
239-
full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
240-
idx,
241-
full_kwargs["weight"],
242-
)
243-
output = op(**full_kwargs)
244-
output = builder.mark_outputs(output)
245-
246-
output = torch.reshape(output, (B, T, embedding_dim))
247-
return output
248-
else:
249-
return op(**full_kwargs)
232+
233+
# Explicitly cast to INT32. This places the CastOp outside of the HLFB.
234+
idx = idx.type(torch.int)
235+
original_idx_shape = idx.size()
236+
237+
# Explicitly reshape to 1D. This places the ReshapeOp outside of the HLFB.
238+
idx = torch.reshape(idx, (idx.numel(),))
239+
240+
builder = StableHLOCompositeBuilder("odml.embedding_lookup")
241+
full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
242+
idx,
243+
full_kwargs["weight"],
244+
)
245+
output = op(**full_kwargs)
246+
output = builder.mark_outputs(output)
247+
248+
# Explicitly reshape back to the original shape. This places the ReshapeOp outside of the HLFB.
249+
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
250+
return output
250251

251252
node.target = embedding
252253

0 commit comments

Comments
 (0)