Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/google/protobuf/compiler/cpp/file_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ TEST(FileTest, TopologicallyOrderedDescriptors) {
"TestCommentInjectionMessage",
"TestChildExtensionData.NestedTestAllExtensionsData."
"NestedDynamicExtensions",
"TestAllTypesAsExtension",
"TestAllTypes.RepeatedGroup",
"TestAllTypes.OptionalGroup",
"TestAllTypes.NestedMessage",
Expand Down
32 changes: 32 additions & 0 deletions src/google/protobuf/extension_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,23 @@ class PROTOBUF_EXPORT ExtensionSet {
}
#endif

// Moves an extension from one ExtensionSet to another.
//
// If the source extension does not exist, then destination extension is
// cleared.
//
// If the destination extension already exists, it is overwritten otherwise
// it is created and then moved.
bool MoveExtension(Arena* arena, int dst_number, ExtensionSet& src,
int src_number);

bool IsLazy(int number) const {
const Extension* extension = FindOrNull(number);
return extension != nullptr && extension->is_lazy;
}

LazyField* TryGetLazyField(Arena* arena, int number, FieldType type);

private:
template <typename Type>
friend class PrimitiveTypeTraits;
Expand Down Expand Up @@ -727,6 +744,8 @@ class PROTOBUF_EXPORT ExtensionSet {
io::EpsCopyOutputStream* stream) const = 0;


virtual LazyField* GetUnderlyingField() = 0;

private:
virtual void UnusedKeyMethod(); // Dummy key method to avoid weak vtable.
};
Expand Down Expand Up @@ -1853,6 +1872,19 @@ class ExtensionIdentifier {
typename TypeTraits::InitType default_value_;
};

template <typename ExtendeeType, typename TypeTraitsType,
internal::FieldType field_type, bool is_packed>
auto TryGetLazyMessageFromExtensionSet(
Arena* arena,
const google::protobuf::internal::ExtensionIdentifier<
ExtendeeType, TypeTraitsType, field_type, is_packed>& extension,
ExtensionSet& set) {
static_assert(std::is_base_of_v<
MessageLite,
std::decay_t<typename TypeTraitsType::Singular::ConstType>>);
return set.TryGetLazyField(arena, extension.number(), field_type);
}

// -------------------------------------------------------------------
// Generated accessors

Expand Down
33 changes: 33 additions & 0 deletions src/google/protobuf/extension_set_heavy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

#include <cstddef>
#include <cstdint>
#include <cstring>
#include <initializer_list>
#include <string>
#include <utility>
#include <variant>
#include <vector>

Expand Down Expand Up @@ -311,6 +313,37 @@ bool ExtensionSet::FindExtension(int wire_type, uint32_t field,
}


bool ExtensionSet::MoveExtension(Arena* arena, int dst_number,
ExtensionSet& src, int src_number) {
// Find the source extension & return if it doesn't exist.
Extension* src_ext = src.FindOrNull(src_number);
if (src_ext == nullptr) {
ClearExtension(dst_number);
return true;
}

if (src_ext->descriptor != nullptr) {
return false;
}

// Get or create the destination extension.
auto [dst_ext, is_new] = Insert(arena, dst_number);
if (!is_new) {
// If an extension already exists at dst_number, free it if not using an
// arena.
if (arena == nullptr) {
dst_ext->Free();
}
}

// Move the extension from the source to the destination.
*dst_ext = std::move(*src_ext);

// Erase the extension from the source.
src.Erase(src_number);
return true;
}

const char* ExtensionSet::ParseField(uint64_t tag, const char* ptr,
const Message* extendee,
internal::InternalMetadata* metadata,
Expand Down
262 changes: 262 additions & 0 deletions src/google/protobuf/extension_set_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "google/protobuf/cpp_features.pb.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/generated_message_util.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/message_lite.h"
Expand Down Expand Up @@ -1545,6 +1546,267 @@ TEST(ExtensionSetTest, Descriptor) {
EXPECT_NE(GetExtensionReflection(pb::cpp), nullptr);
}

TEST(ExtensionSetTest, MoveExtension) {
unittest::TestAllExtensions src;
src.SetExtension(unittest::optional_int32_extension, 101);
src.SetExtension(unittest::optional_string_extension, "123");
*src.MutableExtension(unittest::optional_foreign_message_extension) =
unittest::ForeignMessage();
src.AddExtension(unittest::repeated_int32_extension, 201);
src.AddExtension(unittest::repeated_int32_extension, 202);

unittest::TestAllExtensions dst;
ExtensionSet& set1 = PrivateAccess::GetExtensionSet(src);
ExtensionSet& set2 = PrivateAccess::GetExtensionSet(dst);

// Move fields from set1 to set2
EXPECT_TRUE(
set2.MoveExtension(nullptr, unittest::optional_int32_extension.number(),
set1, unittest::optional_int32_extension.number()));
EXPECT_TRUE(
set2.MoveExtension(nullptr, unittest::optional_string_extension.number(),
set1, unittest::optional_string_extension.number()));
EXPECT_TRUE(set2.MoveExtension(
nullptr, unittest::optional_foreign_message_extension.number(), set1,
unittest::optional_foreign_message_extension.number()));
EXPECT_TRUE(
set2.MoveExtension(nullptr, unittest::repeated_int32_extension.number(),
set1, unittest::repeated_int32_extension.number()));

// Verify int32 extension
EXPECT_FALSE(set1.Has(unittest::optional_int32_extension.number()));
EXPECT_TRUE(set2.Has(unittest::optional_int32_extension.number()));
EXPECT_TRUE(dst.HasExtension(unittest::optional_int32_extension));
EXPECT_EQ(dst.GetExtension(unittest::optional_int32_extension), 101);

// Verify string extension
EXPECT_FALSE(set1.Has(unittest::optional_string_extension.number()));
EXPECT_TRUE(set2.Has(unittest::optional_string_extension.number()));
EXPECT_TRUE(dst.HasExtension(unittest::optional_string_extension));
EXPECT_EQ(dst.GetExtension(unittest::optional_string_extension), "123");

// Verify foreign message extension
EXPECT_FALSE(set1.Has(unittest::optional_foreign_message_extension.number()));
EXPECT_TRUE(set2.Has(unittest::optional_foreign_message_extension.number()));
EXPECT_TRUE(dst.HasExtension(unittest::optional_foreign_message_extension));
EXPECT_EQ(dst.GetExtension(unittest::optional_foreign_message_extension).c(),
0); // Default value

// Verify repeated int32 extension
EXPECT_EQ(set1.ExtensionSize(unittest::repeated_int32_extension.number()), 0);
EXPECT_EQ(set2.ExtensionSize(unittest::repeated_int32_extension.number()), 2);
EXPECT_THAT(dst.GetRepeatedExtension(unittest::repeated_int32_extension),
testing::ElementsAre(201, 202));

// Test moving a non-existent field.
int non_existent_number = 99999;
EXPECT_TRUE(set2.MoveExtension(nullptr, non_existent_number, set1,
non_existent_number));
EXPECT_FALSE(set1.Has(non_existent_number));
EXPECT_FALSE(set2.Has(non_existent_number));

// Test moving to an existing field.
// Set a different value in set1 for the string extension.
src.SetExtension(unittest::optional_string_extension, "999");
EXPECT_TRUE(set1.Has(unittest::optional_string_extension.number()));
EXPECT_EQ(src.GetExtension(unittest::optional_string_extension), "999");

// Move the original value from set2 back to set1.
EXPECT_TRUE(
set1.MoveExtension(nullptr, unittest::optional_string_extension.number(),
set2, unittest::optional_string_extension.number()));
EXPECT_FALSE(set2.Has(unittest::optional_string_extension.number()));
EXPECT_TRUE(set1.Has(unittest::optional_string_extension.number()));
EXPECT_EQ(src.GetExtension(unittest::optional_string_extension), "123");
}

TEST(ExtensionSetTest, MoveExtensionWithArena) {
Arena arena;
auto* src = Arena::Create<unittest::TestAllExtensions>(&arena);
src->SetExtension(unittest::optional_int32_extension, 101);
src->SetExtension(unittest::optional_string_extension, "123");
*src->MutableExtension(unittest::optional_foreign_message_extension) =
unittest::ForeignMessage();
src->AddExtension(unittest::repeated_int32_extension, 201);
src->AddExtension(unittest::repeated_int32_extension, 202);

auto* dst = Arena::Create<unittest::TestAllExtensions>(&arena);
ExtensionSet& set1 = PrivateAccess::GetExtensionSet(*src);
ExtensionSet& set2 = PrivateAccess::GetExtensionSet(*dst);
EXPECT_TRUE(
set2.MoveExtension(&arena, unittest::optional_int32_extension.number(),
set1, unittest::optional_int32_extension.number()));
EXPECT_TRUE(
set2.MoveExtension(&arena, unittest::optional_string_extension.number(),
set1, unittest::optional_string_extension.number()));
EXPECT_TRUE(set2.MoveExtension(
&arena, unittest::optional_foreign_message_extension.number(), set1,
unittest::optional_foreign_message_extension.number()));
EXPECT_TRUE(
set2.MoveExtension(&arena, unittest::repeated_int32_extension.number(),
set1, unittest::repeated_int32_extension.number()));

// Verify int32 extension
EXPECT_FALSE(set1.Has(unittest::optional_int32_extension.number()));
EXPECT_TRUE(set2.Has(unittest::optional_int32_extension.number()));
EXPECT_TRUE(dst->HasExtension(unittest::optional_int32_extension));
EXPECT_EQ(dst->GetExtension(unittest::optional_int32_extension), 101);

// Verify string extension
EXPECT_FALSE(set1.Has(unittest::optional_string_extension.number()));
EXPECT_TRUE(set2.Has(unittest::optional_string_extension.number()));
EXPECT_TRUE(dst->HasExtension(unittest::optional_string_extension));
EXPECT_EQ(dst->GetExtension(unittest::optional_string_extension), "123");

// Verify foreign message extension
EXPECT_FALSE(set1.Has(unittest::optional_foreign_message_extension.number()));
EXPECT_TRUE(set2.Has(unittest::optional_foreign_message_extension.number()));
EXPECT_TRUE(dst->HasExtension(unittest::optional_foreign_message_extension));
EXPECT_EQ(dst->GetExtension(unittest::optional_foreign_message_extension).c(),
0); // Default value

// Verify repeated int32 extension
EXPECT_EQ(set1.ExtensionSize(unittest::repeated_int32_extension.number()), 0);
EXPECT_EQ(set2.ExtensionSize(unittest::repeated_int32_extension.number()), 2);
EXPECT_THAT(dst->GetRepeatedExtension(unittest::repeated_int32_extension),
testing::ElementsAre(201, 202));

// Test moving a non-existent field.
int non_existent_number = 99999;
EXPECT_TRUE(set2.MoveExtension(&arena, non_existent_number, set1,
non_existent_number));
EXPECT_FALSE(set1.Has(non_existent_number));
EXPECT_FALSE(set2.Has(non_existent_number));

// Test moving to an existing field.
// Set a different value in set1 for the string extension.
src->SetExtension(unittest::optional_string_extension, "999");
EXPECT_TRUE(set1.Has(unittest::optional_string_extension.number()));
EXPECT_EQ(src->GetExtension(unittest::optional_string_extension), "999");

// Move the original value from set2 back to set1.
EXPECT_TRUE(
set1.MoveExtension(&arena, unittest::optional_string_extension.number(),
set2, unittest::optional_string_extension.number()));
EXPECT_FALSE(set2.Has(unittest::optional_string_extension.number()));
EXPECT_TRUE(set1.Has(unittest::optional_string_extension.number()));
EXPECT_EQ(src->GetExtension(unittest::optional_string_extension), "123");
}

auto MakeExtensionWithLazyRep(int value) {
unittest::TestAllExtensions tmp;
tmp.SetExtension(unittest::optional_int32_extension, value);

unittest::TestAllExtensions out;
out.ParseFromString(tmp.SerializeAsString());
return out;
}

TEST(ExtensionSetTest, MoveLazyMessageExtension) {
proto2_unittest::TestAllExtensions src = MakeExtensionWithLazyRep(1234);
proto2_unittest::TestAllExtensions dst = MakeExtensionWithLazyRep(5678);

ExtensionSet& set1 = PrivateAccess::GetExtensionSet(src);
ExtensionSet& set2 = PrivateAccess::GetExtensionSet(dst);

EXPECT_TRUE(
set2.MoveExtension(nullptr, unittest::optional_int32_extension.number(),
set1, unittest::optional_int32_extension.number()));

// Source no longer has the field at all.
EXPECT_FALSE(src.HasExtension(unittest::optional_int32_extension));

// Dest must have the field as set.
EXPECT_TRUE(dst.HasExtension(unittest::optional_int32_extension));

// Dest is lazy.
set2.IsLazy(unittest::optional_int32_extension.number());

// Finally, force the parse just to verify.
EXPECT_EQ(dst.GetExtension(unittest::optional_int32_extension), 1234);
}

TEST(ExtensionSetTest, MoveExtensionWithGeneratedDescriptor) {
unittest::TestAllExtensions src;
const FieldDescriptor* fd =
GetExtensionReflection(unittest::optional_int32_extension);
src.GetReflection()->SetInt32(&src, fd, 101);

unittest::TestAllExtensions dst;
ExtensionSet& set1 = PrivateAccess::GetExtensionSet(src);
ExtensionSet& set2 = PrivateAccess::GetExtensionSet(dst);

EXPECT_FALSE(
set2.MoveExtension(nullptr, unittest::optional_int32_extension.number(),
set1, unittest::optional_int32_extension.number()));

EXPECT_TRUE(src.HasExtension(unittest::optional_int32_extension));
EXPECT_EQ(src.GetExtension(unittest::optional_int32_extension), 101);
EXPECT_FALSE(dst.HasExtension(unittest::optional_int32_extension));
}

TEST(ExtensionSetTest, MoveExtensionWithDynamicDescriptor) {
// Define a dynamic extension.
FileDescriptorProto file_descriptor_proto;
file_descriptor_proto.set_name("my_dynamic_extensions.proto");
file_descriptor_proto.set_package("my_dynamic_package");
file_descriptor_proto.add_dependency(
unittest::TestAllExtensions::descriptor()->file()->name());

FieldDescriptorProto* extension = file_descriptor_proto.add_extension();
extension->set_name("my_dynamic_int_extension");
extension->set_extendee(
unittest::TestAllExtensions::descriptor()->full_name());
extension->set_number(5000); // Use a unique extension number.
extension->set_label(FieldDescriptorProto::LABEL_OPTIONAL);
extension->set_type(FieldDescriptorProto::TYPE_INT32);

FieldDescriptorProto* extension2 = file_descriptor_proto.add_extension();
extension2->set_name("my_other_dynamic_int_extension");
extension2->set_extendee(
unittest::TestAllExtensions::descriptor()->full_name());
extension2->set_number(5001); // Use a unique extension number.
extension2->set_label(FieldDescriptorProto::LABEL_OPTIONAL);
extension2->set_type(FieldDescriptorProto::TYPE_INT32);

DescriptorPool dynamic_pool(DescriptorPool::generated_pool());
const FileDescriptor* file = dynamic_pool.BuildFile(file_descriptor_proto);
ASSERT_TRUE(file != nullptr);
DynamicMessageFactory dynamic_factory(&dynamic_pool);
dynamic_factory.SetDelegateToGeneratedFactory(true);

const FieldDescriptor* dynamic_ext_fd =
file->FindExtensionByName("my_dynamic_int_extension");
ASSERT_TRUE(dynamic_ext_fd != nullptr);
const FieldDescriptor* dynamic_ext_fd2 =
file->FindExtensionByName("my_other_dynamic_int_extension");
ASSERT_TRUE(dynamic_ext_fd2 != nullptr);

// Set dynamic extension on src message via reflection.
auto* prototype =
dynamic_factory.GetPrototype(unittest::TestAllExtensions::descriptor());
std::unique_ptr<Message> src_msg(prototype->New());
src_msg->GetReflection()->SetInt32(src_msg.get(), dynamic_ext_fd, 12345);
std::unique_ptr<Message> dst_msg(prototype->New());
auto& src = *DynamicCastMessage<unittest::TestAllExtensions>(src_msg.get());
auto& dst = *DynamicCastMessage<unittest::TestAllExtensions>(dst_msg.get());
ExtensionSet& set1 = PrivateAccess::GetExtensionSet(src);
ExtensionSet& set2 = PrivateAccess::GetExtensionSet(dst);

// Move dynamic extension from src to dst.
EXPECT_FALSE(set2.MoveExtension(nullptr, dynamic_ext_fd2->number(), set1,
dynamic_ext_fd->number()));

EXPECT_TRUE(src.GetReflection()->HasField(src, dynamic_ext_fd));
EXPECT_FALSE(dst.GetReflection()->HasField(dst, dynamic_ext_fd2));

EXPECT_EQ(src.GetReflection()->GetInt32(src, dynamic_ext_fd), 12345);
std::vector<const FieldDescriptor*> fields;
src.GetReflection()->ListFields(src, &fields);
ASSERT_EQ(fields.size(), 1);
EXPECT_EQ(fields[0], dynamic_ext_fd);
}


TEST_P(FindExtensionTest,
FindExtensionInfoFromFieldNumber_FindExistingExtension) {
Expand Down
Loading
Loading