Skip to content

Commit

Permalink
Update for SDXL Turbo support (#634)
Browse files Browse the repository at this point in the history
  • Loading branch information
atakaha authored Jan 15, 2024
1 parent 7380e2b commit 4bcdef8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
7 changes: 5 additions & 2 deletions examples/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ python text_to_image_generation.py \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16
--bf16
```

> HPU graphs are recommended when generating images by batches to get the fastest possible generations.
Expand Down Expand Up @@ -192,7 +192,10 @@ python text_to_image_generation.py \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16
--bf16 \
--num_inference_steps 1 \
--guidance_scale 0.0 \
--timestep_spacing trailing
```

> HPU graphs are recommended when generating images by batches to get the fastest possible generations.
Expand Down
16 changes: 13 additions & 3 deletions examples/stable-diffusion/text_to_image_generation.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def main():
help="Name of scheduler",
)

parser.add_argument(
"--timestep_spacing",
default="linspace",
choices=["linspace", "leading", "trailing"],
type=str,
help="The way the timesteps should be scaled.",
)
# Pipeline arguments
parser.add_argument(
"--prompts",
Expand Down Expand Up @@ -207,14 +214,17 @@ def main():
logger.setLevel(logging.INFO)

# Initialize the scheduler and the generation pipeline
kwargs = {"timestep_spacing": args.timestep_spacing}
if args.scheduler == "euler_discrete":
scheduler = GaudiEulerDiscreteScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler")
scheduler = GaudiEulerDiscreteScheduler.from_pretrained(
args.model_name_or_path, subfolder="scheduler", **kwargs
)
elif args.scheduler == "euler_ancestral_discrete":
scheduler = GaudiEulerAncestralDiscreteScheduler.from_pretrained(
args.model_name_or_path, subfolder="scheduler"
args.model_name_or_path, subfolder="scheduler", **kwargs
)
else:
scheduler = GaudiDDIMScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler")
scheduler = GaudiDDIMScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler", **kwargs)

kwargs = {
"scheduler": scheduler,
Expand Down
23 changes: 21 additions & 2 deletions tests/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ class GaudiStableDiffusionXLPipelineTester(TestCase):
Tests the StableDiffusionXLPipeline for Gaudi.
"""

def get_dummy_components(self, time_cond_proj_dim=None):
def get_dummy_components(self, time_cond_proj_dim=None, timestep_spacing="leading"):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(2, 4),
Expand All @@ -844,7 +844,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
timestep_spacing=timestep_spacing,
)
torch.manual_seed(0)
vae = AutoencoderKL(
Expand Down Expand Up @@ -933,6 +933,25 @@ def test_stable_diffusion_xl_euler_ancestral(self):
expected_slice = np.array([0.4675, 0.5173, 0.4611, 0.4067, 0.5250, 0.4674, 0.5446, 0.5094, 0.4791])
self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2)

def test_stable_diffusion_xl_turbo_euler_ancestral(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(timestep_spacing="trailing")
gaudi_config = GaudiConfig(use_torch_autocast=False)

sd_pipe = GaudiStableDiffusionXLPipeline(use_habana=True, gaudi_config=gaudi_config, **components)
sd_pipe.scheduler = GaudiEulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)

sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images[0]

image_slice = image[-3:, -3:, -1]

self.assertEqual(image.shape, (64, 64, 3))
expected_slice = np.array([0.4675, 0.5173, 0.4611, 0.4067, 0.5250, 0.4674, 0.5446, 0.5094, 0.4791])
self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2)

@parameterized.expand(["pil", "np", "latent"])
def test_stable_diffusion_xl_output_types(self, output_type):
components = self.get_dummy_components()
Expand Down

0 comments on commit 4bcdef8

Please sign in to comment.