Skip to content

Commit

Permalink
add primitive devices util
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Feb 14, 2025
1 parent 7edc4dd commit a2f145d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
)
from .callbacks import detailed_loss_callback
from .comp_utils import expected_calibration_error
from .devices import devices
from .dict_utils import (
convert_args,
convert_kwargs,
Expand Down
22 changes: 22 additions & 0 deletions bayesflow/utils/devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import keras


def devices() -> list:
"""Returns a list of available GPU devices."""
match keras.backend.backend():
case "jax":
import jax

return jax.devices("gpu")
case "tensorflow":
import tensorflow as tf

return tf.config.list_physical_devices("GPU")
case "torch":
import torch

return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
case "numpy":
return []
case _:
raise NotImplementedError(f"Backend {keras.backend.backend()} not supported.")

0 comments on commit a2f145d

Please sign in to comment.