Skip to content

Commit d5d1dd6

Browse files
authored
Disable embedding composite for non-2D embedding tables (google-ai-edge#113)
1 parent 1fad116 commit d5d1dd6

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,25 @@ def embedding(*args, **kwargs):
223223
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
224224
_, embedding_dim = full_kwargs["weight"].size()
225225
idx = full_kwargs["indices"]
226-
idx = idx.type(torch.int)
227-
B, T = idx.size()
228-
229-
idx = torch.reshape(idx, (B * T,))
230-
231-
builder = StableHLOCompositeBuilder("odml.embedding_lookup")
232-
full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
233-
idx,
234-
full_kwargs["weight"],
235-
)
236-
output = op(**full_kwargs)
237-
output = builder.mark_outputs(output)
238-
239-
output = torch.reshape(output, (B, T, embedding_dim))
240-
return output
226+
# TODO(b/356458830): Handle relative positional encoding
227+
if len(idx.size()) == 2:
228+
idx = idx.type(torch.int)
229+
B, T = idx.size()
230+
231+
idx = torch.reshape(idx, (B * T,))
232+
233+
builder = StableHLOCompositeBuilder("odml.embedding_lookup")
234+
full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
235+
idx,
236+
full_kwargs["weight"],
237+
)
238+
output = op(**full_kwargs)
239+
output = builder.mark_outputs(output)
240+
241+
output = torch.reshape(output, (B, T, embedding_dim))
242+
return output
243+
else:
244+
return op(**full_kwargs)
241245

242246
node.target = embedding
243247

0 commit comments

Comments
 (0)