Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit b4b7144

Browse files
ezyangfacebook-github-bot
authored andcommitted
Tracked new device argument in is_pinned/pin_memory for nested tensor
Reviewed By: cpuhrsch Differential Revision: D29467785 fbshipit-source-id: 966c05ef16b7955afee982c168002c2ab7ed644a
1 parent 3afc66d commit b4b7144

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

nestedtensor/csrc/functions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ Tensor NestedTensor__log_softmax(
159159
[&](Tensor a) { return at::_log_softmax(a, dim_, half_to_float); }, self);
160160
}
161161

162-
Tensor NestedTensor_pin_memory(const Tensor& self) {
162+
Tensor NestedTensor_pin_memory(const Tensor& self, c10::optional<Device> device) {
163163
return map_nested_tensor(
164-
[](Tensor tensor) { return at::native::pin_memory(tensor); }, self);
164+
[&](Tensor tensor) { return at::native::pin_memory(tensor, device); }, self);
165165
}
166166

167167
Tensor NestedTensor_flatten(

nestedtensor/csrc/nested_tensor_impl.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ Tensor NestedTensor_contiguous(const Tensor& self, MemoryFormat memory_format) {
135135
std::shared_ptr<NestedTensorStorage>(ps_base));
136136
}
137137

138-
bool NestedTensor_is_pinned(const Tensor& self) {
138+
bool NestedTensor_is_pinned(const Tensor& self, c10::optional<Device> device) {
139+
TORCH_CHECK(
140+
!device.has_value() || device->is_cuda(),
141+
"nested tensor doesn't support non-CUDA pinned memory");
139142
return get_nested_tensor_impl(self)->is_pinned();
140143
}
141144

0 commit comments

Comments
 (0)