Skip to content

Stateless scheduler #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,43 +11,62 @@
InferenceState,
PNDMScheduler,
StableDiffusionPipeline,
UNet2D
UNet2D,
StableDiffusionSafetyCheckerModel,
)
from stable_diffusion_jax.convert_diffusers_to_jax import convert_diffusers_to_jax


# convert diffusers checkpoint to jax
pt_path = "path_to_diffusers_pt_ckpt"
fx_path = "save_path"
convert_diffusers_to_jax(pt_path, fx_path)
# Local checkout until weights are available in the Hub
flax_path = "/sddata/sd-v1-4-flax"

num_samples = 8
num_inference_steps = 50
guidance_scale = 7.5

devices = jax.devices()[:1]

# inference with jax
dtype = jnp.bfloat16
clip_model, clip_params = FlaxCLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", _do_init=False, dtype=dtype
)
unet, unet_params = UNet2D.from_pretrained(f"{fx_path}/unet", _do_init=False, dtype=dtype)
vae, vae_params = AutoencoderKL.from_pretrained(f"{fx_path}/vae", _do_init=False, dtype=dtype)
safety_model, safety_model_params = StableDiffusionSafetyCheckerModel.from_pretrained(f"{fx_path}/safety_model", _do_init=False, dtype=dtype)
unet, unet_params = UNet2D.from_pretrained(f"{flax_path}/unet", _do_init=False, dtype=dtype)
vae, vae_params = AutoencoderKL.from_pretrained(f"{flax_path}/vae", _do_init=False, dtype=dtype)
safety_model, safety_model_params = StableDiffusionSafetyCheckerModel.from_pretrained(f"{flax_path}/safety_checker", _do_init=False, dtype=dtype)

config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
scheduler = PNDMScheduler()

latents_shape = (
num_samples,
unet.config.sample_size,
unet.config.sample_size,
unet.config.in_channels,
)

scheduler = PNDMScheduler.from_config(f"{flax_path}/scheduler")
scheduler_state = scheduler.set_timesteps(
scheduler.state,
latents_shape,
num_inference_steps = num_inference_steps,
offset = 1,
)

# create inference state and replicate it across all TPU devices
inference_state = InferenceState(text_encoder_params=clip_params, unet_params=unet_params, vae_params=vae_params)
inference_state = replicate(inference_state)
inference_state = InferenceState(
text_encoder_params=clip_params,
unet_params=unet_params,
vae_params=vae_params,
scheduler_state=scheduler_state,
)
inference_state = replicate(inference_state, devices=devices)


# create pipeline
pipe = StableDiffusionPipeline(text_encoder=clip_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, vae=vae)



# prepare inputs
num_samples = 8
p = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"

input_ids = tokenizer(
Expand All @@ -59,23 +78,20 @@
prng_seed = jax.random.PRNGKey(42)

# shard inputs and rng
input_ids = shard(input_ids)
uncond_input_ids = shard(uncond_input_ids)
prng_seed = jax.random.split(prng_seed, 8)
# Simply use shard if using the default devices
input_ids = jax.device_put_sharded([input_ids], devices)
uncond_input_ids = jax.device_put_sharded([uncond_input_ids], devices)
prng_seed = jax.random.split(prng_seed, len(devices))

# pmap the sample function
num_inference_steps = 50
guidance_scale = 1.0

sample = jax.pmap(pipe.sample, static_broadcasted_argnums=(4, 5))
sample = jax.pmap(pipe.sample, static_broadcasted_argnums=(4,))

# sample images
images = sample(
input_ids,
uncond_input_ids,
prng_seed,
inference_state,
num_inference_steps,
guidance_scale,
)

Expand All @@ -87,3 +103,4 @@
images = np.asarray(images).reshape((num_samples, 512, 512, 3))

pil_images = [Image.fromarray(image) for image in images]
pil_images[0].save("example.png")
23 changes: 12 additions & 11 deletions stable_diffusion_jax/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from PIL import Image
from transformers import CLIPTokenizer, FlaxCLIPTextModel

from stable_diffusion_jax.scheduling_pndm import PNDMScheduler
from stable_diffusion_jax.scheduling_pndm import PNDMSchedulerState


@flax.struct.dataclass
class InferenceState:
text_encoder_params: flax.core.FrozenDict
unet_params: flax.core.FrozenDict
vae_params: flax.core.FrozenDict
scheduler_state: PNDMSchedulerState


class StableDiffusionPipeline:
Expand Down Expand Up @@ -41,13 +42,9 @@ def sample(
uncond_input_ids: jnp.ndarray,
prng_seed: jax.random.PRNGKey,
inference_state: InferenceState,
num_inference_steps: int = 50,
guidance_scale: float = 1.0,
debug: bool = False,
):

self.scheduler.set_timesteps(num_inference_steps)

text_embeddings = self.text_encoder(input_ids, params=inference_state.text_encoder_params)[0]
uncond_embeddings = self.text_encoder(uncond_input_ids, params=inference_state.text_encoder_params)[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
Expand All @@ -60,13 +57,14 @@ def sample(
)
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)

def loop_body(step, latents):
def loop_body(step, args):
latents, scheduler_state = args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = jnp.concatenate([latents] * 2)

t = jnp.array(self.scheduler.timesteps)[step]
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])

# predict the noise residual
Expand All @@ -78,15 +76,18 @@ def loop_body(step, latents):
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"]
return latents
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents)
latents = latents["prev_sample"]
return latents, scheduler_state

scheduler_state = inference_state.scheduler_state
num_inference_steps = len(scheduler_state.timesteps)
if debug:
# run with python for loop
for i in range(num_inference_steps):
latents = loop_body(i, latents)
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
else:
latents = jax.lax.fori_loop(0, num_inference_steps, loop_body, latents)
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
95 changes: 77 additions & 18 deletions stable_diffusion_jax/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,25 +196,84 @@ def step_plms(
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)

if self.counter != 1:
self.ets.append(model_output)
else:
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
# Reference:
# if state.counter != 1:
# state.ets.append(model_output)
# else:
# prev_timestep = timestep
# timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps

prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
timestep = jnp.where(state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep)

# Reference:
# if len(state.ets) == 1 and state.counter == 0:
# model_output = model_output
# state.cur_sample = sample
# elif len(state.ets) == 1 and state.counter == 1:
# model_output = (model_output + state.ets[-1]) / 2
# sample = state.cur_sample
# state.cur_sample = None
# elif len(state.ets) == 2:
# model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
# elif len(state.ets) == 3:
# model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
# else:
# model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])

def counter_0(state: PNDMSchedulerState):
ets = state.ets.at[0].set(model_output)
return state.replace(
ets = ets,
sample = sample,
model_output = jnp.array(model_output, dtype=jnp.float32),
)

if len(self.ets) == 1 and self.counter == 0:
model_output = model_output
self.cur_sample = sample
elif len(self.ets) == 1 and self.counter == 1:
model_output = (model_output + self.ets[-1]) / 2
sample = self.cur_sample
self.cur_sample = None
elif len(self.ets) == 2:
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
elif len(self.ets) == 3:
model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
else:
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
def counter_1(state: PNDMSchedulerState):
return state.replace(
model_output = (model_output + state.ets[0]) / 2,
)

def counter_2(state: PNDMSchedulerState):
ets = state.ets.at[1].set(model_output)
return state.replace(
ets = ets,
model_output = (3 * ets[1] - ets[0]) / 2,
sample = sample,
)

def counter_3(state: PNDMSchedulerState):
ets = state.ets.at[2].set(model_output)
return state.replace(
ets = ets,
model_output = (23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
sample = sample,
)

def counter_other(state: PNDMSchedulerState):
ets = state.ets.at[3].set(model_output)
next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])

ets = ets.at[0].set(ets[1])
ets = ets.at[1].set(ets[2])
ets = ets.at[2].set(ets[3])

return state.replace(
ets = ets,
model_output = next_model_output,
sample = sample,
)

counter = jnp.clip(state.counter, 0, 4)
state = jax.lax.switch(
counter,
[counter_0, counter_1, counter_2, counter_3, counter_other],
state,
)

sample = state.sample
model_output = state.model_output
prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output)

prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still use the self.counter variable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we should return the state somewhere no?

Expand Down