Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/google/protobuf/unittest.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ message TestAllTypes {
// a local variable named "b" in one of the generated methods. Doh.
// This file needs to compile in proto1 to test backwards-compatibility.
int32 bb = 1;
int32 cc = 2;
}

enum NestedEnum {
Expand Down
16 changes: 13 additions & 3 deletions src/google/protobuf/util/field_mask_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
#include "absl/log/absl_log.h"
#include "absl/log/die_if_null.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"

// Must be included last.
Expand Down Expand Up @@ -584,9 +586,17 @@ bool FieldMaskTree::TrimMessage(const Node* node, Message* message) {
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
Node* child = it->second.get();
if (!child->children.empty() && reflection->HasField(*message, field)) {
bool nestedMessageChanged =
TrimMessage(child, reflection->MutableMessage(message, field));
modified = nestedMessageChanged || modified;
if (field->is_repeated()) {
for (int i = 0; i < reflection->FieldSize(*message, field); ++i) {
bool nestedMessageChanged = TrimMessage(
child, reflection->MutableRepeatedMessage(message, field, i));
modified = nestedMessageChanged || modified;
}
} else {
bool nestedMessageChanged =
TrimMessage(child, reflection->MutableMessage(message, field));
modified = nestedMessageChanged || modified;
}
}
}
}
Expand Down
89 changes: 65 additions & 24 deletions src/google/protobuf/util/field_mask_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@

#include <algorithm>
#include <cstdint>
#include <string>
#include <vector>

#include "google/protobuf/field_mask.pb.h"
#include "net/proto2/contrib/parse_proto/parse_text_proto.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/base/log_severity.h"
#include "google/protobuf/test_util.h"
#include "google/protobuf/unittest.pb.h"
#include "google/protobuf/util/field_mask_util.h"

namespace google {
namespace protobuf {
Expand Down Expand Up @@ -90,11 +95,13 @@ TEST_F(SnakeCaseCamelCaseTest, RoundTripTest) {
} while (std::next_permutation(name.begin(), name.end()));
}

using contrib::parse_proto::ParseTextProtoOrDie;
using google::protobuf::FieldMask;
using proto2_unittest::NestedTestAllTypes;
using proto2_unittest::TestAllTypes;
using proto2_unittest::TestRequired;
using proto2_unittest::TestRequiredMessage;
using testing::EqualsProto;

TEST(FieldMaskUtilTest, StringFormat) {
FieldMask mask;
Expand Down Expand Up @@ -198,8 +205,9 @@ TEST(FieldMaskUtilTest, TestIsValidFieldMask) {
TEST(FieldMaskUtilTest, TestGetFieldMaskForAllFields) {
FieldMask mask;
mask = FieldMaskUtil::GetFieldMaskForAllFields<TestAllTypes::NestedMessage>();
EXPECT_EQ(1, mask.paths_size());
EXPECT_EQ(2, mask.paths_size());
EXPECT_TRUE(FieldMaskUtil::IsPathInFieldMask("bb", mask));
EXPECT_TRUE(FieldMaskUtil::IsPathInFieldMask("cc", mask));

mask = FieldMaskUtil::GetFieldMaskForAllFields<TestAllTypes>();
EXPECT_EQ(80, mask.paths_size());
Expand Down Expand Up @@ -356,7 +364,8 @@ TEST(FieldMaskUtilTest, TestSubtract) {

FieldMaskUtil::Subtract<TestAllTypes>(mask1, mask2, &out);
EXPECT_EQ(
"optional_foreign_message.d,optional_uint64,repeated_foreign_message.c",
"optional_foreign_message.d,optional_nested_message.cc,optional_uint64,"
"repeated_foreign_message.c",
FieldMaskUtil::ToString(out));

// mask1 is empty.
Expand Down Expand Up @@ -581,16 +590,16 @@ TEST(FieldMaskUtilTest, TrimMessage) {
TEST_TRIM_ONE_PRIMITIVE_FIELD(optional_import_enum)
#undef TEST_TRIM_ONE_PRIMITIVE_FIELD

#define TEST_TRIM_ONE_FIELD(field_name) \
{ \
TestAllTypes msg; \
TestUtil::SetAllFields(&msg); \
TestAllTypes tmp; \
*tmp.mutable_##field_name() = msg.field_name(); \
FieldMask mask; \
mask.add_paths(#field_name); \
FieldMaskUtil::TrimMessage(mask, &msg); \
EXPECT_EQ(tmp.DebugString(), msg.DebugString()); \
#define TEST_TRIM_ONE_FIELD(field_name) \
{ \
TestAllTypes msg; \
TestUtil::SetAllFields(&msg); \
TestAllTypes tmp; \
*tmp.mutable_##field_name() = msg.field_name(); \
FieldMask mask; \
mask.add_paths(#field_name); \
EXPECT_TRUE(FieldMaskUtil::TrimMessage(mask, &msg)); \
EXPECT_EQ(tmp.DebugString(), msg.DebugString()); \
}
TEST_TRIM_ONE_FIELD(optional_nested_message)
TEST_TRIM_ONE_FIELD(optional_foreign_message)
Expand Down Expand Up @@ -629,25 +638,25 @@ TEST(FieldMaskUtilTest, TrimMessage) {
NestedTestAllTypes trimmed_msg(nested_msg);
FieldMask mask;
FieldMaskUtil::FromString("child.payload", &mask);
FieldMaskUtil::TrimMessage(mask, &trimmed_msg);
EXPECT_TRUE(FieldMaskUtil::TrimMessage(mask, &trimmed_msg));
EXPECT_EQ(1234, trimmed_msg.child().payload().optional_int32());
EXPECT_EQ(0, trimmed_msg.child().child().payload().optional_int32());

trimmed_msg = nested_msg;
FieldMaskUtil::FromString("child.child.payload", &mask);
FieldMaskUtil::TrimMessage(mask, &trimmed_msg);
EXPECT_TRUE(FieldMaskUtil::TrimMessage(mask, &trimmed_msg));
EXPECT_EQ(0, trimmed_msg.child().payload().optional_int32());
EXPECT_EQ(5678, trimmed_msg.child().child().payload().optional_int32());

trimmed_msg = nested_msg;
FieldMaskUtil::FromString("child", &mask);
FieldMaskUtil::TrimMessage(mask, &trimmed_msg);
EXPECT_FALSE(FieldMaskUtil::TrimMessage(mask, &trimmed_msg));
EXPECT_EQ(1234, trimmed_msg.child().payload().optional_int32());
EXPECT_EQ(5678, trimmed_msg.child().child().payload().optional_int32());

trimmed_msg = nested_msg;
FieldMaskUtil::FromString("child.child", &mask);
FieldMaskUtil::TrimMessage(mask, &trimmed_msg);
EXPECT_TRUE(FieldMaskUtil::TrimMessage(mask, &trimmed_msg));
EXPECT_EQ(0, trimmed_msg.child().payload().optional_int32());
EXPECT_EQ(5678, trimmed_msg.child().child().payload().optional_int32());

Expand All @@ -656,7 +665,7 @@ TEST(FieldMaskUtilTest, TrimMessage) {
TestUtil::SetAllFields(&all_types_msg);
TestAllTypes trimmed_all_types(all_types_msg);
FieldMask empty_mask;
FieldMaskUtil::TrimMessage(empty_mask, &trimmed_all_types);
EXPECT_FALSE(FieldMaskUtil::TrimMessage(empty_mask, &trimmed_all_types));
EXPECT_EQ(trimmed_all_types.DebugString(), all_types_msg.DebugString());

// Test trim required fields with keep_required_fields is set true.
Expand All @@ -668,15 +677,17 @@ TEST(FieldMaskUtilTest, TrimMessage) {
TestRequired trimmed_required_msg_1(required_msg_1);
FieldMaskUtil::FromString("dummy2", &mask);
options.set_keep_required_fields(true);
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_1, options);
EXPECT_FALSE(
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_1, options));
EXPECT_EQ(trimmed_required_msg_1.DebugString(), required_msg_1.DebugString());

// Test trim required fields with keep_required_fields is set false.
required_msg_1.clear_a();
required_msg_1.clear_b();
required_msg_1.clear_c();
options.set_keep_required_fields(false);
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_1, options);
EXPECT_TRUE(
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_1, options));
EXPECT_EQ(trimmed_required_msg_1.DebugString(), required_msg_1.DebugString());

// Test trim required message with keep_required_fields is set true.
Expand All @@ -697,14 +708,16 @@ TEST(FieldMaskUtilTest, TrimMessage) {
options.set_keep_required_fields(true);
required_msg_2.clear_repeated_message();
required_msg_2.mutable_required_message()->clear_dummy2();
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_2, options);
EXPECT_TRUE(
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_2, options));
EXPECT_EQ(trimmed_required_msg_2.DebugString(), required_msg_2.DebugString());

FieldMaskUtil::FromString("required_message", &mask);
required_msg_2.mutable_required_message()->set_dummy2(7890);
trimmed_required_msg_2.mutable_required_message()->set_dummy2(7890);
required_msg_2.clear_optional_message();
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_2, options);
EXPECT_TRUE(
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_2, options));
EXPECT_EQ(trimmed_required_msg_2.DebugString(), required_msg_2.DebugString());

// Test trim required message with keep_required_fields is set false.
Expand All @@ -713,15 +726,16 @@ TEST(FieldMaskUtilTest, TrimMessage) {
required_msg_2.mutable_required_message()->clear_b();
required_msg_2.mutable_required_message()->clear_c();
options.set_keep_required_fields(false);
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_2, options);
EXPECT_TRUE(
FieldMaskUtil::TrimMessage(mask, &trimmed_required_msg_2, options));
EXPECT_EQ(trimmed_required_msg_2.DebugString(), required_msg_2.DebugString());

// Verify that trimming an empty message has no effect. In particular, fields
// mentioned in the field mask should not be created or changed.
TestAllTypes empty_msg;
FieldMaskUtil::FromString(
"optional_int32,optional_bytes,optional_nested_message.bb", &mask);
FieldMaskUtil::TrimMessage(mask, &empty_msg);
EXPECT_FALSE(FieldMaskUtil::TrimMessage(mask, &empty_msg));
EXPECT_FALSE(empty_msg.has_optional_int32());
EXPECT_FALSE(empty_msg.has_optional_bytes());
EXPECT_FALSE(empty_msg.has_optional_nested_message());
Expand All @@ -731,10 +745,37 @@ TEST(FieldMaskUtilTest, TrimMessage) {
TestAllTypes oneof_msg;
oneof_msg.set_oneof_uint32(11);
FieldMaskUtil::FromString("oneof_uint32,oneof_nested_message.bb", &mask);
FieldMaskUtil::TrimMessage(mask, &oneof_msg);
EXPECT_FALSE(FieldMaskUtil::TrimMessage(mask, &oneof_msg));
EXPECT_EQ(11, oneof_msg.oneof_uint32());
}

TEST(FieldMaskUtilTest, TrimMessageRepeatedField) {
FieldMask bb_mask;
FieldMaskUtil::FromString("repeated_nested_message.bb", &bb_mask);
{
TestAllTypes repeated_nested_msg;
TestUtil::SetAllFields(&repeated_nested_msg);
repeated_nested_msg.add_repeated_nested_message()->set_bb(1234);
TestAllTypes trimmed_repeated_nested_msg;
*trimmed_repeated_nested_msg.mutable_repeated_nested_message() =
repeated_nested_msg.repeated_nested_message();
EXPECT_TRUE(FieldMaskUtil::TrimMessage(bb_mask, &repeated_nested_msg));
EXPECT_THAT(repeated_nested_msg, EqualsProto(trimmed_repeated_nested_msg));
}
{
// Repeated field has multiple elements
TestAllTypes msg = ParseTextProtoOrDie(R"pb(
repeated_nested_message { bb: 1234 }
repeated_nested_message { bb: 5678 cc: 9012 }
)pb");
EXPECT_TRUE(FieldMaskUtil::TrimMessage(bb_mask, &msg));
EXPECT_THAT(msg, EqualsProto(R"pb(
repeated_nested_message { bb: 1234 }
repeated_nested_message { bb: 5678 }
)pb"));
}
}

TEST(FieldMaskUtilTest, TrimMessageReturnValue) {
FieldMask mask;
TestAllTypes trimmed_msg;
Expand Down
Loading