Skip to content

Commit

Permalink
fix configs without #
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Aug 3, 2024
1 parent bdb671d commit f8bb603
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 52 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sd-mecha"
version = "0.0.24"
version = "0.0.25"
description = "State dict recipe merger"
readme = "README.md"
authors = [{ name = "ljleb" }]
Expand Down
19 changes: 1 addition & 18 deletions sd_mecha/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def compose(
f.write(composed_recipe)


@click.command(help="Show the available blocks and classes of model architectures")
@click.command(help="Show the available blocks of model architectures")
@click.argument("model_arch", type=click.Choice(extensions.model_arch.get_all() + [""], case_sensitive=False), default="")
@click.option("--verbose", "-v", is_flag=True)
@click.option("--debug", is_flag=True, help="Print the stacktrace when an error occurs.")
Expand Down Expand Up @@ -156,23 +156,6 @@ def info(
for b in bs
if f"_{component}_block_" in b
], key=lambda t: natural_sort_key(t[0]))),
"Classes":
sorted(list({
c.split("_", 3)[3]
for cs in model_arch.classes.values()
for c in cs
if f"_{component}_class_" in c
}), key=natural_sort_key)
if not verbose else
dict(sorted([
(c.split("_", 3)[3], sorted([
k for k in model_arch.classes
if c in model_arch.classes[k]
], key=natural_sort_key))
for cs in model_arch.classes.values()
for c in cs
if f"_{component}_class_" in c
], key=lambda t: natural_sort_key(t[0]))),
}
for component in model_arch.components
}, sort_keys=False))
Expand Down
1 change: 0 additions & 1 deletion sd_mecha/extensions/model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def discover_blocks(keys, discovered_block_prefixes, arch_identifier: str):
for block, prefixes in discovered_block_prefixes.items():
if any(prefix.match(key) for prefix in prefixes["patterns"]):
blocks.setdefault(key, set()).add(block)
break

return blocks

Expand Down
22 changes: 16 additions & 6 deletions sd_mecha/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,15 +576,16 @@ def clamp(
) -> Tensor | SameMergeSpace:
maximums = functools.reduce(torch.maximum, bounds)
minimums = functools.reduce(torch.minimum, bounds)
centers = (maximums + minimums) / 2
bounds = torch.stack(bounds)
average = bounds.mean(dim=0)

if stiffness:
smallest_positive = maximums
largest_negative = minimums

for i, bound in enumerate(bounds):
smallest_positive = torch.where((smallest_positive >= bound) & (bound >= centers), bound, smallest_positive)
largest_negative = torch.where((largest_negative <= bound) & (bound <= centers), bound, largest_negative)
smallest_positive = torch.where((smallest_positive >= bound) & (bound >= average), bound, smallest_positive)
largest_negative = torch.where((largest_negative <= bound) & (bound <= average), bound, largest_negative)

maximums = weighted_sum.__wrapped__(maximums, smallest_positive, alpha=stiffness)
minimums = weighted_sum.__wrapped__(minimums, largest_negative, alpha=stiffness)
Expand All @@ -599,9 +600,14 @@ def dropout( # aka n-supermario
probability: Hyper = 0.9,
overlap: Hyper = 1.0,
overlap_emphasis: Hyper = 0.0,
seed: Hyper = None,
seed: Hyper = -1,
**kwargs,
) -> Tensor | LiftFlag[MergeSpace.DELTA]:
if seed < 0:
seed = None
else:
seed = int(seed)

deltas = torch.stack((delta0,) + deltas)
rng = np.random.default_rng(seed)

Expand Down Expand Up @@ -638,11 +644,15 @@ def ties_sum_with_dropout(
apply_median: Hyper = 0.0,
eps: Hyper = 1e-6,
maxiter: Hyper = 100,
ftol: Hyper =1e-20,
seed: Hyper = None,
ftol: Hyper = 1e-20,
seed: Hyper = -1,
**kwargs,
) -> Tensor | LiftFlag[MergeSpace.DELTA]:
# Set seed
if seed < 0:
seed = None
else:
seed = int(seed)
torch.manual_seed(seed)

# Under "Dropout", delta will be 0 by definition. Multiply it (Hadamard product) will return 0 also.
Expand Down
16 changes: 8 additions & 8 deletions sd_mecha/models/sd1_ldm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ merge:
unet:
prefix: model.diffusion_model
blocks:
in0: input_blocks.0.#
in3: input_blocks.3.#
in6: input_blocks.6.#
in9: input_blocks.9.#
in0: input_blocks.0
in3: input_blocks.3
in6: input_blocks.6
in9: input_blocks.9
mid:
- middle_block.#
- middle_block
- time_embed
out11:
- output_blocks.11.#
- output_blocks.11
- time_embed
- out
in*:
- input_blocks.*.#
- input_blocks.*
- time_embed
out*:
- output_blocks.*.#
- output_blocks.*
- time_embed

keys: sd1_ldm_keys.txt
12 changes: 6 additions & 6 deletions sd_mecha/models/sd3_sgm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ merge:
prefix: model.diffusion_model
blocks:
in0:
- joint_blocks.0.#
- joint_blocks.0
- pos_embed
- t_embedder
- x_embedder
- y_embedder
- context_embedder
in23:
- joint_blocks.23.#
- joint_blocks.23
- pos_embed
- t_embedder
- x_embedder
- y_embedder
- final_layer
in*:
- joint_blocks.*.#
- joint_blocks.*
- pos_embed
- t_embedder
- x_embedder
Expand Down Expand Up @@ -53,12 +53,12 @@ merge:
prefix: text_encoders.t5xxl.transformer
blocks:
in0:
- encoder.block.0.#.#
- encoder.block.0
- encoder.embed_tokens
- shared
in23:
- encoder.block.23.#.#
- encoder.block.23
- encoder.final_layer_norm
in*: encoder.block.*.#.#
in*: encoder.block.*

keys: sd3_sgm_keys.txt
19 changes: 7 additions & 12 deletions sd_mecha/models/sdxl_sgm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,22 @@ merge:
unet:
prefix: model.diffusion_model
blocks:
in0: input_blocks.0.#
in3: input_blocks.3.#
in0: input_blocks.0
in3: input_blocks.3
in6:
- input_blocks.6.#
- input_blocks.6.#.transformer_blocks.#
- input_blocks.6
mid:
- middle_block.#
- middle_block.#.transformer_blocks.#
- middle_block
- time_embed
out8:
- output_blocks.8.#
- output_blocks.8.#.transformer_blocks.#
- output_blocks.8
- time_embed
- out
in*:
- input_blocks.*.#
- input_blocks.*.#.transformer_blocks.#
- input_blocks.*
- time_embed
out*:
- output_blocks.*.#
- output_blocks.*.#.transformer_blocks.#
- output_blocks.*
- time_embed

txt:
Expand Down

0 comments on commit f8bb603

Please sign in to comment.