Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,43 +11,62 @@
InferenceState,
PNDMScheduler,
StableDiffusionPipeline,
UNet2D
UNet2D,
StableDiffusionSafetyCheckerModel,
)
from stable_diffusion_jax.convert_diffusers_to_jax import convert_diffusers_to_jax


# convert diffusers checkpoint to jax
pt_path = "path_to_diffusers_pt_ckpt"
fx_path = "save_path"
convert_diffusers_to_jax(pt_path, fx_path)
# Local checkout until weights are available in the Hub
flax_path = "/sddata/sd-v1-4-flax"

num_samples = 8
num_inference_steps = 50
guidance_scale = 7.5

devices = jax.devices()[:1]

# inference with jax
dtype = jnp.bfloat16
clip_model, clip_params = FlaxCLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", _do_init=False, dtype=dtype
)
unet, unet_params = UNet2D.from_pretrained(f"{fx_path}/unet", _do_init=False, dtype=dtype)
vae, vae_params = AutoencoderKL.from_pretrained(f"{fx_path}/vae", _do_init=False, dtype=dtype)
safety_model, safety_model_params = StableDiffusionSafetyCheckerModel.from_pretrained(f"{fx_path}/safety_model", _do_init=False, dtype=dtype)
unet, unet_params = UNet2D.from_pretrained(f"{flax_path}/unet", _do_init=False, dtype=dtype)
vae, vae_params = AutoencoderKL.from_pretrained(f"{flax_path}/vae", _do_init=False, dtype=dtype)
safety_model, safety_model_params = StableDiffusionSafetyCheckerModel.from_pretrained(f"{flax_path}/safety_checker", _do_init=False, dtype=dtype)

config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
scheduler = PNDMScheduler()

latents_shape = (
num_samples,
unet.config.sample_size,
unet.config.sample_size,
unet.config.in_channels,
)

scheduler = PNDMScheduler.from_config(f"{flax_path}/scheduler")
scheduler_state = scheduler.set_timesteps(
scheduler.state,
latents_shape,
num_inference_steps = num_inference_steps,
offset = 1,
)

# create inference state and replicate it across all TPU devices
inference_state = InferenceState(text_encoder_params=clip_params, unet_params=unet_params, vae_params=vae_params)
inference_state = replicate(inference_state)
inference_state = InferenceState(
text_encoder_params=clip_params,
unet_params=unet_params,
vae_params=vae_params,
scheduler_state=scheduler_state,
)
inference_state = replicate(inference_state, devices=devices)


# create pipeline
pipe = StableDiffusionPipeline(text_encoder=clip_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, vae=vae)



# prepare inputs
num_samples = 8
p = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"

input_ids = tokenizer(
Expand All @@ -59,23 +78,20 @@
prng_seed = jax.random.PRNGKey(42)

# shard inputs and rng
input_ids = shard(input_ids)
uncond_input_ids = shard(uncond_input_ids)
prng_seed = jax.random.split(prng_seed, 8)
# Simply use shard if using the default devices
input_ids = jax.device_put_sharded([input_ids], devices)
uncond_input_ids = jax.device_put_sharded([uncond_input_ids], devices)
prng_seed = jax.random.split(prng_seed, len(devices))

# pmap the sample function
num_inference_steps = 50
guidance_scale = 1.0

sample = jax.pmap(pipe.sample, static_broadcasted_argnums=(4, 5))
sample = jax.pmap(pipe.sample, static_broadcasted_argnums=(4,))

# sample images
images = sample(
input_ids,
uncond_input_ids,
prng_seed,
inference_state,
num_inference_steps,
guidance_scale,
)

Expand All @@ -87,3 +103,4 @@
images = np.asarray(images).reshape((num_samples, 512, 512, 3))

pil_images = [Image.fromarray(image) for image in images]
pil_images[0].save("example.png")
33 changes: 33 additions & 0 deletions jax-vs-torch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Loop test

Procedure:
- Run `torch-micro-loop.py` in a Linux box for reference. It will save the initial state (latents and text embeddings), and the final latents.
- Run `jax-micro-loop.py` in a TPU. It will load the initial state, run the loop and compare the results.

Notes:
- "cpu" and "tpu" devices in a TPU give the same results unless I'm doing something wrong.
- I fixed a small difference in the scheduler (offset=1), but its impact is low. I think it looks correct, I run it with some fixed inputs and the results were the same. I'll repeat with a more complex text.
- Runing a full `for` loop instead of a `fori_loop` produces double the difference. I don't undernstand the reason why.
- In another test, I looped without involving the scheduler (the output from the unet was the input for the next cycle) and the differences were much smaller. They did accumulate as iterations grow (I saved intermediate results every 10 steps).

Summary of differences:

* Torch cuda vs JAX TPU
```
Max: 3.999680519104004
Mean: 0.32840099930763245
```

* Torch CPU (in TPU device) vs JAX TPU
```
Max: 2.8781819343566895
Mean: 0.3067624568939209
```

* Torch CPU (Linux) vs JAX TPU
```
Max: 2.8606467247009277
Mean: 0.30655285716056824
```

If I hardcode step 20 in the scheduler call, the max difference goes down to `~1`.
153 changes: 153 additions & 0 deletions jax-vs-torch/jax-micro-loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import jax
import jax.numpy as jnp
import os

device = "cuda"
run_parallel = True # Set to False for interactive / debugging
dtype = jnp.float32
enable_x64_timesteps = True

if enable_x64_timesteps:
# Experimental: enable int64 for timestep.
# This will perform some computations in float64, then we'll truncate to float32.
jax.config.update("jax_enable_x64", True)

if device == "cpu":
# Make sure we really use the CPU
os.environ["CUDA_VISIBLE_DEVICES"] = ""
jax.config.update('jax_platform_name', 'cpu')

num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
# assert device_type.startswith("TPU"), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

from flax.jax_utils import replicate
from flax.training.common_utils import shard

import numpy as np

from PIL import Image
from stable_diffusion_jax.scheduling_pndm import PNDMSchedulerState
from stable_diffusion_jax import (
AutoencoderKL,
PNDMScheduler,
UNet2D
)

from tqdm import tqdm
import pickle

# Local checkout
flax_path = "/sddata/sd-v1-4-flax"
tensors_path = "/sddata/sd-tests/tensors"

unet, unet_params = UNet2D.from_pretrained(f"{flax_path}/unet", _do_init=False, dtype=dtype)
scheduler = PNDMScheduler.from_config(f"{flax_path}/scheduler")
initial_state = scheduler.state.state_dict.copy()

# Using jax.debug.print() makes it crash :()
def mini_sample(
text_embeddings: jnp.ndarray,
latents: jnp.ndarray,
unet_params,
scheduler_state_dict: dict,
num_inference_steps: int = 50,
break_after: int = 51,
):
scheduler_state = PNDMSchedulerState.from_state_dict(scheduler_state_dict)
scheduler_state = scheduler.set_timesteps(scheduler_state, latents.shape, num_inference_steps, offset=1)
scheduler_state_dict = scheduler_state.state_dict

def loop_body(step, args):
latents, scheduler_state_dict = args
t = jnp.array(scheduler_state_dict["timesteps"], dtype=jnp.int64)[step]
timestep = jnp.broadcast_to(t, latents.shape[0])

# predict the noise residual
noise_pred = unet(
latents, timestep, encoder_hidden_states=text_embeddings, params=unet_params
)

# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state_dict = scheduler.step(scheduler_state_dict, noise_pred, t, latents)
latents = jnp.array(latents["prev_sample"], dtype=jnp.float32)
return latents, scheduler_state_dict

n = min(len(scheduler_state.timesteps), break_after)
if run_parallel:
latents, scheduler_state_dict = jax.lax.fori_loop(0, n, loop_body, (latents, scheduler_state_dict))
else:
for step in range(n):
latents, scheduler_state_dict = loop_body(step, (latents, scheduler_state_dict))
print(f"{step}: {latents[0, 0, 0, 1]}")

return latents, scheduler_state_dict

p_sample = jax.pmap(mini_sample, in_axes=(0, 0, 0, None), static_broadcasted_argnums=(4,5))

# Run tests on a single device
devices = jax.devices(device)[:1]

def read_latents():
with open(f'{tensors_path}/latents_7667_cuda', 'rb') as f:
latents = pickle.load(f)

latents = jnp.array(latents)
latents = jnp.transpose(latents, (0, 2, 3, 1))
return latents

latents = read_latents()

with open(f'{tensors_path}/embeddings_7667', 'rb') as f:
text_embeddings = pickle.load(f)
text_embeddings = jnp.array(text_embeddings)

if run_parallel:
latents = jax.device_put_sharded([latents], devices)
unet_params = replicate(unet_params, devices)
text_embeddings = jax.device_put_sharded([text_embeddings], devices)

num_inference_steps = 50
sample_fn = p_sample if run_parallel else mini_sample
for step in [50]:
latents = read_latents()
if run_parallel:
latents = jax.device_put_sharded([latents], devices)

scheduler_state_dict = initial_state.copy()
latents, _ = sample_fn(text_embeddings, latents, unet_params, scheduler_state_dict, num_inference_steps, step+1)

latents = latents[0] # unshard
slice = latents[0, 0, 0, :] if run_parallel else latents[0, 0, :]
print(f"Step: {step}: {slice}")


# scheduler_state_dict = initial_state.copy()
# latents, _ = p_sample(text_embeddings, latents, unet_params, scheduler_state_dict, num_inference_steps, 51)

latents = 1 / 0.18215 * latents

vae, vae_params = AutoencoderKL.from_pretrained(f"{flax_path}/vae", _do_init=False, dtype=dtype)
images = vae.decode(latents, params=vae_params)

# convert images to PIL images
images = images / 2 + 0.5
images = jnp.clip(images, 0, 1)
images = (images * 255).round().astype("uint8")
images = np.asarray(images).reshape((1, 512, 512, 3))

pil_images = [Image.fromarray(image) for image in images]
pil_images[0].save(f"jax_{device}_fixes.png")


with open(f"{tensors_path}/latents_7667_cuda_final", "rb") as f:
torch_latents = pickle.load(f)
torch_latents = jnp.transpose(jnp.array(torch_latents), (0, 2, 3, 1))

assert torch_latents.shape == latents.shape, "Wrong shapes"

print(f"Sum: {jnp.sum(jnp.abs((latents - torch_latents)))}")
print(f"Max: {jnp.max(jnp.abs((latents - torch_latents)))}")
print(f"Mean: {jnp.mean(jnp.abs((latents - torch_latents)))}")
84 changes: 84 additions & 0 deletions jax-vs-torch/jax-micro-scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import jax
import torch

from stable_diffusion_jax.scheduling_pndm import PNDMSchedulerState
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
# assert device_type.startswith("TPU"), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

# Run tests on a single device of the specified family
device = "cuda"
devices = jax.devices(device)[:1]
# devices = jax.devices()[:1]
print(f"Running tests on {devices}")

import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from stable_diffusion_jax import PNDMScheduler

# Local checkout
flax_path = "/sddata/sd-v1-4-flax"

dtype = jnp.float32
scheduler = PNDMScheduler.from_config(f"{flax_path}/scheduler")
initial_state = scheduler.state.state_dict.copy()

# Using jax.debug.print() makes it crash :()
def mini_sample(
latents: jnp.ndarray,
# noise: jnp.ndarray,
scheduler_state_dict: dict,
num_inference_steps: int = 50,
break_after: int = 50,
):
scheduler_state = PNDMSchedulerState.from_state_dict(scheduler_state_dict)
scheduler_state = scheduler.set_timesteps(scheduler_state, latents.shape, num_inference_steps, offset=1)
scheduler_state_dict = scheduler_state.state_dict

# Prepare noise samples
torch.manual_seed(42)
noise = [torch.randn((1, 4)) for _ in range(num_inference_steps+1)]

def loop_body(step, args):
latents, scheduler_state_dict = args
t = jnp.array(scheduler_state_dict["timesteps"])[step]

noise_sample = jnp.array(noise[step].numpy())
noise_sample = jax.device_put_sharded([noise_sample], devices)

# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state_dict = scheduler.step(scheduler_state_dict, noise_sample, t, latents)
latents = latents["prev_sample"]
# return latents, scheduler_state_dict #, t
return latents, scheduler_state_dict, t

n = min(len(scheduler_state.timesteps), break_after)

# latents, scheduler_state_dict = jax.lax.fori_loop(0, n, loop_body, (latents, scheduler_state_dict))
for step in range(n):
latents, scheduler_state_dict, t = loop_body(step, (latents, scheduler_state_dict))
print(f"{step} [{t}]: {latents}")

return latents, scheduler_state_dict

p_sample = jax.pmap(mini_sample, in_axes=(0, 0, None), static_broadcasted_argnums=(2,3))

num_inference_steps = 50

latents = jnp.array([[0.34100962, -1.0947237, -1.778018, 0.43691084]])
latents = jax.device_put_sharded([latents], devices)
scheduler_state_dict = initial_state.copy()
latents, _ = mini_sample(latents, scheduler_state_dict, num_inference_steps, num_inference_steps+1)


# for step in range(2):
# latents = jnp.array([[0.34100962, -1.0947237, -1.778018, 0.43691084]])
# latents = jax.device_put_sharded([latents], devices)

# scheduler_state_dict = initial_state.copy()
# latents, _ = mini_sample(latents, scheduler_state_dict, num_inference_steps, step+1)
# print(f"Step: {step}: {latents}")
Binary file added jax-vs-torch/tensors/embeddings_7667
Binary file not shown.
Binary file added jax-vs-torch/tensors/latents_7667
Binary file not shown.
Binary file added jax-vs-torch/tensors/latents_7667_final
Binary file not shown.
Loading