Skip to content

Commit

Permalink
Avoid initializing infinitely recursive sub-fields.
Browse files Browse the repository at this point in the history
Given that sub-fields may not get initialized in the smoke-test, it could be unexpected for users to encounter this issue during fuzzing. Instead of producing a failure, we produce a warning and avoid initializing them.

In addition, the recursion loops are detected over fields and not protos because the customizations are done over fields and two fields of same type could have different customizations, which can affect the recursion analysis.

PiperOrigin-RevId: 726556282
  • Loading branch information
hadi88 authored and copybara-github committed Feb 13, 2025
1 parent eb4c69a commit b35469f
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 49 deletions.
15 changes: 15 additions & 0 deletions domain_tests/arbitrary_domains_protobuf_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,21 @@ TEST(ProtocolBufferWithRequiredFields, ShrinkingNeverRemovesRequiredFields) {
}
}

TEST(ProtocolBufferWithRecursiveFields, InfiniteleyRecursiveFieldsAreNotSet) {
auto domain = Arbitrary<internal::TestProtobufWithRepeatedRecursionSubproto>()
.WithRepeatedFieldsAlwaysSet();
absl::BitGen bitgen;
Value val(domain, bitgen);

ASSERT_TRUE(val.user_value.IsInitialized()) << val.user_value;

for (int i = 0; i < 1000; ++i) {
val.Mutate(domain, bitgen, {}, false);
ASSERT_TRUE(val.user_value.IsInitialized()) << val.user_value;
ASSERT_FALSE(val.user_value.has_list()) << val.user_value;
}
}

TEST(ProtocolBuffer, CanUsePerFieldDomains) {
Domain<TestProtobuf> domain =
Arbitrary<TestProtobuf>()
Expand Down
1 change: 1 addition & 0 deletions fuzztest/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
Expand Down
1 change: 1 addition & 0 deletions fuzztest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ fuzztest_cc_library(
absl::status
absl::statusor
absl::strings
absl::str_format
absl::synchronization
absl::span
)
Expand Down
135 changes: 86 additions & 49 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -466,10 +467,11 @@ class ProtobufDomainUntypedImpl

corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
FUZZTEST_INTERNAL_CHECK(
!IsCustomizedRecursivelyOnly() || !IsInfinitelyRecursive(),
"Cannot set recursive fields by default.");
const auto* descriptor = prototype_.Get()->GetDescriptor();
FUZZTEST_INTERNAL_CHECK(
!IsCustomizedRecursivelyOnly() || !IsInfinitelyRecursive(descriptor),
absl::StrCat("Cannot set recursive fields for ",
descriptor->full_name(), " by default."));
corpus_type val;
absl::flat_hash_map<int, int> oneof_to_field;

Expand All @@ -481,15 +483,18 @@ class ProtobufDomainUntypedImpl
SelectAFieldIndexInOneof(oneof, prng);
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (!MustBeSet(field) && IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
// fields, the initialization will be deterministic, which violates the
// assumption on domain Init. However, such cases should be extremely
// rare and breaking the assumption would not have severe consequences.
continue;
} else if (IsCustomizedRecursivelyOnly()) {
if (!MustBeSet(field) && IsFieldFinitelyRecursive(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
// fields, the initialization will be deterministic, which violates
// the assumption on domain Init. However, such cases should be
// extremely rare and breaking the assumption would not have severe
// consequences.
continue;
}
if (MustBeUnset(field)) continue;
}
VisitProtobufField(field, InitializeVisitor{prng, *this, val});
}
Expand Down Expand Up @@ -609,6 +614,7 @@ class ProtobufDomainUntypedImpl
GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
continue;
}
if (IsCustomizedRecursivelyOnly() && MustBeUnset(field)) continue;
++total_weight;

if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
Expand Down Expand Up @@ -646,6 +652,7 @@ class ProtobufDomainUntypedImpl
GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
continue;
}
if (IsCustomizedRecursivelyOnly() && MustBeUnset(field)) continue;
++field_counter;
if (field_counter == selected_field_index) {
VisitProtobufField(
Expand Down Expand Up @@ -958,9 +965,9 @@ class ProtobufDomainUntypedImpl
for (int i = 0; i < oneof->field_count(); ++i) {
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
if (IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(oneof->field(i))) {
continue;
if (IsCustomizedRecursivelyOnly()) {
if (IsFieldFinitelyRecursive(oneof->field(i))) continue;
if (MustBeUnset(oneof->field(i))) continue;
}
fields.push_back(i);
}
Expand Down Expand Up @@ -1705,31 +1712,43 @@ class ProtobufDomainUntypedImpl
kFinitelyRecursive,
};

bool IsInfinitelyRecursive() {
absl::flat_hash_set<decltype(prototype_.Get()->GetDescriptor())> parents;
return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents,
// Returns true if there are subprotos in the `descriptor` that form an
// infinite recursion.
bool IsInfinitelyRecursive(const Descriptor* descriptor) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsProtoRecursive(/*field=*/nullptr, parents,
RecursionType::kInfinitelyRecursive, descriptor);
}

// Returns true if there are subfields in the `field` that form an
// infinite recursion of the form: F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs,
// because all Fi-s have to be set (e.g., Fi is a required field, or is
// customized using `WithFieldsAlwaysSet`).
bool IsInfinitelyRecursive(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsProtoRecursive(field, parents,
RecursionType::kInfinitelyRecursive);
}

bool IsFieldFinitelyRecursive(const FieldDescriptor* field) {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (!field->message_type()) return false;
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<
absl::flat_hash_map<std::pair<int64_t, const FieldDescriptor*>, bool>>
cache ABSL_GUARDED_BY(mutex);
bool can_use_cache = IsCustomizedRecursivelyOnly();
if (can_use_cache) {
{
absl::MutexLock l(&mutex);
auto it = cache->find({policy_.id(), field});
if (it != cache->end()) return it->second;
}
absl::flat_hash_set<decltype(field->message_type())> parents;
bool result = IsProtoRecursive(field->message_type(), parents,
RecursionType::kFinitelyRecursive);
if (can_use_cache) {
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, result});
}
absl::flat_hash_set<const FieldDescriptor*> parents;
bool result =
IsProtoRecursive(field, parents, RecursionType::kFinitelyRecursive);
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, result});
return result;
}

Expand All @@ -1742,23 +1761,23 @@ class ProtobufDomainUntypedImpl
}

bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
absl::flat_hash_set<const FieldDescriptor*>& parents,
RecursionType recursion_type) const {
bool is_oneof_recursive = false;
for (int i = 0; i < oneof->field_count(); ++i) {
const auto* field = oneof->field(i);
const auto field_policy = policy_.GetOptionalPolicy(field);
if (field_policy == OptionalPolicy::kAlwaysNull) continue;
const auto* child = field->message_type();
if (recursion_type == RecursionType::kInfinitelyRecursive) {
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
child &&
IsProtoRecursive(child, parents, recursion_type);
field->message_type() &&
IsProtoRecursive(field, parents, recursion_type);
if (!is_oneof_recursive) {
return false;
}
} else {
if (child && IsProtoRecursive(child, parents, recursion_type)) {
if (field->message_type() &&
IsProtoRecursive(field, parents, recursion_type)) {
return true;
}
}
Expand All @@ -1767,6 +1786,7 @@ class ProtobufDomainUntypedImpl
}

bool MustBeSet(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (IsRequired(field)) {
return true;
} else if (field->containing_oneof()) {
Expand All @@ -1783,6 +1803,14 @@ class ProtobufDomainUntypedImpl
}

bool MustBeUnset(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (field->message_type() && IsInfinitelyRecursive(field)) {
absl::FPrintF(
GetStderr(),
"[!] Infinite recursion detected for %s and it remains unset.\n",
field->full_name());
return true;
}
if (IsRequired(field)) {
return false;
} else if (field->containing_oneof()) {
Expand All @@ -1798,39 +1826,48 @@ class ProtobufDomainUntypedImpl
return false;
}

template <typename Descriptor>
bool IsProtoRecursive(const Descriptor* descriptor,
absl::flat_hash_set<const Descriptor*>& parents,
RecursionType recursion_type) const {
if (parents.contains(descriptor)) return true;
parents.insert(descriptor);
// If `field` is nullptr, all fields of `descriptor` are checked.
bool IsProtoRecursive(const FieldDescriptor* field,
absl::flat_hash_set<const FieldDescriptor*>& parents,
RecursionType recursion_type,
const Descriptor* descriptor = nullptr) const {
if (field != nullptr) {
if (parents.contains(field)) return true;
parents.insert(field);
descriptor = field->message_type();
} else {
FUZZTEST_INTERNAL_CHECK(descriptor,
"one of field or descriptor must be non-null!");
}
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
const auto* oneof = descriptor->oneof_decl(i);
if (IsOneofRecursive(oneof, parents, recursion_type)) {
parents.erase(descriptor);
if (field != nullptr) parents.erase(field);
return true;
}
}
for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
if (field->containing_oneof()) continue;
const auto* child = field->message_type();
if (!child) continue;
if (policy_.GetDefaultDomainForProtobufs(field) != std::nullopt) {
for (const FieldDescriptor* subfield : GetProtobufFields(descriptor)) {
if (subfield->containing_oneof()) continue;
if (!subfield->message_type()) continue;
if (auto default_domain = policy_.GetDefaultDomainForProtobufs(subfield);
default_domain != std::nullopt) { // For handling WithProtobufFields.
// If this field is recursive, it will be detected when initializing
// its default domain. Otherwise, this field can always be set safely.
absl::BitGen prng;
default_domain->Init(prng);
continue;
}
if (recursion_type == RecursionType::kInfinitelyRecursive) {
if (!MustBeSet(field)) continue;
if (!MustBeSet(subfield)) continue;
} else {
if (MustBeUnset(field)) continue;
if (MustBeUnset(subfield)) continue;
}
if (IsProtoRecursive(child, parents, recursion_type)) {
parents.erase(descriptor);
if (IsProtoRecursive(subfield, parents, recursion_type)) {
if (field != nullptr) parents.erase(field);
return true;
}
}
parents.erase(descriptor);
if (field != nullptr) parents.erase(field);
return false;
}

Expand Down
8 changes: 8 additions & 0 deletions fuzztest/internal/test_protobuf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ message TestProtobufWithRecursion {
optional TestProtobufWithExtension ext = 4;
}

message TestProtobufWithRepeatedRecursion {
repeated TestProtobufWithRepeatedRecursion items = 1;
}

message TestProtobufWithRepeatedRecursionSubproto {
optional TestProtobufWithRepeatedRecursion list = 1;
}

message MessageWithGroup {
optional group GroupField = 1 {
optional int64 field1 = 2;
Expand Down

0 comments on commit b35469f

Please sign in to comment.