Skip to content

Commit 40a0f00

Browse files
committed
add lic & readme, code reformat & review
1 parent 29da7bd commit 40a0f00

File tree

6 files changed

+117
-64
lines changed

6 files changed

+117
-64
lines changed

Euler.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
1-
from diffusers import EulerDiscreteScheduler
2-
from torch import Tensor
3-
import torch
41
from typing import Callable, List, Optional, Tuple, Union, Dict, Any, Literal
2+
3+
import torch
4+
from torch import Tensor
5+
from diffusers import EulerDiscreteScheduler
56
from diffusers.utils import randn_tensor
67
from diffusers.configuration_utils import ConfigMixin
78
from diffusers.schedulers.scheduling_utils import SchedulerMixin
9+
10+
811
class Euler(EulerDiscreteScheduler, SchedulerMixin, ConfigMixin):
9-
history_d=0
10-
momentum=0.95
11-
momentum_hist=0.75
12-
def init_hist_d(self,x:Tensor) -> Union[Literal[0], Tensor]:
12+
13+
history_d = 0
14+
momentum = 0.95
15+
momentum_hist = 0.75
16+
17+
def init_hist_d(self, x:Tensor) -> Union[Literal[0], Tensor]:
1318
# memorize delta momentum
14-
if self.history_d == 0: self.history_d = 0
19+
if self.history_d == 0: self.history_d = 0
1520
elif self.history_d == 'rand_init': self.history_d = x
1621
elif self.history_d == 'rand_new': self.history_d = torch.randn_like(x)
1722
else: raise ValueError(f'unknown momentum_hist_init: {self.history_d}')
23+
1824
def momentum_step(self, x:Tensor, d:Tensor, dt:Tensor):
19-
hd=self.history_d
25+
hd = self.history_d
2026
# correct current `d` with momentum
2127
p = 1.0 - self.momentum
2228
self.momentum_d = (1.0 - p) * d + p * hd
@@ -30,8 +36,10 @@ def momentum_step(self, x:Tensor, d:Tensor, dt:Tensor):
3036
hd = self.momentum_d
3137
else:
3238
hd = (1.0 - q) * hd + q * self.momentum_d
33-
self.history_d=hd
39+
self.history_d = hd
40+
3441
return x
42+
3543
def step(
3644
self,
3745
model_output: torch.FloatTensor,
@@ -124,10 +132,8 @@ def step(
124132
derivative = (sample - pred_original_sample) / sigma_hat
125133

126134
dt = self.sigmas[step_index + 1] - sigma_hat
127-
128-
prev_sample = self.momentum_step(sample,derivative,dt)
129-
if not return_dict:
130-
return (prev_sample,)
131-
132-
output={prev_sample:prev_sample, pred_original_sample:pred_original_sample}
133-
return output
135+
prev_sample = self.momentum_step(sample, derivative, dt)
136+
137+
if not return_dict: return (prev_sample,)
138+
output = { prev_sample: prev_sample, pred_original_sample: pred_original_sample }
139+
return output

EulerA.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
1-
from diffusers import EulerAncestralDiscreteScheduler
2-
from torch import Tensor
3-
import torch
41
from typing import Callable, List, Optional, Tuple, Union, Dict, Any, Literal
2+
3+
import torch
4+
from torch import Tensor
5+
from diffusers import EulerAncestralDiscreteScheduler
56
from diffusers.utils import randn_tensor
67
from diffusers.configuration_utils import ConfigMixin
78
from diffusers.schedulers.scheduling_utils import SchedulerMixin
9+
10+
811
class EulerA(EulerAncestralDiscreteScheduler, SchedulerMixin, ConfigMixin):
9-
history_d=0
10-
momentum=0.95
11-
momentum_hist=0.75
12-
def init_hist_d(self,x:Tensor) -> Union[Literal[0], Tensor]:
12+
13+
history_d = 0
14+
momentum = 0.95
15+
momentum_hist = 0.75
16+
17+
def init_hist_d(self, x:Tensor) -> Union[Literal[0], Tensor]:
1318
# memorize delta momentum
14-
if self.history_d == 0: self.history_d = 0
19+
if self.history_d == 0: self.history_d = 0
1520
elif self.history_d == 'rand_init': self.history_d = x
1621
elif self.history_d == 'rand_new': self.history_d = torch.randn_like(x)
1722
else: raise ValueError(f'unknown momentum_hist_init: {self.history_d}')
23+
1824
def momentum_step(self, x:Tensor, d:Tensor, dt:Tensor):
19-
hd=self.history_d
25+
hd = self.history_d
2026
# correct current `d` with momentum
2127
p = 1.0 - self.momentum
2228
self.momentum_d = (1.0 - p) * d + p * hd
@@ -30,8 +36,10 @@ def momentum_step(self, x:Tensor, d:Tensor, dt:Tensor):
3036
hd = self.momentum_d
3137
else:
3238
hd = (1.0 - q) * hd + q * self.momentum_d
33-
self.history_d=hd
39+
self.history_d = hd
40+
3441
return x
42+
3543
def step(
3644
self,
3745
model_output: torch.FloatTensor,
@@ -101,15 +109,13 @@ def step(
101109
derivative = (sample - pred_original_sample) / sigma
102110

103111
dt = sigma_down - sigma
104-
105-
prev_sample = self.momentum_step(sample,derivative,dt)
112+
prev_sample = self.momentum_step(sample, derivative, dt)
106113

107114
device = model_output.device
108115
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
109116

110117
prev_sample = prev_sample + noise * sigma_up
111118

112-
if not return_dict:
113-
return (prev_sample,)
114-
output={prev_sample:prev_sample, pred_original_sample:pred_original_sample}
115-
return output
119+
if not return_dict: return (prev_sample,)
120+
output = { prev_sample: prev_sample, pred_original_sample: pred_original_sample }
121+
return output

LICENSE

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
MIT License
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy
4+
of this software and associated documentation files (the "Software"), to deal
5+
in the Software without restriction, including without limitation the rights
6+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
copies of the Software, and to permit persons to whom the Software is
8+
furnished to do so, subject to the following conditions:
9+
10+
The above copyright notice and this permission notice shall be included in all
11+
copies or substantial portions of the Software.
12+
13+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19+
SOFTWARE.

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# modified-euler-samplers-for-sonar-diffusers
2+
3+
----
4+
5+
This repo ports the **momentum mechanism** of [sd-webui-sonar](https://github.com/Kahsolt/stable-diffusion-webui-sonar) on `Euler` and `Euler a` sampler to [huggingface/diffusers](https://github.com/huggingface/diffusers).

example (partial).py

-31
This file was deleted.

example.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from schedulers.EulerA import EulerA
2+
# Initialize the Celery app
3+
4+
controlnet = ControlNetModel.from_pretrained(
5+
"lllyasviel/control_v11p_sd15_openpose",
6+
torch_dtype=torch.float16,
7+
local_files_only=True,
8+
)
9+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
10+
"runwayml/stable-diffusion-v1-5",
11+
controlnet=controlnet,
12+
local_files_only=True,
13+
torch_dtype=torch.float16,
14+
safety_checker=None,
15+
requires_safety_checker=False,
16+
).to('cuda')
17+
18+
# import one of the 2 schedulers from this repo
19+
pipe.scheduler = EulerA.from_config(pipe.scheduler.config)
20+
21+
# choose from [0, 'rand_new', 'rand_init']
22+
pipe.scheduler.history_d = 'rand_new'
23+
# number should be between -1 and 1
24+
pipe.scheduler.momentum = 0.95
25+
# number should be between -1 and 1
26+
pipe.scheduler.momentum_hist = 0.75
27+
28+
buffer = open('img0.png', 'rb')
29+
buffer.seek(0)
30+
image_bytes = buffer.read()
31+
images = Image.open(BytesIO(image_bytes))
32+
33+
start_time = time.time()
34+
generator = torch.manual_seed(2733424006)
35+
image=pipe(
36+
"A person standing in a field of flowers, 4k, realistic",
37+
images,
38+
num_inference_steps=20,
39+
height=512,
40+
width=512,
41+
generator=generator
42+
).images[0]
43+
end_time = time.time()
44+
execution_time = end_time - start_time
45+
print("Execution time: {:.2f} seconds".format(execution_time))
46+
47+
# print(image)
48+
image.save('img1.png', format='PNG')

0 commit comments

Comments
 (0)