diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index 6d44da33..3b21e1f9 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h @@ -484,7 +484,7 @@ class ProtobufDomainUntypedImpl } if (oneof_to_field[oneof->index()] != field->index()) continue; } else if (IsCustomizedRecursivelyOnly()) { - if (!MustBeSet(field) && IsRecursionBreaker(field)) { + if (!MustBeSet(field) && IsFieldFinitelyRecursive(field)) { // We avoid initializing non-required recursive fields by default (if // they are not explicitly customized). Otherwise, the initialization // may never terminate. If a proto has only non-required recursive @@ -966,7 +966,7 @@ class ProtobufDomainUntypedImpl OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i)); if (policy == OptionalPolicy::kAlwaysNull) continue; if (IsCustomizedRecursivelyOnly()) { - if (IsRecursionBreaker(oneof->field(i))) continue; + if (IsFieldFinitelyRecursive(oneof->field(i))) continue; if (MustBeUnset(oneof->field(i))) continue; } fields.push_back(i); @@ -1704,12 +1704,21 @@ class ProtobufDomainUntypedImpl return GetDomainForField(field, /*use_policy=*/false); } + // Analysis type for protobuf recursions. + enum class RecursionType { + // The proto contains a proto of type P, that must contain another P. + kInfinitelyRecursive, + // The proto contains a proto of type P, that can contain another P. + kFinitelyRecursive, + }; + // Returns true if there are subprotos in the `descriptor` that form an // infinite recursion. bool IsInfinitelyRecursive(const Descriptor* descriptor) const { FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); absl::flat_hash_set parents; - return IsProtoRecursive(/*field=*/nullptr, parents, descriptor); + return IsProtoRecursive(/*field=*/nullptr, parents, + RecursionType::kInfinitelyRecursive, descriptor); } // Returns true if there are subfields in the `field` that form an @@ -1718,6 +1727,14 @@ class ProtobufDomainUntypedImpl // customized using `WithFieldsAlwaysSet`). bool IsInfinitelyRecursive(const FieldDescriptor* field) const { FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); + absl::flat_hash_set parents; + return IsProtoRecursive(field, parents, + RecursionType::kInfinitelyRecursive); + } + + bool IsFieldFinitelyRecursive(const FieldDescriptor* field) { + FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); + if (!field->message_type()) return false; ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); static absl::NoDestructor< absl::flat_hash_map, bool>> @@ -1728,7 +1745,8 @@ class ProtobufDomainUntypedImpl if (it != cache->end()) return it->second; } absl::flat_hash_set parents; - bool result = IsProtoRecursive(field, parents); + bool result = + IsProtoRecursive(field, parents, RecursionType::kFinitelyRecursive); absl::MutexLock l(&mutex); cache->insert({{policy_.id(), field}, result}); return result; @@ -1742,18 +1760,27 @@ class ProtobufDomainUntypedImpl return index == kFieldCountIndex; } - bool IsOneofRecursive( - const OneofDescriptor* oneof, - absl::flat_hash_set& parents) const { + bool IsOneofRecursive(const OneofDescriptor* oneof, + absl::flat_hash_set& parents, + RecursionType recursion_type) const { bool is_oneof_recursive = false; for (int i = 0; i < oneof->field_count(); ++i) { const auto* field = oneof->field(i); const auto field_policy = policy_.GetOptionalPolicy(field); if (field_policy == OptionalPolicy::kAlwaysNull) continue; - is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && - field->message_type() && - IsProtoRecursive(field, parents); - if (!is_oneof_recursive) return false; + if (recursion_type == RecursionType::kInfinitelyRecursive) { + is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && + field->message_type() && + IsProtoRecursive(field, parents, recursion_type); + if (!is_oneof_recursive) { + return false; + } + } else { + if (field->message_type() && + IsProtoRecursive(field, parents, recursion_type)) { + return true; + } + } } return is_oneof_recursive; } @@ -1802,6 +1829,7 @@ class ProtobufDomainUntypedImpl // If `field` is nullptr, all fields of `descriptor` are checked. bool IsProtoRecursive(const FieldDescriptor* field, absl::flat_hash_set& parents, + RecursionType recursion_type, const Descriptor* descriptor = nullptr) const { if (field != nullptr) { if (parents.contains(field)) return true; @@ -1813,7 +1841,7 @@ class ProtobufDomainUntypedImpl } for (int i = 0; i < descriptor->oneof_decl_count(); ++i) { const auto* oneof = descriptor->oneof_decl(i); - if (IsOneofRecursive(oneof, parents)) { + if (IsOneofRecursive(oneof, parents, recursion_type)) { if (field != nullptr) parents.erase(field); return true; } @@ -1829,8 +1857,12 @@ class ProtobufDomainUntypedImpl default_domain->Init(prng); continue; } - if (!MustBeSet(subfield)) continue; - if (IsProtoRecursive(subfield, parents)) { + if (recursion_type == RecursionType::kInfinitelyRecursive) { + if (!MustBeSet(subfield)) continue; + } else { + if (MustBeUnset(subfield)) continue; + } + if (IsProtoRecursive(subfield, parents, recursion_type)) { if (field != nullptr) parents.erase(field); return true; } @@ -1839,68 +1871,6 @@ class ProtobufDomainUntypedImpl return false; } - // A subset of proto types are considered as recursion breakers and won't - // get recursively initialized during domain initialization to avoid - // non-terminating initialization. - // - // Returns true if the `field` (F0) does not have to be set, and there are - // subfields in the form: F0 -> F1 -> ... -> Fn -> F0 or F20 -> F19 ... -> F0 - // and none of other Fi-s are marked as recursion breakers so far. In other - // words, this method computes recursion breakers and check membership of - // `field` in the set of recursion breakers. - bool IsRecursionBreaker(const FieldDescriptor* field) { - FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); - if (!field->message_type()) return false; - absl::flat_hash_set parents; - return IsRecursionBreaker(/*root=*/field, field, parents); - } - - bool IsRecursionBreaker( - const FieldDescriptor* root, const FieldDescriptor* field, - absl::flat_hash_set& parents) const { - ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); - static absl::NoDestructor< - absl::flat_hash_map, bool>> - cache ABSL_GUARDED_BY(mutex); - { - absl::MutexLock l(&mutex); - auto it = cache->find({policy_.id(), field}); - if (it != cache->end()) return it->second; - } - // Cannot break the recursion for required fields. - bool can_be_unset = !MustBeSet(field); - if (field->containing_oneof() && !can_be_unset) { // oneof must be set - // We check whether `field` is infinitely recursive without considering - // other oneof fields. If it is, there's another field in the oneof that - // can be set. - absl::flat_hash_set subfield_parents; - subfield_parents.insert(field); - can_be_unset = IsProtoRecursive(field, subfield_parents); - } - if (can_be_unset) { - // Break recursion for deeply nested or recursive protos. - if (parents.size() > 20 || parents.contains(field)) { - absl::MutexLock l(&mutex); - cache->insert({{policy_.id(), field}, true}); - return true; - } - parents.insert(field); - } - for (const FieldDescriptor* subfield : - GetProtobufFields(field->message_type())) { - if (!subfield->message_type()) continue; - if (MustBeUnset(subfield)) continue; - IsRecursionBreaker(root, subfield, parents); - } - if (can_be_unset) parents.erase(field); - absl::MutexLock l(&mutex); - // If the result is computed while visiting the children, we shouldn't - // overwrite. For example, if we visit A -> B -> C -> A, we can return the - // result of the nested A for upper-level A. - auto [it, inserted] = cache->insert({{policy_.id(), field}, false}); - return it->second; - } - bool IsRequired(const FieldDescriptor* field) const { return field->is_required() || IsMapValueMessage(field); }