Skip to content

Commit 4f9313d

Browse files
authored
initial flax pndm schedular (huggingface#492)
* initial flax pndm * fix typo * use state * return state * add FlaxSchedulerOutput * fix style * add flax imports * make style * fix typos * return created state * make style * add torch/flax imports * docs * fixed typo * remove tensor_format * round instead of cast * ets is jnp array * remove copy
1 parent 7415c5c commit 4f9313d

File tree

4 files changed

+453
-9
lines changed

4 files changed

+453
-9
lines changed

__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .utils import (
2+
is_flax_available,
23
is_inflect_available,
34
is_onnx_available,
45
is_scipy_available,
@@ -60,3 +61,8 @@
6061
from .pipelines import StableDiffusionOnnxPipeline
6162
else:
6263
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
64+
65+
if is_flax_available():
66+
from .schedulers import FlaxPNDMScheduler
67+
else:
68+
from .utils.dummy_flax_objects import * # noqa F403

schedulers/__init__.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,27 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from ..utils import is_scipy_available
16-
from .scheduling_ddim import DDIMScheduler
17-
from .scheduling_ddpm import DDPMScheduler
18-
from .scheduling_karras_ve import KarrasVeScheduler
19-
from .scheduling_pndm import PNDMScheduler
20-
from .scheduling_sde_ve import ScoreSdeVeScheduler
21-
from .scheduling_sde_vp import ScoreSdeVpScheduler
22-
from .scheduling_utils import SchedulerMixin
2315

16+
from ..utils import is_flax_available, is_scipy_available, is_torch_available
17+
18+
19+
if is_torch_available():
20+
from .scheduling_ddim import DDIMScheduler
21+
from .scheduling_ddpm import DDPMScheduler
22+
from .scheduling_karras_ve import KarrasVeScheduler
23+
from .scheduling_pndm import PNDMScheduler
24+
from .scheduling_sde_ve import ScoreSdeVeScheduler
25+
from .scheduling_sde_vp import ScoreSdeVpScheduler
26+
from .scheduling_utils import SchedulerMixin
27+
else:
28+
from ..utils.dummy_pt_objects import * # noqa F403
29+
30+
if is_flax_available():
31+
from .scheduling_pndm_flax import FlaxPNDMScheduler
32+
else:
33+
from ..utils.dummy_flax_objects import * # noqa F403
2434

2535
if is_scipy_available():
2636
from .scheduling_lms_discrete import LMSDiscreteScheduler
2737
else:
28-
from ..utils.dummy_scipy_objects import * # noqa F403
38+
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403

0 commit comments

Comments
 (0)