@@ -63,6 +63,25 @@ def get_lora_parameters(proj):
63
63
pass
64
64
65
65
66
+ def get_lora_parameters_bias (proj ):
67
+ # For DPO or disabled adapters
68
+ base_layer = (proj .base_layer if hasattr (proj , "base_layer" ) else proj )
69
+ W = base_layer .weight
70
+ bias = base_layer .bias
71
+
72
+ if not hasattr (proj , "disable_adapters" ) or proj .disable_adapters or proj .merged :
73
+ return W , QUANT_STATE (W ), None , None , None , bias
74
+ pass
75
+
76
+ active_adapter = proj .active_adapters [0 ] if \
77
+ hasattr (proj , "active_adapters" ) else proj .active_adapter
78
+ A = proj .lora_A [active_adapter ].weight
79
+ B = proj .lora_B [active_adapter ].weight
80
+ s = proj .scaling [active_adapter ]
81
+ return W , QUANT_STATE (W ), A , B , s , bias
82
+ pass
83
+
84
+
66
85
def fast_dequantize (W , quant_state = None , out = None ):
67
86
if quant_state is None : return W
68
87
if type (quant_state ) is not list :
@@ -181,7 +200,7 @@ def fast_gemv(X, W, quant_state, out = None):
181
200
182
201
def fast_linear_forward (proj , X , temp_lora = None , out = None ):
183
202
184
- W , W_quant , lora_A , lora_B , lora_S = get_lora_parameters (proj )
203
+ W , W_quant , lora_A , lora_B , lora_S , bias = get_lora_parameters_bias (proj )
185
204
bsz , q_len , in_dim = X .shape
186
205
if q_len != 1 : return matmul_lora (X , W , W_quant , lora_A , lora_B , lora_S )
187
206
@@ -216,6 +235,8 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
216
235
out = out .view (bsz , 1 , out_dim )
217
236
pass
218
237
238
+ if bias is not None : out += bias
239
+
219
240
return out
220
241
pass
221
242
0 commit comments