|
| 1 | +import re |
| 2 | + |
| 3 | +import jax.numpy as jnp |
| 4 | +from flax.traverse_util import flatten_dict, unflatten_dict |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from modeling_flax_vqgan import VQModel |
| 9 | +from configuration_vqgan import VQGANConfig |
| 10 | + |
| 11 | + |
| 12 | +regex = r"\w+[.]\d+" |
| 13 | + |
| 14 | + |
| 15 | +def rename_key(key): |
| 16 | + pats = re.findall(regex, key) |
| 17 | + for pat in pats: |
| 18 | + key = key.replace(pat, "_".join(pat.split("."))) |
| 19 | + return key |
| 20 | + |
| 21 | + |
| 22 | +# Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61 |
| 23 | +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): |
| 24 | + # convert pytorch tensor to numpy |
| 25 | + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} |
| 26 | + |
| 27 | + random_flax_state_dict = flatten_dict(flax_model.params) |
| 28 | + flax_state_dict = {} |
| 29 | + |
| 30 | + remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and ( |
| 31 | + flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) |
| 32 | + ) |
| 33 | + add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and ( |
| 34 | + flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) |
| 35 | + ) |
| 36 | + |
| 37 | + # Need to change some parameters name to match Flax names so that we don't have to fork any layer |
| 38 | + for pt_key, pt_tensor in pt_state_dict.items(): |
| 39 | + pt_tuple_key = tuple(pt_key.split(".")) |
| 40 | + |
| 41 | + has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix |
| 42 | + require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict |
| 43 | + |
| 44 | + if remove_base_model_prefix and has_base_model_prefix: |
| 45 | + pt_tuple_key = pt_tuple_key[1:] |
| 46 | + elif add_base_model_prefix and require_base_model_prefix: |
| 47 | + pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key |
| 48 | + |
| 49 | + # Correctly rename weight parameters |
| 50 | + if ( |
| 51 | + "norm" in pt_key |
| 52 | + and (pt_tuple_key[-1] == "bias") |
| 53 | + and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict) |
| 54 | + ): |
| 55 | + pt_tensor = pt_tensor[None, None, None, :] |
| 56 | + elif ( |
| 57 | + "norm" in pt_key |
| 58 | + and (pt_tuple_key[-1] == "bias") |
| 59 | + and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) |
| 60 | + ): |
| 61 | + pt_tuple_key = pt_tuple_key[:-1] + ("scale",) |
| 62 | + pt_tensor = pt_tensor[None, None, None, :] |
| 63 | + elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: |
| 64 | + pt_tuple_key = pt_tuple_key[:-1] + ("scale",) |
| 65 | + pt_tensor = pt_tensor[None, None, None, :] |
| 66 | + if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: |
| 67 | + pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) |
| 68 | + elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict: |
| 69 | + # conv layer |
| 70 | + pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) |
| 71 | + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) |
| 72 | + elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: |
| 73 | + # linear layer |
| 74 | + pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) |
| 75 | + pt_tensor = pt_tensor.T |
| 76 | + elif pt_tuple_key[-1] == "gamma": |
| 77 | + pt_tuple_key = pt_tuple_key[:-1] + ("weight",) |
| 78 | + elif pt_tuple_key[-1] == "beta": |
| 79 | + pt_tuple_key = pt_tuple_key[:-1] + ("bias",) |
| 80 | + |
| 81 | + if pt_tuple_key in random_flax_state_dict: |
| 82 | + if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: |
| 83 | + raise ValueError( |
| 84 | + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " |
| 85 | + f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." |
| 86 | + ) |
| 87 | + |
| 88 | + # also add unexpected weight so that warning is thrown |
| 89 | + flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) |
| 90 | + |
| 91 | + return unflatten_dict(flax_state_dict) |
| 92 | + |
| 93 | + |
| 94 | +def convert_model(config_path, pt_state_dict_path, save_path): |
| 95 | + config = VQGANConfig.from_pretrained(config_path) |
| 96 | + model = VQModel(config) |
| 97 | + |
| 98 | + state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"] |
| 99 | + keys = list(state_dict.keys()) |
| 100 | + for key in keys: |
| 101 | + if key.startswith("loss"): |
| 102 | + state_dict.pop(key) |
| 103 | + continue |
| 104 | + renamed_key = rename_key(key) |
| 105 | + state_dict[renamed_key] = state_dict.pop(key) |
| 106 | + |
| 107 | + state = convert_pytorch_state_dict_to_flax(state_dict, model) |
| 108 | + model.params = unflatten_dict(state) |
| 109 | + model.save_pretrained(save_path) |
0 commit comments