|
15 | 15 |
|
16 | 16 | import copy
|
17 | 17 | import functools
|
| 18 | +from functools import reduce |
18 | 19 | from typing import Any, Callable
|
19 | 20 |
|
20 | 21 | from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
@@ -228,25 +229,25 @@ def embedding(*args, **kwargs):
|
228 | 229 | full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
229 | 230 | _, embedding_dim = full_kwargs["weight"].size()
|
230 | 231 | idx = full_kwargs["indices"]
|
231 |
| - # TODO(b/356458830): Handle relative positional encoding |
232 |
| - if len(idx.size()) == 2: |
233 |
| - idx = idx.type(torch.int) |
234 |
| - B, T = idx.size() |
235 |
| - |
236 |
| - idx = torch.reshape(idx, (B * T,)) |
237 |
| - |
238 |
| - builder = StableHLOCompositeBuilder("odml.embedding_lookup") |
239 |
| - full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs( |
240 |
| - idx, |
241 |
| - full_kwargs["weight"], |
242 |
| - ) |
243 |
| - output = op(**full_kwargs) |
244 |
| - output = builder.mark_outputs(output) |
245 |
| - |
246 |
| - output = torch.reshape(output, (B, T, embedding_dim)) |
247 |
| - return output |
248 |
| - else: |
249 |
| - return op(**full_kwargs) |
| 232 | + |
| 233 | + # Explicitly cast to INT32. This places the CastOp outside of the HLFB. |
| 234 | + idx = idx.type(torch.int) |
| 235 | + original_idx_shape = idx.size() |
| 236 | + |
| 237 | + # Explicitly reshape to 1D. This places the ReshapeOp outside of the HLFB. |
| 238 | + idx = torch.reshape(idx, (idx.numel(),)) |
| 239 | + |
| 240 | + builder = StableHLOCompositeBuilder("odml.embedding_lookup") |
| 241 | + full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs( |
| 242 | + idx, |
| 243 | + full_kwargs["weight"], |
| 244 | + ) |
| 245 | + output = op(**full_kwargs) |
| 246 | + output = builder.mark_outputs(output) |
| 247 | + |
| 248 | + # Explicitly reshape back to the original shape. This places the ReshapeOp outside of the HLFB. |
| 249 | + output = torch.reshape(output, (*(original_idx_shape), embedding_dim)) |
| 250 | + return output |
250 | 251 |
|
251 | 252 | node.target = embedding
|
252 | 253 |
|
|
0 commit comments