Skip to content

Commit 2b90b9f

Browse files
authored
remove .pin_memory() in obj_pos of SAM2Base to resolve and error in MPS (#495)
In this PR, we remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS. Investigations show that `.pin_memory()` causes an error of `Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0"`, as originally reported in #487. (close #487)
1 parent 722d1d1 commit 2b90b9f

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

sam2/modeling/sam2_base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -628,10 +628,8 @@ def _prepare_memory_conditioned_features(
628628
if self.add_tpos_enc_to_obj_ptrs:
629629
t_diff_max = max_obj_ptrs_in_encoder - 1
630630
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
631-
obj_pos = (
632-
torch.tensor(pos_list)
633-
.pin_memory()
634-
.to(device=device, non_blocking=True)
631+
obj_pos = torch.tensor(pos_list).to(
632+
device=device, non_blocking=True
635633
)
636634
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
637635
obj_pos = self.obj_ptr_tpos_proj(obj_pos)

0 commit comments

Comments
 (0)