Skip to content

Commit c1e1646

Browse files
committed
Update utils.py
1 parent 471565f commit c1e1646

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

unsloth/kernels/utils.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,25 @@ def get_lora_parameters(proj):
6363
pass
6464

6565

66+
def get_lora_parameters_bias(proj):
67+
# For DPO or disabled adapters
68+
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
69+
W = base_layer.weight
70+
bias = base_layer.bias
71+
72+
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
73+
return W, QUANT_STATE(W), None, None, None, bias
74+
pass
75+
76+
active_adapter = proj.active_adapters[0] if \
77+
hasattr(proj, "active_adapters") else proj.active_adapter
78+
A = proj.lora_A [active_adapter].weight
79+
B = proj.lora_B [active_adapter].weight
80+
s = proj.scaling[active_adapter]
81+
return W, QUANT_STATE(W), A, B, s, bias
82+
pass
83+
84+
6685
def fast_dequantize(W, quant_state = None, out = None):
6786
if quant_state is None: return W
6887
if type(quant_state) is not list:
@@ -181,7 +200,7 @@ def fast_gemv(X, W, quant_state, out = None):
181200

182201
def fast_linear_forward(proj, X, temp_lora = None, out = None):
183202

184-
W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj)
203+
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
185204
bsz, q_len, in_dim = X.shape
186205
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
187206

@@ -216,6 +235,8 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
216235
out = out.view(bsz, 1, out_dim)
217236
pass
218237

238+
if bias is not None: out += bias
239+
219240
return out
220241
pass
221242

0 commit comments

Comments
 (0)