File tree 3 files changed +22
-13
lines changed
3 files changed +22
-13
lines changed Original file line number Diff line number Diff line change @@ -37,9 +37,11 @@ def _predict(
37
37
device = torch .device (
38
38
"cuda"
39
39
if torch .cuda .is_available ()
40
- else "mps"
41
- if torch .backends .mps .is_available () and parameters .get ("use_mps" , False )
42
- else "cpu"
40
+ else (
41
+ "mps"
42
+ if torch .backends .mps .is_available () and parameters .get ("use_mps" , False )
43
+ else "cpu"
44
+ )
43
45
)
44
46
45
47
with torch .no_grad ():
Original file line number Diff line number Diff line change @@ -362,9 +362,12 @@ def fit(
362
362
torch .device (
363
363
"cuda"
364
364
if torch .cuda .is_available ()
365
- else "mps"
366
- if torch .backends .mps .is_available () and parameters .get ("use_mps" , False )
367
- else "cpu"
365
+ else (
366
+ "mps"
367
+ if torch .backends .mps .is_available ()
368
+ and parameters .get ("use_mps" , False )
369
+ else "cpu"
370
+ )
368
371
)
369
372
)
370
373
if (
@@ -437,10 +440,12 @@ def fit(
437
440
map_location = torch .device (
438
441
"cuda"
439
442
if torch .cuda .is_available ()
440
- else "mps"
441
- if torch .backends .mps .is_available ()
442
- and parameters .get ("use_mps" , False )
443
- else "cpu"
443
+ else (
444
+ "mps"
445
+ if torch .backends .mps .is_available ()
446
+ and parameters .get ("use_mps" , False )
447
+ else "cpu"
448
+ )
444
449
),
445
450
)
446
451
Checkpoint .load_objects (
Original file line number Diff line number Diff line change @@ -58,9 +58,11 @@ def model_input(
58
58
device = torch .device (
59
59
"cuda"
60
60
if torch .cuda .is_available ()
61
- else "mps"
62
- if torch .backends .mps .is_available () and parameters .get ("use_mps" , False )
63
- else "cpu"
61
+ else (
62
+ "mps"
63
+ if torch .backends .mps .is_available () and parameters .get ("use_mps" , False )
64
+ else "cpu"
65
+ )
64
66
)
65
67
cat = []
66
68
cont = []
You can’t perform that action at this time.
0 commit comments