Skip to content

Commit 2a321f4

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d540917 commit 2a321f4

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

src/autoembedder/evaluator.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ def _predict(
3737
device = torch.device(
3838
"cuda"
3939
if torch.cuda.is_available()
40-
else "mps"
41-
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
42-
else "cpu"
40+
else (
41+
"mps"
42+
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
43+
else "cpu"
44+
)
4345
)
4446

4547
with torch.no_grad():

src/autoembedder/learner.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,12 @@ def fit(
362362
torch.device(
363363
"cuda"
364364
if torch.cuda.is_available()
365-
else "mps"
366-
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
367-
else "cpu"
365+
else (
366+
"mps"
367+
if torch.backends.mps.is_available()
368+
and parameters.get("use_mps", False)
369+
else "cpu"
370+
)
368371
)
369372
)
370373
if (
@@ -437,10 +440,12 @@ def fit(
437440
map_location=torch.device(
438441
"cuda"
439442
if torch.cuda.is_available()
440-
else "mps"
441-
if torch.backends.mps.is_available()
442-
and parameters.get("use_mps", False)
443-
else "cpu"
443+
else (
444+
"mps"
445+
if torch.backends.mps.is_available()
446+
and parameters.get("use_mps", False)
447+
else "cpu"
448+
)
444449
),
445450
)
446451
Checkpoint.load_objects(

src/autoembedder/model.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def model_input(
5858
device = torch.device(
5959
"cuda"
6060
if torch.cuda.is_available()
61-
else "mps"
62-
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
63-
else "cpu"
61+
else (
62+
"mps"
63+
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
64+
else "cpu"
65+
)
6466
)
6567
cat = []
6668
cont = []

0 commit comments

Comments
 (0)