-
Notifications
You must be signed in to change notification settings - Fork 243
Svdquant huggingface checkpoint export support #754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3bc8bc8
8a2cab2
b7dabf5
55ca49f
9ed7768
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1075,6 +1075,30 @@ def _get_awq_quantizer_block_size(tensor: torch.Tensor, quantizer: TensorQuantiz | |
| return blocksize | ||
|
|
||
|
|
||
| def svd(weight, rank): | ||
| original_device = weight.device | ||
| original_dtype = weight.dtype | ||
| weight_f64 = weight.to(dtype=torch.float64, device=original_device) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need f64?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure. I kept what @jingyu-ml has originally. This part is just a refactoring so that I can reuse this code during qkv fusion.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using FP64 for decomposition is slightly slower, but it is more accurate. |
||
| u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) | ||
| us = u[:, :rank] * s[:rank] | ||
| vt = vt[:rank] | ||
| us = us.to(device=original_device, dtype=original_dtype) | ||
| vt = vt.to(device=original_device, dtype=original_dtype) | ||
| if us.shape[1] < rank or vt.shape[0] < rank: | ||
| warnings.warn( | ||
| "The low-rank dimensions do not match the layer dimensions. " | ||
| "Please verify your configuration and model settings. " | ||
| f"Rank is {us.shape[1]} and {vt.shape[0]}" | ||
| ) | ||
| us_temp = torch.zeros((us.shape[0], rank), dtype=us.dtype, device=us.device) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jingyu-ml I change the logic here. When the rank is lower than expected. I will pad it will zero now instead of not having a lora_a/b. Do you think this is okay? |
||
| vt_temp = torch.zeros((rank, vt.shape[1]), dtype=vt.dtype, device=vt.device) | ||
| us_temp[:, : us.shape[1]] = us | ||
| vt_temp[: vt.shape[0], :] = vt | ||
| us = us_temp | ||
| vt = vt_temp | ||
| return us, vt | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def svdquant( | ||
| model: nn.Module, | ||
|
|
@@ -1096,25 +1120,9 @@ def svdquant( | |
| def postprocess(module, name): | ||
| print_rank_0(f"SVD {name}") | ||
| weight = module.weight.data | ||
| original_device = weight.device | ||
| original_dtype = weight.dtype | ||
| weight_f64 = weight.to(dtype=torch.float64, device=original_device) | ||
| u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) | ||
| if u.shape[1] < lowrank or vt.shape[0] < lowrank: | ||
| warnings.warn( | ||
| "The low-rank dimensions do not match the layer dimensions. " | ||
| "Please verify your configuration and model settings. " | ||
| f"SVD will be skipped for this layer {name}." | ||
| ) | ||
| return | ||
| us = u[:, :lowrank] * s[:lowrank] | ||
| vt = vt[:lowrank] | ||
| module.weight_quantizer.svdquant_lora_a = vt.to( | ||
| dtype=original_dtype, device=original_device | ||
| ) | ||
| module.weight_quantizer.svdquant_lora_b = us.to( | ||
| dtype=original_dtype, device=original_device | ||
| ) | ||
| us, vt = svd(weight, lowrank) | ||
| module.weight_quantizer.svdquant_lora_a = vt | ||
| module.weight_quantizer.svdquant_lora_b = us | ||
| module.weight.data.sub_( | ||
| module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.