forked from happy-jihye/Cartoon-StyleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclosed_form_factorization.py
36 lines (26 loc) · 979 Bytes
/
closed_form_factorization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import argparse
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extract factor/eigenvectors of latent spaces using closed form factorization"
)
parser.add_argument(
"--factor_name", type=str, default="factor.pt", help="name of the result factor file"
)
parser.add_argument("ckpt", type=str, help="name of the model checkpoint")
args = parser.parse_args()
# -----------------------------------
# Make Eigenvector of Latent spaces
# -----------------------------------
ckpt = torch.load(args.ckpt)
modulate = {
k: v
for k, v in ckpt["g_ema"].items()
if "modulation" in k and "to_rgbs" not in k and "weight" in k
}
weight_mat = []
for k, v in modulate.items():
weight_mat.append(v)
W = torch.cat(weight_mat, 0)
eigvec = torch.svd(W).V.to("cpu")
torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, f"{args.factor_name}")