Skip to content

Commit

Permalink
Fix SD2 test (#647)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored and schoi-habana committed Jan 22, 2024
1 parent 30351fa commit 03d7729
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions tests/ci/slow_tests_diffusers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

python -m pip install --upgrade pip
export RUN_SLOW=true
CUSTOM_BF16_OPS=1 python -m pytest tests/test_diffusers.py -v -s -k "test_no_throughput_regression_autocast"
make slow_tests_diffusers
19 changes: 17 additions & 2 deletions tests/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tempfile
from io import BytesIO
from pathlib import Path
from unittest import TestCase
from unittest import TestCase, skipUnless

import numpy as np
import requests
Expand All @@ -31,7 +31,7 @@
from parameterized import parameterized
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.testing_utils import slow
from transformers.testing_utils import parse_flag_from_env, slow

from optimum.habana import GaudiConfig
from optimum.habana.diffusers import (
Expand All @@ -58,6 +58,20 @@
TEXTUAL_INVERSION_RUNTIME = 206.32180358597543


_run_custom_bf16_ops_test_ = parse_flag_from_env("CUSTOM_BF16_OPS", default=False)


def custom_bf16_ops(test_case):
"""
Decorator marking a test as needing custom bf16 ops.
Custom bf16 ops must be declared before `habana_frameworks.torch.core` is imported, which is not possible if some other tests are executed before.
Such tests are skipped by default. Set the CUSTOM_BF16_OPS environment variable to a truthy value to run them.
"""
return skipUnless(_run_custom_bf16_ops_test_, "test requires custom bf16 ops")(test_case)


class GaudiPipelineUtilsTester(TestCase):
"""
Tests the features added on top of diffusers/pipeline_utils.py.
Expand Down Expand Up @@ -550,6 +564,7 @@ def test_no_throughput_regression_bf16(self):
self.assertEqual(len(outputs.images), num_images_per_prompt * len(prompts))
self.assertGreaterEqual(outputs.throughput, 0.95 * THROUGHPUT_BASELINE_BF16)

@custom_bf16_ops
@slow
def test_no_throughput_regression_autocast(self):
prompts = [
Expand Down

0 comments on commit 03d7729

Please sign in to comment.