1515import enum
1616from copy import deepcopy
1717from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union
18+ import inspect
19+ import re
1820
1921from packaging import version
2022from transformers import AutoConfig , PretrainedConfig , PreTrainedModel , TFPreTrainedModel
9597 LlamaModelPatcher ,
9698 LlavaImageEmbeddingModelPatcher ,
9799 LlavaQwen2ImageEmbeddingsModelPatcher ,
100+ MambaPatcher ,
98101 MiniCPM3Patcher ,
99102 MiniCPMModelPatcher ,
100103 MiniCPMVImageEmbeddingsModelPatcher ,
@@ -2880,3 +2883,126 @@ def patch_model_for_export(
28802883 self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
28812884 ) -> "ModelPatcher" :
28822885 return DeepseekPatcher (self , model , model_kwargs = model_kwargs )
2886+
2887+
2888+ class MambaCacheDummyInputGenerator (DummyInputGenerator ):
2889+ """
2890+ Generates dummy past_key_values inputs for seq2seq architectures.
2891+ """
2892+
2893+ SUPPORTED_INPUT_NAMES = ("past_ssm_states" , "past_conv_states" , "cache_position" )
2894+
2895+ def __init__ (
2896+ self ,
2897+ task : str ,
2898+ normalized_config ,
2899+ batch_size : int = DEFAULT_DUMMY_SHAPES ["batch_size" ],
2900+ sequence_length : int = DEFAULT_DUMMY_SHAPES ["sequence_length" ],
2901+ ** kwargs ,
2902+ ):
2903+ self .normalized_config = normalized_config
2904+ self .batch_size = batch_size
2905+ self .sequence_length = sequence_length
2906+ self .intermediate_size = self .normalized_config .config .intermediate_size
2907+ self .ssm_state_size = self .normalized_config .config .state_size
2908+ self .conv_kernel_size = self .normalized_config .config .conv_kernel
2909+
2910+
2911+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2912+ if input_name == "past_ssm_states" :
2913+ ssm_shape = [self .batch_size , self .intermediate_size , self .ssm_state_size ]
2914+ return [self .random_float_tensor (ssm_shape , framework = framework , dtype = float_dtype ) for _ in range (self .normalized_config .num_layers )]
2915+
2916+ elif input_name == "past_conv_states" :
2917+ conv_shape = [self .batch_size , self .intermediate_size , self .conv_kernel_size ]
2918+ return [self .random_float_tensor (conv_shape , framework = framework , dtype = float_dtype ) for _ in range (self .normalized_config .num_layers )]
2919+
2920+ elif input_name == "cache_position" :
2921+ return self .random_int_tensor (
2922+ shape = [self .conv_kernel_size ],
2923+ max_value = self .sequence_length ,
2924+ framework = framework ,
2925+ dtype = int_dtype ,
2926+ )
2927+
2928+ raise ValueError (f"Unsupported input name { input_name } " )
2929+
2930+ @register_in_tasks_manager (
2931+ "mamba" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers"
2932+ )
2933+ class MambaOpenVINOConfig (TextDecoderOnnxConfig ):
2934+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , MambaCacheDummyInputGenerator )
2935+ DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator
2936+ NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
2937+
2938+ @property
2939+ def inputs (self ) -> Dict [str , Dict [int , str ]]:
2940+ if self .use_past_in_inputs :
2941+ common_inputs = {"input_ids" : {0 : "batch_size" , 1 : "sequence_length" }}
2942+ self .add_past_key_values (common_inputs , direction = "inputs" )
2943+ #common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
2944+ common_inputs ["cache_position" ] = {0 : "cache_sequence_length" }
2945+ else :
2946+ common_inputs = {
2947+ "input_ids" : {0 : "batch_size" , 1 : "sequence_length" },
2948+ #"attention_mask": {0: "batch_size", 1: "sequence_length"},
2949+ "cache_position" : {0 : "cache_sequence_length" }
2950+ }
2951+ return common_inputs
2952+
2953+ def add_past_key_values (self , inputs_or_outputs : Dict [str , Dict [int , str ]], direction : str ):
2954+ """
2955+ Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
2956+
2957+ Args:
2958+ inputs_or_outputs (`Dict[str, Dict[int, str]]`):
2959+ The mapping to fill.
2960+ direction (`str`):
2961+ either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
2962+ output mapping, this is important for axes naming.
2963+ """
2964+ if direction not in ["inputs" , "outputs" ]:
2965+ raise ValueError (f'direction must either be "inputs" or "outputs", but { direction } was given' )
2966+
2967+ if direction == "inputs" :
2968+ ssm_name = "past_ssm_states"
2969+ conv_name = "past_conv_states"
2970+ else :
2971+ ssm_name = "present_ssm_states"
2972+ conv_name = "present_conv_states"
2973+
2974+ for i in range (self ._normalized_config .num_layers ):
2975+ inputs_or_outputs [f"{ ssm_name } .{ i } " ] = {0 : "batch_size" }
2976+
2977+ for i in range (self ._normalized_config .num_layers ):
2978+ inputs_or_outputs [f"{ conv_name } .{ i } " ] = {0 : "batch_size" }
2979+
2980+ def patch_model_for_export (self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None ):
2981+ return MambaPatcher (self , model , model_kwargs )
2982+
2983+ def generate_dummy_inputs (self , framework : str = "pt" , ** kwargs ):
2984+ dummy_inputs_generators = self ._create_dummy_input_generator_classes (** kwargs )
2985+
2986+ dummy_inputs = {}
2987+ input_names = [key for key in self .inputs .keys () if not key .startswith ("past_" )]
2988+ if self .use_past_in_inputs and self .use_cache_branch is not False :
2989+ input_names .extend (["past_ssm_states" , "past_conv_states" ])
2990+
2991+ for input_name in input_names :
2992+ input_was_inserted = False
2993+ for dummy_input_gen in dummy_inputs_generators :
2994+ if dummy_input_gen .supports_input (input_name ):
2995+ dummy_inputs [input_name ] = self .overwrite_shape_and_generate_input (
2996+ dummy_input_gen ,
2997+ input_name ,
2998+ framework ,
2999+ input_shapes = kwargs ,
3000+ )
3001+ input_was_inserted = True
3002+ break
3003+ if not input_was_inserted :
3004+ raise RuntimeError (
3005+ f'Could not generate dummy input for "{ input_name } ". Try adding a proper dummy input generator to the model ONNX config.'
3006+ )
3007+
3008+ return dummy_inputs
0 commit comments