Skip to content

Commit

Permalink
Improved shallow and deep copy functions (#103)
Browse files Browse the repository at this point in the history
* Improved shallow and deep copy functions

* Handle cyclic references when deepcopying

The storage format allows a container to contain itself (although this shouldn't be used)
This fixes a stack overflow crash in deep_copy if this was the case.
It now throws an exception.
  • Loading branch information
gentlegiantJGC authored Feb 6, 2025
1 parent d533d86 commit d316532
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 126 deletions.
135 changes: 63 additions & 72 deletions src/amulet_nbt/cpp/tag/copy.cpp
Original file line number Diff line number Diff line change
@@ -1,93 +1,84 @@
#include <vector>
#include <memory>
#include <set>
#include <stdexcept>
#include <type_traits>
#include <utility>
#include <variant>
#include <type_traits>
#include <stdexcept>
#include <vector>

#include <amulet_nbt/export.hpp>
#include <amulet_nbt/tag/int.hpp>
#include <amulet_nbt/tag/array.hpp>
#include <amulet_nbt/tag/compound.hpp>
#include <amulet_nbt/tag/float.hpp>
#include <amulet_nbt/tag/string.hpp>
#include <amulet_nbt/tag/int.hpp>
#include <amulet_nbt/tag/list.hpp>
#include <amulet_nbt/tag/compound.hpp>
#include <amulet_nbt/tag/array.hpp>
#include <amulet_nbt/tag/string.hpp>

#include <amulet_nbt/tag/copy.hpp>


namespace AmuletNBT {
template <
typename T,
std::enable_if_t<
std::is_same_v<T, AmuletNBT::ListTagPtr> ||
std::is_same_v<T, AmuletNBT::CompoundTagPtr> ||
std::is_same_v<T, AmuletNBT::ByteArrayTagPtr> ||
std::is_same_v<T, AmuletNBT::IntArrayTagPtr> ||
std::is_same_v<T, AmuletNBT::LongArrayTagPtr>,
bool
> = true
>
AmuletNBT::ListTagPtr NBTTag_deep_copy_list_vector(const std::vector<T>&tag) {
AmuletNBT::ListTagPtr new_tag = std::make_shared<AmuletNBT::ListTag>(std::in_place_type<std::vector<T>>);
std::vector<T>& new_vector = std::get<std::vector<T>>(*new_tag);
for (T value : tag) {
if constexpr (std::is_same_v<T, AmuletNBT::ListTagPtr>) {
new_vector.push_back(NBTTag_deep_copy_list(*value));
}
else if constexpr (std::is_same_v<T, AmuletNBT::CompoundTagPtr>) {
new_vector.push_back(NBTTag_deep_copy_compound(*value));
}
else {
new_vector.push_back(NBTTag_copy<typename T::element_type>(*value));
}
}
return new_tag;

template <typename T>
AmuletNBT::ListTag deep_copy_list_vector(const std::vector<T>& vec, std::set<size_t>& memo)
{
std::vector<T> new_vector;
new_vector.reserve(vec.size());
for (const T& value : vec) {
new_vector.push_back(deep_copy_2(value, memo));
}
return new_vector;
}

AmuletNBT::ListTagPtr NBTTag_deep_copy_list(const AmuletNBT::ListTag& tag) {
return std::visit([](auto&& list) {
AmuletNBT::ListTag deep_copy_2(const AmuletNBT::ListTag& tag, std::set<size_t>& memo)
{
auto ptr = reinterpret_cast<size_t>(&tag);
if (memo.contains(ptr)) {
throw std::runtime_error("ListTag cannot contain itself.");
}
memo.insert(ptr);
auto new_tag = std::visit(
[&memo](auto&& list) -> AmuletNBT::ListTag {
using T = std::decay_t<decltype(list)>;
if constexpr (std::is_same_v<T, std::monostate>) {
return std::make_shared<AmuletNBT::ListTag>();
}
else if constexpr (is_shared_ptr<typename T::value_type>::value) {
return NBTTag_deep_copy_list_vector(list);
}
else {
return std::make_shared<AmuletNBT::ListTag>(list);
return AmuletNBT::ListTag();
} else if constexpr (is_shared_ptr<typename T::value_type>::value) {
return deep_copy_list_vector(list, memo);
} else {
return list;
}
}, tag);
}
},
tag);
memo.erase(ptr);
return new_tag;
}

AmuletNBT::TagNode NBTTag_deep_copy_node(const AmuletNBT::TagNode& node) {
return std::visit([](auto&& tag) -> AmuletNBT::TagNode {
using T = std::decay_t<decltype(tag)>;
if constexpr (std::is_same_v<T, AmuletNBT::ListTagPtr>) {
return NBTTag_deep_copy_list(*tag);
}
else if constexpr (std::is_same_v<T, AmuletNBT::CompoundTagPtr>) {
return NBTTag_deep_copy_compound(*tag);
}
else if constexpr (
std::is_same_v<T, AmuletNBT::ByteArrayTagPtr> ||
std::is_same_v<T, AmuletNBT::IntArrayTagPtr> ||
std::is_same_v<T, AmuletNBT::LongArrayTagPtr>
) {
return NBTTag_copy(*tag);
}
else {
return tag;
}
}, node);
AmuletNBT::CompoundTag deep_copy_2(const AmuletNBT::CompoundTag& tag, std::set<size_t>& memo)
{
auto ptr = reinterpret_cast<size_t>(&tag);
if (memo.contains(ptr)) {
throw std::runtime_error("CompoundTag cannot contain itself.");
}

AmuletNBT::CompoundTagPtr NBTTag_deep_copy_compound(const AmuletNBT::CompoundTag& tag) {
auto new_tag = std::make_shared<AmuletNBT::CompoundTag>();
for (auto& [key, value] : tag) {
(*new_tag)[key] = NBTTag_deep_copy_node(value);
}
return new_tag;
memo.insert(ptr);
AmuletNBT::CompoundTag new_tag;
for (auto& [key, value] : tag) {
new_tag.emplace(key, deep_copy_2(value, memo));
}
memo.erase(ptr);
return new_tag;
}

AmuletNBT::TagNode deep_copy_2(const AmuletNBT::TagNode& node, std::set<size_t>& memo)
{
return std::visit(
[&memo](auto&& tag) -> AmuletNBT::TagNode {
return deep_copy_2(tag, memo);
},
node);
}

AmuletNBT::NamedTag deep_copy_2(const AmuletNBT::NamedTag& named_tag, std::set<size_t>& memo)
{
return { named_tag.name, deep_copy_2(named_tag.tag_node, memo) };
}

} // namespace AmuletNBT
119 changes: 79 additions & 40 deletions src/amulet_nbt/include/amulet_nbt/tag/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,89 @@
#include <memory>
#include <type_traits>
#include <variant>
#include <set>

#include <amulet_nbt/export.hpp>
#include <amulet_nbt/common.hpp>
#include <amulet_nbt/tag/int.hpp>
#include <amulet_nbt/export.hpp>
#include <amulet_nbt/tag/array.hpp>
#include <amulet_nbt/tag/compound.hpp>
#include <amulet_nbt/tag/float.hpp>
#include <amulet_nbt/tag/string.hpp>
#include <amulet_nbt/tag/int.hpp>
#include <amulet_nbt/tag/list.hpp>
#include <amulet_nbt/tag/compound.hpp>
#include <amulet_nbt/tag/array.hpp>
#include <amulet_nbt/tag/named_tag.hpp>
#include <amulet_nbt/tag/string.hpp>

namespace AmuletNBT {
template <
typename T,
std::enable_if_t<
std::is_same_v<T, AmuletNBT::ByteTag> ||
std::is_same_v<T, AmuletNBT::ShortTag> ||
std::is_same_v<T, AmuletNBT::IntTag> ||
std::is_same_v<T, AmuletNBT::LongTag> ||
std::is_same_v<T, AmuletNBT::FloatTag> ||
std::is_same_v<T, AmuletNBT::DoubleTag> ||
std::is_same_v<T, AmuletNBT::StringTag>,
bool
> = true
>
inline T NBTTag_copy(const T & tag) {
return tag;
}

template <
typename T,
std::enable_if_t<
std::is_same_v<T, AmuletNBT::ListTag> ||
std::is_same_v<T, AmuletNBT::CompoundTag> ||
std::is_same_v<T, AmuletNBT::ByteArrayTag> ||
std::is_same_v<T, AmuletNBT::IntArrayTag> ||
std::is_same_v<T, AmuletNBT::LongArrayTag>,
bool
> = true
>
inline std::shared_ptr<T> NBTTag_copy(const T& tag){
return std::make_shared<T>(tag);
}

AMULET_NBT_EXPORT AmuletNBT::ListTagPtr NBTTag_deep_copy_list(const AmuletNBT::ListTag& tag);
AMULET_NBT_EXPORT AmuletNBT::TagNode NBTTag_deep_copy_node(const AmuletNBT::TagNode& tag);
AMULET_NBT_EXPORT AmuletNBT::CompoundTagPtr NBTTag_deep_copy_compound(const AmuletNBT::CompoundTag& tag);

template <typename T>
requires std::is_same_v<T, AmuletNBT::ByteTag>
|| std::is_same_v<T, AmuletNBT::ShortTag>
|| std::is_same_v<T, AmuletNBT::IntTag>
|| std::is_same_v<T, AmuletNBT::LongTag>
|| std::is_same_v<T, AmuletNBT::FloatTag>
|| std::is_same_v<T, AmuletNBT::DoubleTag>
|| std::is_same_v<T, AmuletNBT::StringTag>
|| std::is_same_v<T, AmuletNBT::ListTag>
|| std::is_same_v<T, AmuletNBT::CompoundTag>
|| std::is_same_v<T, AmuletNBT::ByteArrayTag>
|| std::is_same_v<T, AmuletNBT::IntArrayTag>
|| std::is_same_v<T, AmuletNBT::LongArrayTag>
|| std::is_same_v<T, AmuletNBT::TagNode>
|| std::is_same_v<T, AmuletNBT::NamedTag>
T shallow_copy(const T& tag)
{
return tag;
}

template <typename T>
std::unique_ptr<T> shallow_copy(const std::unique_ptr<T>& tag)
{
return std::make_unique<T>(shallow_copy(*tag));
}

template <typename T>
std::shared_ptr<T> shallow_copy(const std::shared_ptr<T>& tag)
{
return std::make_shared<T>(shallow_copy(*tag));
}

template <typename T>
requires std::is_same_v<T, AmuletNBT::ByteTag>
|| std::is_same_v<T, AmuletNBT::ShortTag>
|| std::is_same_v<T, AmuletNBT::IntTag>
|| std::is_same_v<T, AmuletNBT::LongTag>
|| std::is_same_v<T, AmuletNBT::FloatTag>
|| std::is_same_v<T, AmuletNBT::DoubleTag>
|| std::is_same_v<T, AmuletNBT::StringTag>
|| std::is_same_v<T, AmuletNBT::ByteArrayTag>
|| std::is_same_v<T, AmuletNBT::IntArrayTag>
|| std::is_same_v<T, AmuletNBT::LongArrayTag>
T deep_copy_2(const T& tag, std::set<size_t>& memo)
{
return tag;
}

AMULET_NBT_EXPORT AmuletNBT::ListTag deep_copy_2(const AmuletNBT::ListTag&, std::set<size_t>& memo);
AMULET_NBT_EXPORT AmuletNBT::CompoundTag deep_copy_2(const AmuletNBT::CompoundTag&, std::set<size_t>& memo);
AMULET_NBT_EXPORT AmuletNBT::TagNode deep_copy_2(const AmuletNBT::TagNode&, std::set<size_t>& memo);
AMULET_NBT_EXPORT AmuletNBT::NamedTag deep_copy_2(const AmuletNBT::NamedTag&, std::set<size_t>& memo);

template <typename T>
std::unique_ptr<T> deep_copy_2(const std::unique_ptr<T>& tag, std::set<size_t>& memo)
{
return std::make_unique<T>(deep_copy_2(*tag, memo));
}

template <typename T>
std::shared_ptr<T> deep_copy_2(const std::shared_ptr<T>& tag, std::set<size_t>& memo)
{
return std::make_shared<T>(deep_copy_2(*tag, memo));
}

template <typename T>
auto deep_copy(const T& obj) {
std::set<size_t> memo;
return deep_copy_2(obj, memo);
}

} // namespace AmuletNBT
4 changes: 2 additions & 2 deletions src/amulet_nbt/pybind/tag/py_array_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ namespace py = pybind11;
CLSNAME.def(\
"__copy__",\
[](const AmuletNBT::CLSNAME& self){\
return NBTTag_copy<AmuletNBT::CLSNAME>(self);\
return shallow_copy(self);\
}\
);\
CLSNAME.def(\
"__deepcopy__",\
[](const AmuletNBT::CLSNAME& self, py::dict){\
return AmuletNBT::NBTTag_copy<AmuletNBT::CLSNAME>(self);\
return deep_copy(self);\
},\
py::arg("memo")\
);\
Expand Down
4 changes: 2 additions & 2 deletions src/amulet_nbt/pybind/tag/py_compound_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,13 @@ void init_compound(py::module& m) {
CompoundTag.def(
"__copy__",
[](const AmuletNBT::CompoundTag& self){
return NBTTag_copy<AmuletNBT::CompoundTag>(self);
return shallow_copy(self);
}
);
CompoundTag.def(
"__deepcopy__",
[](const AmuletNBT::CompoundTag& self, py::dict){
return AmuletNBT::NBTTag_deep_copy_compound(self);
return deep_copy(self);
},
py::arg("memo")
);
Expand Down
5 changes: 3 additions & 2 deletions src/amulet_nbt/pybind/tag/py_float_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <amulet_nbt/tag/abc.hpp>
#include <amulet_nbt/tag/float.hpp>
#include <amulet_nbt/tag/copy.hpp>
#include <amulet_nbt/nbt_encoding/binary.hpp>
#include <amulet_nbt/nbt_encoding/string.hpp>
#include <amulet_nbt/pybind/serialisation.hpp>
Expand Down Expand Up @@ -80,13 +81,13 @@ namespace py = pybind11;
CLSNAME.def(\
"__copy__",\
[](const AmuletNBT::CLSNAME& self){\
return self;\
return shallow_copy(self);\
}\
);\
CLSNAME.def(\
"__deepcopy__",\
[](const AmuletNBT::CLSNAME& self, py::dict){\
return self;\
return deep_copy(self);\
},\
py::arg("memo")\
);\
Expand Down
5 changes: 3 additions & 2 deletions src/amulet_nbt/pybind/tag/py_int_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <amulet_nbt/tag/abc.hpp>
#include <amulet_nbt/tag/int.hpp>
#include <amulet_nbt/tag/copy.hpp>
#include <amulet_nbt/nbt_encoding/binary.hpp>
#include <amulet_nbt/nbt_encoding/string.hpp>
#include <amulet_nbt/pybind/serialisation.hpp>
Expand Down Expand Up @@ -84,13 +85,13 @@ namespace py = pybind11;
CLSNAME.def(\
"__copy__",\
[](const AmuletNBT::CLSNAME& self){\
return self;\
return shallow_copy(self);\
}\
);\
CLSNAME.def(\
"__deepcopy__",\
[](const AmuletNBT::CLSNAME& self, py::dict){\
return self;\
return deep_copy(self);\
},\
py::arg("memo")\
);\
Expand Down
4 changes: 2 additions & 2 deletions src/amulet_nbt/pybind/tag/py_list_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,13 @@ void init_list(py::module& m) {
ListTag.def(
"__copy__",
[](const AmuletNBT::ListTag& self){
return NBTTag_copy<AmuletNBT::ListTag>(self);
return shallow_copy(self);
}
);
ListTag.def(
"__deepcopy__",
[](const AmuletNBT::ListTag& self, py::dict){
return AmuletNBT::NBTTag_deep_copy_list(self);
return deep_copy(self);
},
py::arg("memo")
);
Expand Down
4 changes: 2 additions & 2 deletions src/amulet_nbt/pybind/tag/py_named_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,13 @@ void init_named_tag(py::module& m) {
NamedTag.def(
"__copy__",
[](const AmuletNBT::NamedTag& self){
return self;
return shallow_copy(self);
}
);
NamedTag.def(
"__deepcopy__",
[](const AmuletNBT::NamedTag& self, py::dict){
return AmuletNBT::NamedTag(self.name, AmuletNBT::NBTTag_deep_copy_node(self.tag_node));
return deep_copy(self);
},
py::arg("memo")
);
Expand Down
Loading

0 comments on commit d316532

Please sign in to comment.