|
| 1 | +Supporting Custom C++ Classes in torch.compile/torch.export |
| 2 | +=========================================================== |
| 3 | + |
| 4 | + |
| 5 | +This tutorial is a follow-on to the |
| 6 | +:doc:`custom C++ classes <torch_script_custom_classes>` tutorial, and |
| 7 | +introduces additional steps that are needed to support custom C++ classes in |
| 8 | +torch.compile/torch.export. |
| 9 | + |
| 10 | +.. warning:: |
| 11 | + |
| 12 | + This feature is in prototype status and is subject to backwards compatibility |
| 13 | + breaking changes. This tutorial provides a snapshot as of PyTorch 2.8. If |
| 14 | + you run into any issues, please reach out to us on Github! |
| 15 | + |
| 16 | +Concretely, there are a few steps: |
| 17 | + |
| 18 | +1. Implement an ``__obj_flatten__`` method to the C++ custom class |
| 19 | + implementation to allow us to inspect its states and guard the changes. The |
| 20 | + method should return a tuple of tuple of attribute_name, value |
| 21 | + (``tuple[tuple[str, value] * n]``). |
| 22 | + |
| 23 | +2. Register a python fake class using ``@torch._library.register_fake_class`` |
| 24 | + |
| 25 | + a. Implement “fake methods” of each of the class’s c++ methods, which should |
| 26 | + have the same schema as the C++ implementation. |
| 27 | + |
| 28 | + b. Additionally, implement an ``__obj_unflatten__`` classmethod in the Python |
| 29 | + fake class to tell us how to create a fake class from the flattened |
| 30 | + states returned by ``__obj_flatten__``. |
| 31 | + |
| 32 | +Here is a breakdown of the diff. Following the guide in |
| 33 | +:doc:`Extending TorchScript with Custom C++ Classes <torch_script_custom_classes>`, |
| 34 | +we can create a thread-safe tensor queue and build it. |
| 35 | + |
| 36 | +.. code-block:: cpp |
| 37 | +
|
| 38 | + // Thread-safe Tensor Queue |
| 39 | +
|
| 40 | + #include <torch/custom_class.h> |
| 41 | + #include <torch/script.h> |
| 42 | +
|
| 43 | + #include <iostream> |
| 44 | + #include <string> |
| 45 | + #include <vector> |
| 46 | +
|
| 47 | + using namespace torch::jit; |
| 48 | +
|
| 49 | + // Thread-safe Tensor Queue |
| 50 | + struct TensorQueue : torch::CustomClassHolder { |
| 51 | + explicit TensorQueue(at::Tensor t) : init_tensor_(t) {} |
| 52 | +
|
| 53 | + explicit TensorQueue(c10::Dict<std::string, at::Tensor> dict) { |
| 54 | + init_tensor_ = dict.at(std::string("init_tensor")); |
| 55 | + const std::string key = "queue"; |
| 56 | + at::Tensor size_tensor; |
| 57 | + size_tensor = dict.at(std::string(key + "/size")).cpu(); |
| 58 | + const auto* size_tensor_acc = size_tensor.const_data_ptr<int64_t>(); |
| 59 | + int64_t queue_size = size_tensor_acc[0]; |
| 60 | +
|
| 61 | + for (const auto index : c10::irange(queue_size)) { |
| 62 | + at::Tensor val; |
| 63 | + queue_[index] = dict.at(key + "/" + std::to_string(index)); |
| 64 | + queue_.push_back(val); |
| 65 | + } |
| 66 | + } |
| 67 | +
|
| 68 | + // Push the element to the rear of queue. |
| 69 | + // Lock is added for thread safe. |
| 70 | + void push(at::Tensor x) { |
| 71 | + std::lock_guard<std::mutex> guard(mutex_); |
| 72 | + queue_.push_back(x); |
| 73 | + } |
| 74 | + // Pop the front element of queue and return it. |
| 75 | + // If empty, return init_tensor_. |
| 76 | + // Lock is added for thread safe. |
| 77 | + at::Tensor pop() { |
| 78 | + std::lock_guard<std::mutex> guard(mutex_); |
| 79 | + if (!queue_.empty()) { |
| 80 | + auto val = queue_.front(); |
| 81 | + queue_.pop_front(); |
| 82 | + return val; |
| 83 | + } else { |
| 84 | + return init_tensor_; |
| 85 | + } |
| 86 | + } |
| 87 | +
|
| 88 | + std::vector<at::Tensor> get_raw_queue() { |
| 89 | + std::vector<at::Tensor> raw_queue(queue_.begin(), queue_.end()); |
| 90 | + return raw_queue; |
| 91 | + } |
| 92 | +
|
| 93 | + private: |
| 94 | + std::deque<at::Tensor> queue_; |
| 95 | + std::mutex mutex_; |
| 96 | + at::Tensor init_tensor_; |
| 97 | + }; |
| 98 | +
|
| 99 | + // The torch binding code |
| 100 | + TORCH_LIBRARY(MyCustomClass, m) { |
| 101 | + m.class_<TensorQueue>("TensorQueue") |
| 102 | + .def(torch::init<at::Tensor>()) |
| 103 | + .def("push", &TensorQueue::push) |
| 104 | + .def("pop", &TensorQueue::pop) |
| 105 | + .def("get_raw_queue", &TensorQueue::get_raw_queue); |
| 106 | + } |
| 107 | +
|
| 108 | +**Step 1**: Add an ``__obj_flatten__`` method to the C++ custom class implementation: |
| 109 | + |
| 110 | +.. code-block:: cpp |
| 111 | +
|
| 112 | + // Thread-safe Tensor Queue |
| 113 | + struct TensorQueue : torch::CustomClassHolder { |
| 114 | + ... |
| 115 | + std::tuple<std::tuple<std::string, std::vector<at::Tensor>>, std::tuple<std::string, at::Tensor>> __obj_flatten__() { |
| 116 | + return std::tuple(std::tuple("queue", this->get_raw_queue()), std::tuple("init_tensor_", this->init_tensor_.clone())); |
| 117 | + } |
| 118 | + ... |
| 119 | + }; |
| 120 | +
|
| 121 | + TORCH_LIBRARY(MyCustomClass, m) { |
| 122 | + m.class_<TensorQueue>("TensorQueue") |
| 123 | + .def(torch::init<at::Tensor>()) |
| 124 | + ... |
| 125 | + .def("__obj_flatten__", &TensorQueue::__obj_flatten__); |
| 126 | + } |
| 127 | +
|
| 128 | +**Step 2a**: Register a fake class in Python that implements each method. |
| 129 | + |
| 130 | +.. code-block:: python |
| 131 | +
|
| 132 | + # namespace::class_name |
| 133 | + @torch._library.register_fake_class("MyCustomClass::TensorQueue") |
| 134 | + class FakeTensorQueue: |
| 135 | + def __init__( |
| 136 | + self, |
| 137 | + queue: List[torch.Tensor], |
| 138 | + init_tensor_: torch.Tensor |
| 139 | + ) -> None: |
| 140 | + self.queue = queue |
| 141 | + self.init_tensor_ = init_tensor_ |
| 142 | +
|
| 143 | + def push(self, tensor: torch.Tensor) -> None: |
| 144 | + self.queue.append(tensor) |
| 145 | +
|
| 146 | + def pop(self) -> torch.Tensor: |
| 147 | + if len(self.queue) > 0: |
| 148 | + return self.queue.pop(0) |
| 149 | + return self.init_tensor_ |
| 150 | +
|
| 151 | +**Step 2b**: Implement an ``__obj_unflatten__`` classmethod in Python. |
| 152 | + |
| 153 | +.. code-block:: python |
| 154 | +
|
| 155 | + # namespace::class_name |
| 156 | + @torch._library.register_fake_class("MyCustomClass::TensorQueue") |
| 157 | + class FakeTensorQueue: |
| 158 | + ... |
| 159 | + @classmethod |
| 160 | + def __obj_unflatten__(cls, flattened_tq): |
| 161 | + return cls(**dict(flattened_tq)) |
| 162 | +
|
| 163 | +
|
| 164 | +That’s it! Now we can create a module that uses this object and run it with ``torch.compile`` or ``torch.export``. |
| 165 | + |
| 166 | +.. code-block:: python |
| 167 | +
|
| 168 | + import torch |
| 169 | +
|
| 170 | + torch.classes.load_library("build/libcustom_class.so") |
| 171 | + tq = torch.classes.MyCustomClass.TensorQueue(torch.empty(0).fill_(-1)) |
| 172 | +
|
| 173 | + class Mod(torch.nn.Module): |
| 174 | + def forward(self, tq, x): |
| 175 | + tq.push(x.sin()) |
| 176 | + tq.push(x.cos()) |
| 177 | + poped_t = tq.pop() |
| 178 | + assert torch.allclose(poped_t, x.sin()) |
| 179 | + return tq, poped_t |
| 180 | +
|
| 181 | + tq, poped_t = torch.compile(Mod(), backend="eager", fullgraph=True)(tq, torch.randn(2, 3)) |
| 182 | + assert tq.size() == 1 |
| 183 | +
|
| 184 | + exported_program = torch.export.export(Mod(), (tq, torch.randn(2, 3),), strict=False) |
| 185 | + exported_program.module()(tq, torch.randn(2, 3)) |
| 186 | +
|
| 187 | +We can also implement custom ops that take custom classes as inputs. For |
| 188 | +example, we could register a custom op ``for_each_add_(tq, tensor)`` |
| 189 | + |
| 190 | +.. code-block:: cpp |
| 191 | +
|
| 192 | + struct TensorQueue : torch::CustomClassHolder { |
| 193 | + ... |
| 194 | + void for_each_add_(at::Tensor inc) { |
| 195 | + for (auto& t : queue_) { |
| 196 | + t.add_(inc); |
| 197 | + } |
| 198 | + } |
| 199 | + ... |
| 200 | + } |
| 201 | +
|
| 202 | +
|
| 203 | + TORCH_LIBRARY_FRAGMENT(MyCustomClass, m) { |
| 204 | + m.class_<TensorQueue>("TensorQueue") |
| 205 | + ... |
| 206 | + .def("for_each_add_", &TensorQueue::for_each_add_); |
| 207 | +
|
| 208 | + m.def( |
| 209 | + "for_each_add_(__torch__.torch.classes.MyCustomClass.TensorQueue foo, Tensor inc) -> ()"); |
| 210 | + } |
| 211 | +
|
| 212 | + void for_each_add_(c10::intrusive_ptr<TensorQueue> tq, at::Tensor inc) { |
| 213 | + tq->for_each_add_(inc); |
| 214 | + } |
| 215 | +
|
| 216 | + TORCH_LIBRARY_IMPL(MyCustomClass, CPU, m) { |
| 217 | + m.impl("for_each_add_", for_each_add_); |
| 218 | + } |
| 219 | +
|
| 220 | +
|
| 221 | +Since the fake class is implemented in python, we require the fake |
| 222 | +implementation of custom op must also be registered in python: |
| 223 | + |
| 224 | +.. code-block:: python |
| 225 | +
|
| 226 | + @torch.library.register_fake("MyCustomClass::for_each_add_") |
| 227 | + def fake_for_each_add_(tq, inc): |
| 228 | + tq.for_each_add_(inc) |
| 229 | +
|
| 230 | +After re-compilation, we can export the custom op with: |
| 231 | + |
| 232 | +.. code-block:: python |
| 233 | +
|
| 234 | + class ForEachAdd(torch.nn.Module): |
| 235 | + def forward(self, tq: torch.ScriptObject, a: torch.Tensor) -> torch.ScriptObject: |
| 236 | + torch.ops.MyCustomClass.for_each_add_(tq, a) |
| 237 | + return tq |
| 238 | +
|
| 239 | + mod = ForEachAdd() |
| 240 | + tq = empty_tensor_queue() |
| 241 | + qlen = 10 |
| 242 | + for i in range(qlen): |
| 243 | + tq.push(torch.zeros(1)) |
| 244 | +
|
| 245 | + ep = torch.export.export(mod, (tq, torch.ones(1)), strict=False) |
| 246 | +
|
| 247 | +Why do we need to make a Fake Class? |
| 248 | +------------------------------------ |
| 249 | + |
| 250 | +Tracing with real custom object has several major downsides: |
| 251 | + |
| 252 | +1. Operators on real objects can be time consuming e.g. the custom object |
| 253 | + might be reading from the network or loading data from the disk. |
| 254 | + |
| 255 | +2. We don’t want to mutate the real custom object or create side-effects to the environment while tracing. |
| 256 | + |
| 257 | +3. It cannot support dynamic shapes. |
| 258 | + |
| 259 | +However, it may be difficult for users to write a fake class, e.g. if the |
| 260 | +original class uses some third-party library that determines the output shape of |
| 261 | +the methods, or is complicated and written by others. In such cases, users can |
| 262 | +disable the fakification requirement by defining a ``tracing_mode`` method to |
| 263 | +return ``"real"``: |
| 264 | + |
| 265 | +.. code-block:: cpp |
| 266 | +
|
| 267 | + std::string tracing_mode() { |
| 268 | + return "real"; |
| 269 | + } |
| 270 | +
|
| 271 | +
|
| 272 | +A caveat of fakification is regarding **tensor aliasing.** We assume that no |
| 273 | +tensors within a torchbind object aliases a tensor outside of the torchbind |
| 274 | +object. Therefore, mutating one of these tensors will result in undefined |
| 275 | +behavior. |
0 commit comments