Skip to content

Commit

Permalink
0.0.23
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Aug 2, 2024
1 parent 000796b commit b80206b
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 56 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sd-mecha"
version = "0.0.22"
version = "0.0.23"
description = "State dict recipe merger"
readme = "README.md"
authors = [{ name = "ljleb" }]
Expand All @@ -27,6 +27,7 @@ dependencies = [
[tool.setuptools]
packages = [
"sd_mecha",
"sd_mecha.merge_methods",
"sd_mecha.extensions",
"sd_mecha.lora",
"sd_mecha.models",
Expand Down
9 changes: 4 additions & 5 deletions sd_mecha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ def add_difference_ties_extended(
)



def copy_region(
a: RecipeNodeOrPath, b: RecipeNodeOrPath, c: Optional[RecipeNodeOrPath] = None, *,
width: Hyper = 1.0,
Expand Down Expand Up @@ -293,8 +292,8 @@ def copy_region(

def rotate(
a: RecipeNodeOrPath, b: RecipeNodeOrPath, c: Optional[RecipeNodeOrPath] = None, *,
alpha: Hyper = 1.0,
beta: Hyper = 0.0,
alignment: Hyper = 1.0,
alpha: Hyper = 0.0,
device: Optional[str] = "cuda",
dtype: Optional[torch.dtype] = None,
cache: Optional[Dict[str, torch.Tensor]] = None,
Expand All @@ -318,8 +317,8 @@ def rotate(

res = merge_methods.rotate(
a, b,
alignment=alpha,
alpha=beta,
alignment=alignment,
alpha=alpha,
cache=cache,
device=device,
dtype=dtype,
Expand Down
53 changes: 3 additions & 50 deletions sd_mecha/merge_methods.py → sd_mecha/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor
from typing import Tuple, TypeVar, Dict, Optional
from sd_mecha.hypers import Hyper
from .svd import orthogonal_procrustes, fractional_matrix_power
from sd_mecha.merge_space import MergeSpace
from sd_mecha.extensions.merge_method import LiftFlag, convert_to_recipe

Expand All @@ -25,15 +26,15 @@ def weighted_sum(
) -> Tensor | SameMergeSpace:
return (1 - alpha) * a + alpha * b

# Isotropic merge / Uniform Soup / Uniform Merge... you name it.
# Instead of running average, this may run faster.

@convert_to_recipe
def n_average(
*models: Tensor | SameMergeSpace,
**kwargs,
) -> Tensor | SameMergeSpace:
return torch.mean(torch.stack(models), dim=0)


@convert_to_recipe
def slerp(
a: Tensor | SameMergeSpace,
Expand Down Expand Up @@ -566,54 +567,6 @@ def rotate(
return a_neurons.reshape_as(a)


def orthogonal_procrustes(a, b, cancel_reflection: bool = False):
svd_driver = "gesvd" if a.is_cuda else None
u, _, v_t = torch.linalg.svd(a.T @ b, driver=svd_driver)

if cancel_reflection:
u[:, -1] /= torch.det(u) * torch.det(v_t)

transform = u @ v_t
if not torch.isfinite(u).all():
raise ValueError(
f"determinant error: {torch.det(transform)}. "
'This can happen when merging on the CPU with the "rotate" method. '
"Consider merging on a cuda device, "
"or try setting alpha to 1 for the problematic blocks. "
"See this related discussion for more info: "
"https://github.com/s1dlx/meh/pull/50#discussion_r1429469484"
)

return transform


def fractional_matrix_power(matrix: Tensor, power: float, cache: Optional[Dict[str, Tensor]] = None):
if cache is not None and "eigenvalues" in cache:
complex_dtype = torch_complex_dtype_map[matrix.dtype]
eigenvalues = cache["eigenvalues"].to(matrix.device, complex_dtype)
eigenvectors = cache["eigenvectors"].to(matrix.device, complex_dtype)
eigenvectors_inv = cache["eigenvectors_inv"].to(matrix.device, complex_dtype)
else:
eigenvalues, eigenvectors = torch.linalg.eig(matrix)
eigenvectors_inv = torch.linalg.inv(eigenvectors)
if cache is not None:
cache["eigenvalues"] = eigenvalues.to("cpu", torch.complex32)
cache["eigenvectors"] = eigenvectors.to("cpu", torch.complex32)
cache["eigenvectors_inv"] = eigenvectors_inv.to("cpu", torch.complex32)

eigenvalues.pow_(power)
result = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors_inv
return result.real.to(dtype=matrix.dtype)


torch_complex_dtype_map = {
torch.bfloat16: torch.complex32,
torch.float16: torch.complex32,
torch.float32: torch.complex64,
torch.float64: torch.complex128,
}


@convert_to_recipe
def clamp(
a: Tensor | SameMergeSpace,
Expand Down
231 changes: 231 additions & 0 deletions sd_mecha/merge_methods/svd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import torch
from torch import Tensor
from typing import Optional, Tuple, Dict


def orthogonal_procrustes(a, b, cancel_reflection: bool = False):
svd_driver = "gesvd" if a.is_cuda else None
if a.shape[0] + 10 < a.shape[1]:
u, _, v = torch_svd_lowrank(a.T @ b, driver=svd_driver, q=a.shape[0] + 10)
v_t = v.T
del v
else:
u, _, v_t = torch.linalg.svd(a.T @ b, driver=svd_driver)

if cancel_reflection:
u[:, -1] /= torch.det(u) * torch.det(v_t)

transform = u @ v_t
if not torch.isfinite(u).all():
raise ValueError(
f"determinant error: {torch.det(transform)}. "
'This can happen when merging on the CPU with the "rotate" method. '
"Consider merging on a cuda device, "
"or try setting alpha to 1 for the problematic blocks. "
"See this related discussion for more info: "
"https://github.com/s1dlx/meh/pull/50#discussion_r1429469484"
)

return transform


def fractional_matrix_power(matrix: Tensor, power: float, cache: Optional[Dict[str, Tensor]] = None):
if cache is not None and "eigenvalues" in cache:
complex_dtype = torch_complex_dtype_map[matrix.dtype]
eigenvalues = cache["eigenvalues"].to(matrix.device, complex_dtype)
eigenvectors = cache["eigenvectors"].to(matrix.device, complex_dtype)
eigenvectors_inv = cache["eigenvectors_inv"].to(matrix.device, complex_dtype)
else:
eigenvalues, eigenvectors = torch.linalg.eig(matrix)
eigenvectors_inv = torch.linalg.inv(eigenvectors)
if cache is not None:
cache["eigenvalues"] = eigenvalues.to("cpu", torch.complex32)
cache["eigenvectors"] = eigenvectors.to("cpu", torch.complex32)
cache["eigenvectors_inv"] = eigenvectors_inv.to("cpu", torch.complex32)

eigenvalues.pow_(power)
result = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors_inv
return result.real.to(dtype=matrix.dtype)


torch_complex_dtype_map = {
torch.bfloat16: torch.complex32,
torch.float16: torch.complex32,
torch.float32: torch.complex64,
torch.float64: torch.complex128,
}


# need to redefine torch.svd_lowrank to specify the svd driver
def torch_svd_lowrank(
A: Tensor,
q: Optional[int] = 6,
niter: Optional[int] = 2,
M: Optional[Tensor] = None,
driver: Optional[str] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
q = 6 if q is None else q
m, n = A.shape[-2:]
if M is None:
M_t = None
else:
M_t = transpose(M)
A_t = transpose(A)

# Algorithm 5.1 in Halko et al 2009, slightly modified to reduce
# the number conjugate and transpose operations
if m < n or n > q:
# computing the SVD approximation of a transpose in
# order to keep B shape minimal (the m < n case) or the V
# shape small (the n > q case)
Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
Q_c = conjugate(Q)
if M is None:
B_t = matmul(A, Q_c)
else:
B_t = matmul(A, Q_c) - matmul(M, Q_c)
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = torch.linalg.svd(B_t, driver=driver, full_matrices=False)
V = Vh.mH
V = Q.matmul(V)
else:
Q = get_approximate_basis(A, q, niter=niter, M=M)
Q_c = conjugate(Q)
if M is None:
B = matmul(A_t, Q_c)
else:
B = matmul(A_t, Q_c) - matmul(M_t, Q_c)
B_t = transpose(B)
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = torch.linalg.svd(B_t, driver=driver, full_matrices=False)
V = Vh.mH
U = Q.matmul(U)

return U, S, V


def get_approximate_basis(A: Tensor, q: int, niter: Optional[int] = 2, M: Optional[Tensor] = None) -> Tensor:
"""Return tensor :math:`Q` with :math:`q` orthonormal columns such
that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
approximates :math:`A - M`.
.. note:: The implementation is based on the Algorithm 4.4 from
Halko et al, 2009.
.. note:: For an adequate approximation of a k-rank matrix
:math:`A`, where k is not known in advance but could be
estimated, the number of :math:`Q` columns, q, can be
choosen according to the following criteria: in general,
:math:`k <= q <= min(2*k, m, n)`. For large low-rank
matrices, take :math:`q = k + 5..10`. If k is
relatively small compared to :math:`min(m, n)`, choosing
:math:`q = k + 0..2` may be sufficient.
.. note:: To obtain repeatable results, reset the seed for the
pseudorandom number generator
Args::
A (Tensor): the input tensor of size :math:`(*, m, n)`
q (int): the dimension of subspace spanned by :math:`Q`
columns.
niter (int, optional): the number of subspace iterations to
conduct; ``niter`` must be a
nonnegative integer. In most cases, the
default value 2 is more than enough.
M (Tensor, optional): the input tensor's mean of size
:math:`(*, 1, n)`.
References::
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
structure with randomness: probabilistic algorithms for
constructing approximate matrix decompositions,
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
`arXiv <http://arxiv.org/abs/0909.4061>`_).
"""

niter = 2 if niter is None else niter
m, n = A.shape[-2:]
dtype = get_floating_dtype(A)

R = torch.randn(n, q, dtype=dtype, device=A.device)

# The following code could be made faster using torch.geqrf + torch.ormqr
# but geqrf is not differentiable
A_H = transjugate(A)
if M is None:
Q = torch.linalg.qr(matmul(A, R)).Q
for i in range(niter):
Q = torch.linalg.qr(matmul(A_H, Q)).Q
Q = torch.linalg.qr(matmul(A, Q)).Q
else:
M_H = transjugate(M)
Q = torch.linalg.qr(matmul(A, R) - matmul(M, R)).Q
for i in range(niter):
Q = torch.linalg.qr(matmul(A_H, Q) - matmul(M_H, Q)).Q
Q = torch.linalg.qr(matmul(A, Q) - matmul(M, Q)).Q

return Q


def transjugate(A):
"""Return transpose conjugate of a matrix or batches of matrices."""
return conjugate(transpose(A))


def conjugate(A):
"""Return conjugate of tensor A.
.. note:: If A's dtype is not complex, A is returned.
"""
if A.is_complex():
return A.conj()
return A


def transpose(A):
"""Return transpose of a matrix or batches of matrices."""
ndim = len(A.shape)
return A.transpose(ndim - 1, ndim - 2)


def matmul(A: Optional[Tensor], B: Tensor) -> Tensor:
"""Multiply two matrices.
If A is None, return B. A can be sparse or dense. B is always
dense.
"""
if A is None:
return B
if is_sparse(A):
return torch.sparse.mm(A, B)
return torch.matmul(A, B)


def is_sparse(A):
"""Check if tensor A is a sparse tensor"""
if isinstance(A, torch.Tensor):
return A.layout == torch.sparse_coo

error_str = "expected Tensor"
if not torch.jit.is_scripting():
error_str += f" but got {type(A)}"
raise TypeError(error_str)


def get_floating_dtype(A):
"""Return the floating point dtype of tensor A.
Integer types map to float32.
"""
dtype = A.dtype
if dtype in (torch.float16, torch.float32, torch.float64):
return dtype
return torch.float32

0 comments on commit b80206b

Please sign in to comment.