@@ -412,9 +412,13 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
412412 ds .pop (0 )
413413 if callback is not None :
414414 callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
415- cur_order = min (i + 1 , order )
416- coeffs = [linear_multistep_coeff (cur_order , sigmas_cpu , i , j ) for j in range (cur_order )]
417- x = x + sum (coeff * d for coeff , d in zip (coeffs , reversed (ds )))
415+ if sigmas [i + 1 ] == 0 :
416+ # Denoising step
417+ x = denoised
418+ else :
419+ cur_order = min (i + 1 , order )
420+ coeffs = [linear_multistep_coeff (cur_order , sigmas_cpu , i , j ) for j in range (cur_order )]
421+ x = x + sum (coeff * d for coeff , d in zip (coeffs , reversed (ds )))
418422 return x
419423
420424
@@ -1067,7 +1071,9 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
10671071 d_cur = (x_cur - denoised ) / t_cur
10681072
10691073 order = min (max_order , i + 1 )
1070- if order == 1 : # First Euler step.
1074+ if t_next == 0 : # Denoising step
1075+ x_next = denoised
1076+ elif order == 1 : # First Euler step.
10711077 x_next = x_cur + (t_next - t_cur ) * d_cur
10721078 elif order == 2 : # Use one history point.
10731079 x_next = x_cur + (t_next - t_cur ) * (3 * d_cur - buffer_model [- 1 ]) / 2
@@ -1085,6 +1091,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
10851091
10861092 return x_next
10871093
1094+
10881095#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
10891096#under Apache 2 license
10901097def sample_ipndm_v (model , x , sigmas , extra_args = None , callback = None , disable = None , max_order = 4 ):
@@ -1108,7 +1115,9 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
11081115 d_cur = (x_cur - denoised ) / t_cur
11091116
11101117 order = min (max_order , i + 1 )
1111- if order == 1 : # First Euler step.
1118+ if t_next == 0 : # Denoising step
1119+ x_next = denoised
1120+ elif order == 1 : # First Euler step.
11121121 x_next = x_cur + (t_next - t_cur ) * d_cur
11131122 elif order == 2 : # Use one history point.
11141123 h_n = (t_next - t_cur )
@@ -1148,6 +1157,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
11481157
11491158 return x_next
11501159
1160+
11511161#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
11521162#under Apache 2 license
11531163@torch .no_grad ()
@@ -1198,6 +1208,7 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
11981208
11991209 return x_next
12001210
1211+
12011212@torch .no_grad ()
12021213def sample_euler_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None ):
12031214 extra_args = {} if extra_args is None else extra_args
@@ -1404,6 +1415,7 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
14041415def sample_res_multistep_ancestral_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None ):
14051416 return res_multistep (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , s_noise = s_noise , noise_sampler = noise_sampler , eta = eta , cfg_pp = True )
14061417
1418+
14071419@torch .no_grad ()
14081420def sample_gradient_estimation (model , x , sigmas , extra_args = None , callback = None , disable = None , ge_gamma = 2. , cfg_pp = False ):
14091421 """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
@@ -1430,19 +1442,19 @@ def post_cfg_function(args):
14301442 if callback is not None :
14311443 callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
14321444 dt = sigmas [i + 1 ] - sigmas [i ]
1433- if i == 0 :
1445+ if sigmas [i + 1 ] == 0 :
1446+ # Denoising step
1447+ x = denoised
1448+ else :
14341449 # Euler method
14351450 if cfg_pp :
14361451 x = denoised + d * sigmas [i + 1 ]
14371452 else :
14381453 x = x + d * dt
1439- else :
1440- # Gradient estimation
1441- if cfg_pp :
1454+
1455+ if i >= 1 :
1456+ # Gradient estimation
14421457 d_bar = (ge_gamma - 1 ) * (d - old_d )
1443- x = denoised + d * sigmas [i + 1 ] + d_bar * dt
1444- else :
1445- d_bar = ge_gamma * d + (1 - ge_gamma ) * old_d
14461458 x = x + d_bar * dt
14471459 old_d = d
14481460 return x
0 commit comments