@@ -223,21 +223,25 @@ def embedding(*args, **kwargs):
223
223
full_kwargs = args_mapper .get_full_kwargs (args , kwargs )
224
224
_ , embedding_dim = full_kwargs ["weight" ].size ()
225
225
idx = full_kwargs ["indices" ]
226
- idx = idx .type (torch .int )
227
- B , T = idx .size ()
228
-
229
- idx = torch .reshape (idx , (B * T ,))
230
-
231
- builder = StableHLOCompositeBuilder ("odml.embedding_lookup" )
232
- full_kwargs ["indices" ], full_kwargs ["weight" ] = builder .mark_inputs (
233
- idx ,
234
- full_kwargs ["weight" ],
235
- )
236
- output = op (** full_kwargs )
237
- output = builder .mark_outputs (output )
238
-
239
- output = torch .reshape (output , (B , T , embedding_dim ))
240
- return output
226
+ # TODO(b/356458830): Handle relative positional encoding
227
+ if len (idx .size ()) == 2 :
228
+ idx = idx .type (torch .int )
229
+ B , T = idx .size ()
230
+
231
+ idx = torch .reshape (idx , (B * T ,))
232
+
233
+ builder = StableHLOCompositeBuilder ("odml.embedding_lookup" )
234
+ full_kwargs ["indices" ], full_kwargs ["weight" ] = builder .mark_inputs (
235
+ idx ,
236
+ full_kwargs ["weight" ],
237
+ )
238
+ output = op (** full_kwargs )
239
+ output = builder .mark_outputs (output )
240
+
241
+ output = torch .reshape (output , (B , T , embedding_dim ))
242
+ return output
243
+ else :
244
+ return op (** full_kwargs )
241
245
242
246
node .target = embedding
243
247
0 commit comments