Skip to content

Commit 9465f38

Browse files
committed
These lines overwrite the device argument of the function, rendering it useless.
1 parent aec40d8 commit 9465f38

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

HIPT_4K/hipt_4k.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def __init__(self,
3131
device4k=torch.device('cuda:1')):
3232

3333
super().__init__()
34-
self.model256 = get_vit256(pretrained_weights=model256_path).to(device256)
35-
self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device4k)
34+
self.model256 = get_vit256(pretrained_weights=model256_path, device=device256).to(device256)
35+
self.model4k = get_vit4k(pretrained_weights=model4k_path, device=device4k).to(device4k)
3636
self.device256 = device256
3737
self.device4k = device4k
3838

HIPT_4K/hipt_model_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def get_vit256(pretrained_weights, arch='vit_small', device=torch.device('cuda:0
3232
"""
3333

3434
checkpoint_key = 'teacher'
35-
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
3635
model256 = vits.__dict__[arch](patch_size=16, num_classes=0)
3736
for p in model256.parameters():
3837
p.requires_grad = False
@@ -68,7 +67,6 @@ def get_vit4k(pretrained_weights, arch='vit4k_xs', device=torch.device('cuda:1')
6867
"""
6968

7069
checkpoint_key = 'teacher'
71-
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
7270
model4k = vits4k.__dict__[arch](num_classes=0)
7371
for p in model4k.parameters():
7472
p.requires_grad = False

0 commit comments

Comments
 (0)