@@ -4510,6 +4510,75 @@ def mamba_mixer_forward(
45104510 return contextualized_states
45114511
45124512
4513+ def falcon_mamba_mixer_forward (
4514+ self ,
4515+ input_states ,
4516+ cache_params = None ,
4517+ cache_position : Optional [torch .LongTensor ] = None ,
4518+ attention_mask : Optional [torch .LongTensor ] = None ,
4519+ ):
4520+ from transformers .models .falcon_mamba .modeling_falcon_mamba import rms_forward
4521+
4522+ batch_size , seq_len , _ = input_states .shape
4523+ dtype = input_states .dtype
4524+ # 1. Gated MLP's linear projection
4525+ projected_states = self .in_proj (input_states ).transpose (1 , 2 ) # [batch, 2 * intermediate_size, seq_len]
4526+ hidden_states , gate = projected_states .chunk (2 , dim = 1 )
4527+
4528+ if attention_mask is not None :
4529+ hidden_states = hidden_states * attention_mask .unsqueeze (1 )
4530+
4531+ # 2. Convolution sequence transformation
4532+ if cache_params is not None :
4533+ ssm_state = cache_params .ssm_states [self .layer_idx ].clone ()
4534+ ssm_state = ssm_state .to (hidden_states .device )
4535+ # use `cache_position.shape[0]` to check whether we are in prefill
4536+ # stage, it's equivalent to check `cache_position[0] == 0`, which
4537+ # breaks dynamo fullgraph constraints
4538+ hidden_states , conv_state = self .conv_sequence_transform (
4539+ hidden_states , cache_position , cache_params .conv_states [self .layer_idx ]
4540+ )
4541+ cache_params .conv_states [self .layer_idx ] = conv_state
4542+ else :
4543+ ssm_state = torch .zeros (
4544+ (batch_size , self .intermediate_size , self .ssm_state_size ), device = hidden_states .device , dtype = dtype
4545+ )
4546+ hidden_states = self .act (self .conv1d (hidden_states )[..., :seq_len ]) # [batch, intermediate_size, seq_len]
4547+
4548+ if attention_mask is not None :
4549+ hidden_states = hidden_states * attention_mask .unsqueeze (1 )
4550+
4551+ # 3. State Space Model sequence transformation
4552+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
4553+ ssm_parameters = self .x_proj (hidden_states .transpose (1 , 2 ))
4554+ time_step , B , C = torch .split (
4555+ ssm_parameters , [self .time_step_rank , self .ssm_state_size , self .ssm_state_size ], dim = - 1
4556+ )
4557+
4558+ B = rms_forward (B , variance_epsilon = self .rms_eps )
4559+ C = rms_forward (C , variance_epsilon = self .rms_eps )
4560+ time_step = rms_forward (time_step , variance_epsilon = self .rms_eps )
4561+ discrete_time_step = self .dt_proj (time_step ) # [batch, seq_len, intermediate_size]
4562+
4563+ discrete_time_step = torch .nn .functional .softplus (discrete_time_step ) # [batch, intermediate_size, seq_len]
4564+ A = - torch .exp (self .A_log .float ())
4565+ B = B .float ()
4566+ D = self .D .float ()
4567+
4568+ scan_output , ssm_state = self .selective_scan (
4569+ ssm_state , hidden_states .float ().transpose (1 , 2 ), discrete_time_step , A , B , C , D
4570+ )
4571+ scan_output = scan_output .transpose (1 , 2 )
4572+ scan_output = scan_output * self .act (gate )
4573+
4574+ if cache_params is not None :
4575+ cache_params .ssm_states [self .layer_idx ].copy_ (ssm_state )
4576+
4577+ # 4. Final linear projection
4578+ contextualized_states = self .out_proj (scan_output .transpose (1 , 2 )) # [batch, seq_len, hidden_size]
4579+ return contextualized_states
4580+
4581+
45134582class MambaPatcher (ModelPatcher ):
45144583 def __init__ (
45154584 self ,
@@ -4684,3 +4753,22 @@ def __exit__(self, exc_type, exc_value, traceback):
46844753 self ._model .forward = self ._model .__orig_forward
46854754 for layer in self ._model .backbone .layers :
46864755 layer .mixer .forward = layer .mixer ._orig_forward
4756+
4757+
4758+ class FalconMambaPatcher (MambaPatcher ):
4759+ def __enter__ (self ):
4760+ super ().__enter__ ()
4761+ selective_scan = SelectiveScan ()
4762+
4763+ for layer in self ._model .backbone .layers :
4764+ layer .mixer .selective_scan = selective_scan
4765+ layer .mixer ._orig_forward = layer .mixer .forward
4766+ layer .mixer .forward = types .MethodType (falcon_mamba_mixer_forward , layer .mixer )
4767+ conv_transform = ConvSequenceTransform (
4768+ layer .mixer .conv_kernel_size ,
4769+ layer .mixer .use_conv_bias ,
4770+ layer .mixer .conv1d ,
4771+ layer .mixer .act ,
4772+ layer .mixer .conv1d .bias ,
4773+ )
4774+ layer .mixer .conv_sequence_transform = torch .jit .script (conv_transform )
0 commit comments