Skip to content

Commit a198274

Browse files
committed
Add custom class pt2 tutorial
1 parent a96b470 commit a198274

File tree

1 file changed

+275
-0
lines changed

1 file changed

+275
-0
lines changed

advanced_source/custom_class_pt2.rst

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
Supporting Custom C++ Classes in torch.compile/torch.export
2+
===========================================================
3+
4+
5+
This tutorial is a follow-on to the
6+
:doc:`custom C++ classes <torch_script_custom_classes>` tutorial, and
7+
introduces additional steps that are needed to support custom C++ classes in
8+
torch.compile/torch.export.
9+
10+
.. warning::
11+
12+
This feature is in prototype status and is subject to backwards compatibility
13+
breaking changes. This tutorial provides a snapshot as of PyTorch 2.8. If
14+
you run into any issues, please reach out to us on Github!
15+
16+
Concretely, there are a few steps:
17+
18+
1. Implement an ``__obj_flatten__`` method to the C++ custom class
19+
implementation to allow us to inspect its states and guard the changes. The
20+
method should return a tuple of tuple of attribute_name, value
21+
(``tuple[tuple[str, value] * n]``).
22+
23+
2. Register a python fake class using ``@torch._library.register_fake_class``
24+
25+
a. Implement “fake methods” of each of the class’s c++ methods, which should
26+
have the same schema as the C++ implementation.
27+
28+
b. Additionally, implement an ``__obj_unflatten__`` classmethod in the Python
29+
fake class to tell us how to create a fake class from the flattened
30+
states returned by ``__obj_flatten__``.
31+
32+
Here is a breakdown of the diff. Following the guide in
33+
:doc:`Extending TorchScript with Custom C++ Classes <torch_script_custom_classes>`,
34+
we can create a thread-safe tensor queue and build it.
35+
36+
.. code-block:: cpp
37+
38+
// Thread-safe Tensor Queue
39+
40+
#include <torch/custom_class.h>
41+
#include <torch/script.h>
42+
43+
#include <iostream>
44+
#include <string>
45+
#include <vector>
46+
47+
using namespace torch::jit;
48+
49+
// Thread-safe Tensor Queue
50+
struct TensorQueue : torch::CustomClassHolder {
51+
explicit TensorQueue(at::Tensor t) : init_tensor_(t) {}
52+
53+
explicit TensorQueue(c10::Dict<std::string, at::Tensor> dict) {
54+
init_tensor_ = dict.at(std::string("init_tensor"));
55+
const std::string key = "queue";
56+
at::Tensor size_tensor;
57+
size_tensor = dict.at(std::string(key + "/size")).cpu();
58+
const auto* size_tensor_acc = size_tensor.const_data_ptr<int64_t>();
59+
int64_t queue_size = size_tensor_acc[0];
60+
61+
for (const auto index : c10::irange(queue_size)) {
62+
at::Tensor val;
63+
queue_[index] = dict.at(key + "/" + std::to_string(index));
64+
queue_.push_back(val);
65+
}
66+
}
67+
68+
// Push the element to the rear of queue.
69+
// Lock is added for thread safe.
70+
void push(at::Tensor x) {
71+
std::lock_guard<std::mutex> guard(mutex_);
72+
queue_.push_back(x);
73+
}
74+
// Pop the front element of queue and return it.
75+
// If empty, return init_tensor_.
76+
// Lock is added for thread safe.
77+
at::Tensor pop() {
78+
std::lock_guard<std::mutex> guard(mutex_);
79+
if (!queue_.empty()) {
80+
auto val = queue_.front();
81+
queue_.pop_front();
82+
return val;
83+
} else {
84+
return init_tensor_;
85+
}
86+
}
87+
88+
std::vector<at::Tensor> get_raw_queue() {
89+
std::vector<at::Tensor> raw_queue(queue_.begin(), queue_.end());
90+
return raw_queue;
91+
}
92+
93+
private:
94+
std::deque<at::Tensor> queue_;
95+
std::mutex mutex_;
96+
at::Tensor init_tensor_;
97+
};
98+
99+
// The torch binding code
100+
TORCH_LIBRARY(MyCustomClass, m) {
101+
m.class_<TensorQueue>("TensorQueue")
102+
.def(torch::init<at::Tensor>())
103+
.def("push", &TensorQueue::push)
104+
.def("pop", &TensorQueue::pop)
105+
.def("get_raw_queue", &TensorQueue::get_raw_queue);
106+
}
107+
108+
**Step 1**: Add an ``__obj_flatten__`` method to the C++ custom class implementation:
109+
110+
.. code-block:: cpp
111+
112+
// Thread-safe Tensor Queue
113+
struct TensorQueue : torch::CustomClassHolder {
114+
...
115+
std::tuple<std::tuple<std::string, std::vector<at::Tensor>>, std::tuple<std::string, at::Tensor>> __obj_flatten__() {
116+
return std::tuple(std::tuple("queue", this->get_raw_queue()), std::tuple("init_tensor_", this->init_tensor_.clone()));
117+
}
118+
...
119+
};
120+
121+
TORCH_LIBRARY(MyCustomClass, m) {
122+
m.class_<TensorQueue>("TensorQueue")
123+
.def(torch::init<at::Tensor>())
124+
...
125+
.def("__obj_flatten__", &TensorQueue::__obj_flatten__);
126+
}
127+
128+
**Step 2a**: Register a fake class in Python that implements each method.
129+
130+
.. code-block:: python
131+
132+
# namespace::class_name
133+
@torch._library.register_fake_class("MyCustomClass::TensorQueue")
134+
class FakeTensorQueue:
135+
def __init__(
136+
self,
137+
queue: List[torch.Tensor],
138+
init_tensor_: torch.Tensor
139+
) -> None:
140+
self.queue = queue
141+
self.init_tensor_ = init_tensor_
142+
143+
def push(self, tensor: torch.Tensor) -> None:
144+
self.queue.append(tensor)
145+
146+
def pop(self) -> torch.Tensor:
147+
if len(self.queue) > 0:
148+
return self.queue.pop(0)
149+
return self.init_tensor_
150+
151+
**Step 2b**: Implement an ``__obj_unflatten__`` classmethod in Python.
152+
153+
.. code-block:: python
154+
155+
# namespace::class_name
156+
@torch._library.register_fake_class("MyCustomClass::TensorQueue")
157+
class FakeTensorQueue:
158+
...
159+
@classmethod
160+
def __obj_unflatten__(cls, flattened_tq):
161+
return cls(**dict(flattened_tq))
162+
163+
164+
That’s it! Now we can create a module that uses this object and run it with ``torch.compile`` or ``torch.export``.
165+
166+
.. code-block:: python
167+
168+
import torch
169+
170+
torch.classes.load_library("build/libcustom_class.so")
171+
tq = torch.classes.MyCustomClass.TensorQueue(torch.empty(0).fill_(-1))
172+
173+
class Mod(torch.nn.Module):
174+
def forward(self, tq, x):
175+
tq.push(x.sin())
176+
tq.push(x.cos())
177+
poped_t = tq.pop()
178+
assert torch.allclose(poped_t, x.sin())
179+
return tq, poped_t
180+
181+
tq, poped_t = torch.compile(Mod(), backend="eager", fullgraph=True)(tq, torch.randn(2, 3))
182+
assert tq.size() == 1
183+
184+
exported_program = torch.export.export(Mod(), (tq, torch.randn(2, 3),), strict=False)
185+
exported_program.module()(tq, torch.randn(2, 3))
186+
187+
We can also implement custom ops that take custom classes as inputs. For
188+
example, we could register a custom op ``for_each_add_(tq, tensor)``
189+
190+
.. code-block:: cpp
191+
192+
struct TensorQueue : torch::CustomClassHolder {
193+
...
194+
void for_each_add_(at::Tensor inc) {
195+
for (auto& t : queue_) {
196+
t.add_(inc);
197+
}
198+
}
199+
...
200+
}
201+
202+
203+
TORCH_LIBRARY_FRAGMENT(MyCustomClass, m) {
204+
m.class_<TensorQueue>("TensorQueue")
205+
...
206+
.def("for_each_add_", &TensorQueue::for_each_add_);
207+
208+
m.def(
209+
"for_each_add_(__torch__.torch.classes.MyCustomClass.TensorQueue foo, Tensor inc) -> ()");
210+
}
211+
212+
void for_each_add_(c10::intrusive_ptr<TensorQueue> tq, at::Tensor inc) {
213+
tq->for_each_add_(inc);
214+
}
215+
216+
TORCH_LIBRARY_IMPL(MyCustomClass, CPU, m) {
217+
m.impl("for_each_add_", for_each_add_);
218+
}
219+
220+
221+
Since the fake class is implemented in python, we require the fake
222+
implementation of custom op must also be registered in python:
223+
224+
.. code-block:: python
225+
226+
@torch.library.register_fake("MyCustomClass::for_each_add_")
227+
def fake_for_each_add_(tq, inc):
228+
tq.for_each_add_(inc)
229+
230+
After re-compilation, we can export the custom op with:
231+
232+
.. code-block:: python
233+
234+
class ForEachAdd(torch.nn.Module):
235+
def forward(self, tq: torch.ScriptObject, a: torch.Tensor) -> torch.ScriptObject:
236+
torch.ops.MyCustomClass.for_each_add_(tq, a)
237+
return tq
238+
239+
mod = ForEachAdd()
240+
tq = empty_tensor_queue()
241+
qlen = 10
242+
for i in range(qlen):
243+
tq.push(torch.zeros(1))
244+
245+
ep = torch.export.export(mod, (tq, torch.ones(1)), strict=False)
246+
247+
Why do we need to make a Fake Class?
248+
------------------------------------
249+
250+
Tracing with real custom object has several major downsides:
251+
252+
1. Operators on real objects can be time consuming e.g. the custom object
253+
might be reading from the network or loading data from the disk.
254+
255+
2. We don’t want to mutate the real custom object or create side-effects to the environment while tracing.
256+
257+
3. It cannot support dynamic shapes.
258+
259+
However, it may be difficult for users to write a fake class, e.g. if the
260+
original class uses some third-party library that determines the output shape of
261+
the methods, or is complicated and written by others. In such cases, users can
262+
disable the fakification requirement by defining a ``tracing_mode`` method to
263+
return ``"real"``:
264+
265+
.. code-block:: cpp
266+
267+
std::string tracing_mode() {
268+
return "real";
269+
}
270+
271+
272+
A caveat of fakification is regarding **tensor aliasing.** We assume that no
273+
tensors within a torchbind object aliases a tensor outside of the torchbind
274+
object. Therefore, mutating one of these tensors will result in undefined
275+
behavior.

0 commit comments

Comments
 (0)