@@ -268,7 +268,7 @@ class SingletonInputsAdapter:
268
268
def prompt (self ) -> Optional [str ]:
269
269
inputs = self .inputs
270
270
271
- if is_token_inputs ( inputs ) or is_multimodal_inputs ( inputs ) :
271
+ if inputs [ "type" ] == "token" or inputs [ "type" ] == "multimodal" :
272
272
return inputs .get ("prompt" )
273
273
274
274
assert_never (inputs ) # type: ignore[arg-type]
@@ -277,7 +277,7 @@ def prompt(self) -> Optional[str]:
277
277
def prompt_token_ids (self ) -> List [int ]:
278
278
inputs = self .inputs
279
279
280
- if is_token_inputs ( inputs ) or is_multimodal_inputs ( inputs ) :
280
+ if inputs [ "type" ] == "token" or inputs [ "type" ] == "multimodal" :
281
281
return inputs .get ("prompt_token_ids" , [])
282
282
283
283
assert_never (inputs ) # type: ignore[arg-type]
@@ -286,7 +286,7 @@ def prompt_token_ids(self) -> List[int]:
286
286
def token_type_ids (self ) -> List [int ]:
287
287
inputs = self .inputs
288
288
289
- if is_token_inputs ( inputs ) or is_multimodal_inputs ( inputs ) :
289
+ if inputs [ "type" ] == "token" or inputs [ "type" ] == "multimodal" :
290
290
return inputs .get ("token_type_ids" , [])
291
291
292
292
assert_never (inputs ) # type: ignore[arg-type]
@@ -295,7 +295,7 @@ def token_type_ids(self) -> List[int]:
295
295
def prompt_embeds (self ) -> Optional [torch .Tensor ]:
296
296
inputs = self .inputs
297
297
298
- if is_token_inputs ( inputs ) or is_multimodal_inputs ( inputs ) :
298
+ if inputs [ "type" ] == "token" or inputs [ "type" ] == "multimodal" :
299
299
return None
300
300
301
301
assert_never (inputs ) # type: ignore[arg-type]
@@ -304,9 +304,10 @@ def prompt_embeds(self) -> Optional[torch.Tensor]:
304
304
def multi_modal_data (self ) -> "MultiModalDataDict" :
305
305
inputs = self .inputs
306
306
307
- if is_token_inputs ( inputs ) :
307
+ if inputs [ "type" ] == "token" :
308
308
return inputs .get ("multi_modal_data" , {})
309
- elif is_multimodal_inputs (inputs ):
309
+
310
+ if inputs ["type" ] == "multimodal" :
310
311
return inputs .get ("mm_kwargs" , {})
311
312
312
313
assert_never (inputs ) # type: ignore[arg-type]
@@ -315,9 +316,10 @@ def multi_modal_data(self) -> "MultiModalDataDict":
315
316
def multi_modal_inputs (self ) -> Union [Dict , "MultiModalKwargs" ]:
316
317
inputs = self .inputs
317
318
318
- if is_token_inputs ( inputs ) :
319
+ if inputs [ "type" ] == "token" :
319
320
return inputs .get ("multi_modal_inputs" , {})
320
- elif is_multimodal_inputs (inputs ):
321
+
322
+ if inputs ["type" ] == "multimodal" :
321
323
return inputs .get ("mm_kwargs" , {})
322
324
323
325
assert_never (inputs ) # type: ignore[arg-type]
@@ -329,6 +331,7 @@ def multi_modal_hashes(self) -> List[str]:
329
331
if is_token_inputs (inputs ):
330
332
return inputs .get ("multi_modal_hashes" , [])
331
333
elif is_multimodal_inputs (inputs ):
334
+ # only the case when we use MultiModalInputsV2
332
335
return inputs .get ("mm_hashes" , [])
333
336
334
337
assert_never (inputs ) # type: ignore[arg-type]
@@ -337,9 +340,10 @@ def multi_modal_hashes(self) -> List[str]:
337
340
def multi_modal_placeholders (self ) -> "MultiModalPlaceholderDict" :
338
341
inputs = self .inputs
339
342
340
- if is_token_inputs ( inputs ) :
343
+ if inputs [ "type" ] == "token" :
341
344
return inputs .get ("multi_modal_placeholders" , {})
342
- elif is_multimodal_inputs (inputs ):
345
+
346
+ if inputs ["type" ] == "multimodal" :
343
347
return inputs .get ("mm_placeholders" , {})
344
348
345
349
assert_never (inputs ) # type: ignore[arg-type]
@@ -348,9 +352,10 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
348
352
def mm_processor_kwargs (self ) -> Dict [str , Any ]:
349
353
inputs = self .inputs
350
354
351
- if is_token_inputs ( inputs ) :
355
+ if inputs [ "type" ] == "token" :
352
356
return inputs .get ("mm_processor_kwargs" , {})
353
- elif is_multimodal_inputs (inputs ):
357
+
358
+ if inputs ["type" ] == "multimodal" :
354
359
return {}
355
360
356
361
assert_never (inputs ) # type: ignore[arg-type]
0 commit comments