Skip to content

Commit febaf29

Browse files
author
zhangyuncong
committed
support deepcopy & copy & getstate & setstate & pickle
1 parent 1cfa795 commit febaf29

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

src/fastpb/template/module.jinjacc

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,23 @@ namespace {
203203
return PyString_FromStringAndSize(result.data(), result.length());
204204
}
205205

206+
PyObject *
207+
{{ message.name }}_Copy({{ message.name }}* self)
208+
{
209+
{{ message.name }}* cloned = NULL;
210+
Py_BEGIN_ALLOW_THREADS
211+
cloned = ({{ message.name }}*){{ message.name }}_new(&{{ message.name }}Type, NULL, NULL);
212+
cloned->protobuf->CopyFrom(*self->protobuf);
213+
Py_END_ALLOW_THREADS
214+
return (PyObject*)cloned;
215+
}
216+
217+
PyObject *
218+
{{ message.name }}_DeepCopy({{ message.name }}* self, PyObject* /*memo*/)
219+
{
220+
return {{ message.name }}_Copy(self);
221+
}
222+
206223

207224
PyObject *
208225
{{ message.name }}_ParseFromString({{ message.name }}* self, PyObject *value)
@@ -214,6 +231,18 @@ namespace {
214231
Py_RETURN_NONE;
215232
}
216233

234+
PyObject *
235+
{{ message.name }}_Reduce({{ message.name }}* self)
236+
{
237+
PyObject* ret = PyTuple_New(3);
238+
PyObject* type_object = (PyObject*)Py_TYPE(self);
239+
Py_INCREF(type_object);
240+
PyObject* state = {{ message.name }}_SerializeToString(self);
241+
PyTuple_SetItem(ret, 0, type_object);
242+
PyTuple_SetItem(ret, 1, PyTuple_New(0));
243+
PyTuple_SetItem(ret, 2, state);
244+
return ret;
245+
}
217246

218247
PyObject *
219248
{{ message.name }}_ParseFromLongString({{ message.name }}* self, PyObject *value)
@@ -658,10 +687,23 @@ namespace {
658687
{"ParseMany", (PyCFunction){{ message.name }}_ParseMany, METH_VARARGS | METH_CLASS,
659688
"Parses many protocol buffers of this type from a string."
660689
},
690+
{"__copy__", (PyCFunction){{ message.name }}_DeepCopy, METH_NOARGS,
691+
"copy a pb message."
692+
},
693+
{"__deepcopy__", (PyCFunction){{ message.name }}_DeepCopy, METH_O,
694+
"deep copy a pb message."
695+
},
696+
{"__getstate__", (PyCFunction){{ message.name }}_SerializeToString, METH_NOARGS,
697+
"support getstate"
698+
},
699+
{"__setstate__", (PyCFunction){{ message.name }}_ParseFromString, METH_O,
700+
"support setstate"
701+
},
702+
{"__reduce__", (PyCFunction){{ message.name }}_Reduce, METH_NOARGS,
703+
"support pickle"
704+
},
661705
{NULL} // Sentinel
662706
};
663-
664-
665707
PyTypeObject {{ message.name }}Type = {
666708
PyObject_HEAD_INIT(NULL)
667709
0, /*ob_size*/
@@ -688,9 +730,9 @@ namespace {
688730
0, /* tp_traverse */
689731
0, /* tp_clear */
690732
{{ message.name }}_richcompare, /* tp_richcompare */
691-
0, /* tp_weaklistoffset */
692-
0, /* tp_iter */
693-
0, /* tp_iternext */
733+
0, /* tp_weaklistoffset */
734+
0, /* tp_iter */
735+
0, /* tp_iternext */
694736
{{ message.name }}_methods, /* tp_methods */
695737
{{ message.name }}_members, /* tp_members */
696738
{{ message.name }}_getsetters, /* tp_getset */
@@ -711,7 +753,7 @@ static PyMethodDef module_methods[] = {
711753
{NULL} // Sentinel
712754
};
713755

714-
#ifndef PyMODINIT_FUNC // Declarations for DLL import/export.
756+
#ifndef PyMODINIT_FUNC // Declarations for DLL import/export.
715757
#define PyMODINIT_FUNC void
716758
#endif
717759
PyMODINIT_FUNC

src/fastpb/template/test.jinjapy

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"""Auto-generated unit tests."""
33

44
import unittest
5+
import copy
6+
import cPickle
57

68
{% for file in files %}
79
import {{ file.package }}
@@ -54,6 +56,15 @@ class Test_{{ file.package.replace('.', '_') }}(unittest.TestCase):
5456
{% for field in message.field %}
5557
self.assertEquals(pb.{{ field.name }}, pb2.{{ field.name }})
5658
{% endfor %}
59+
60+
pb3 = copy.deepcopy(pb)
61+
pb4 = cPickle.loads(cPickle.dumps(pb3))
62+
63+
{% for field in message.field %}
64+
self.assertEquals(pb.{{ field.name }}, pb3.{{ field.name }})
65+
self.assertEquals(pb.{{ field.name }}, pb4.{{ field.name }})
66+
{% endfor %}
67+
5768
{% endfor %}
5869
{% endfor %}
5970

0 commit comments

Comments
 (0)