diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index 3b21e1f9..6d44da33 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h @@ -484,7 +484,7 @@ class ProtobufDomainUntypedImpl } if (oneof_to_field[oneof->index()] != field->index()) continue; } else if (IsCustomizedRecursivelyOnly()) { - if (!MustBeSet(field) && IsFieldFinitelyRecursive(field)) { + if (!MustBeSet(field) && IsRecursionBreaker(field)) { // We avoid initializing non-required recursive fields by default (if // they are not explicitly customized). Otherwise, the initialization // may never terminate. If a proto has only non-required recursive @@ -966,7 +966,7 @@ class ProtobufDomainUntypedImpl OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i)); if (policy == OptionalPolicy::kAlwaysNull) continue; if (IsCustomizedRecursivelyOnly()) { - if (IsFieldFinitelyRecursive(oneof->field(i))) continue; + if (IsRecursionBreaker(oneof->field(i))) continue; if (MustBeUnset(oneof->field(i))) continue; } fields.push_back(i); @@ -1704,21 +1704,12 @@ class ProtobufDomainUntypedImpl return GetDomainForField(field, /*use_policy=*/false); } - // Analysis type for protobuf recursions. - enum class RecursionType { - // The proto contains a proto of type P, that must contain another P. - kInfinitelyRecursive, - // The proto contains a proto of type P, that can contain another P. - kFinitelyRecursive, - }; - // Returns true if there are subprotos in the `descriptor` that form an // infinite recursion. bool IsInfinitelyRecursive(const Descriptor* descriptor) const { FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); absl::flat_hash_set parents; - return IsProtoRecursive(/*field=*/nullptr, parents, - RecursionType::kInfinitelyRecursive, descriptor); + return IsProtoRecursive(/*field=*/nullptr, parents, descriptor); } // Returns true if there are subfields in the `field` that form an @@ -1727,14 +1718,6 @@ class ProtobufDomainUntypedImpl // customized using `WithFieldsAlwaysSet`). bool IsInfinitelyRecursive(const FieldDescriptor* field) const { FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); - absl::flat_hash_set parents; - return IsProtoRecursive(field, parents, - RecursionType::kInfinitelyRecursive); - } - - bool IsFieldFinitelyRecursive(const FieldDescriptor* field) { - FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); - if (!field->message_type()) return false; ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); static absl::NoDestructor< absl::flat_hash_map, bool>> @@ -1745,8 +1728,7 @@ class ProtobufDomainUntypedImpl if (it != cache->end()) return it->second; } absl::flat_hash_set parents; - bool result = - IsProtoRecursive(field, parents, RecursionType::kFinitelyRecursive); + bool result = IsProtoRecursive(field, parents); absl::MutexLock l(&mutex); cache->insert({{policy_.id(), field}, result}); return result; @@ -1760,27 +1742,18 @@ class ProtobufDomainUntypedImpl return index == kFieldCountIndex; } - bool IsOneofRecursive(const OneofDescriptor* oneof, - absl::flat_hash_set& parents, - RecursionType recursion_type) const { + bool IsOneofRecursive( + const OneofDescriptor* oneof, + absl::flat_hash_set& parents) const { bool is_oneof_recursive = false; for (int i = 0; i < oneof->field_count(); ++i) { const auto* field = oneof->field(i); const auto field_policy = policy_.GetOptionalPolicy(field); if (field_policy == OptionalPolicy::kAlwaysNull) continue; - if (recursion_type == RecursionType::kInfinitelyRecursive) { - is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && - field->message_type() && - IsProtoRecursive(field, parents, recursion_type); - if (!is_oneof_recursive) { - return false; - } - } else { - if (field->message_type() && - IsProtoRecursive(field, parents, recursion_type)) { - return true; - } - } + is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && + field->message_type() && + IsProtoRecursive(field, parents); + if (!is_oneof_recursive) return false; } return is_oneof_recursive; } @@ -1829,7 +1802,6 @@ class ProtobufDomainUntypedImpl // If `field` is nullptr, all fields of `descriptor` are checked. bool IsProtoRecursive(const FieldDescriptor* field, absl::flat_hash_set& parents, - RecursionType recursion_type, const Descriptor* descriptor = nullptr) const { if (field != nullptr) { if (parents.contains(field)) return true; @@ -1841,7 +1813,7 @@ class ProtobufDomainUntypedImpl } for (int i = 0; i < descriptor->oneof_decl_count(); ++i) { const auto* oneof = descriptor->oneof_decl(i); - if (IsOneofRecursive(oneof, parents, recursion_type)) { + if (IsOneofRecursive(oneof, parents)) { if (field != nullptr) parents.erase(field); return true; } @@ -1857,12 +1829,8 @@ class ProtobufDomainUntypedImpl default_domain->Init(prng); continue; } - if (recursion_type == RecursionType::kInfinitelyRecursive) { - if (!MustBeSet(subfield)) continue; - } else { - if (MustBeUnset(subfield)) continue; - } - if (IsProtoRecursive(subfield, parents, recursion_type)) { + if (!MustBeSet(subfield)) continue; + if (IsProtoRecursive(subfield, parents)) { if (field != nullptr) parents.erase(field); return true; } @@ -1871,6 +1839,68 @@ class ProtobufDomainUntypedImpl return false; } + // A subset of proto types are considered as recursion breakers and won't + // get recursively initialized during domain initialization to avoid + // non-terminating initialization. + // + // Returns true if the `field` (F0) does not have to be set, and there are + // subfields in the form: F0 -> F1 -> ... -> Fn -> F0 or F20 -> F19 ... -> F0 + // and none of other Fi-s are marked as recursion breakers so far. In other + // words, this method computes recursion breakers and check membership of + // `field` in the set of recursion breakers. + bool IsRecursionBreaker(const FieldDescriptor* field) { + FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); + if (!field->message_type()) return false; + absl::flat_hash_set parents; + return IsRecursionBreaker(/*root=*/field, field, parents); + } + + bool IsRecursionBreaker( + const FieldDescriptor* root, const FieldDescriptor* field, + absl::flat_hash_set& parents) const { + ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); + static absl::NoDestructor< + absl::flat_hash_map, bool>> + cache ABSL_GUARDED_BY(mutex); + { + absl::MutexLock l(&mutex); + auto it = cache->find({policy_.id(), field}); + if (it != cache->end()) return it->second; + } + // Cannot break the recursion for required fields. + bool can_be_unset = !MustBeSet(field); + if (field->containing_oneof() && !can_be_unset) { // oneof must be set + // We check whether `field` is infinitely recursive without considering + // other oneof fields. If it is, there's another field in the oneof that + // can be set. + absl::flat_hash_set subfield_parents; + subfield_parents.insert(field); + can_be_unset = IsProtoRecursive(field, subfield_parents); + } + if (can_be_unset) { + // Break recursion for deeply nested or recursive protos. + if (parents.size() > 20 || parents.contains(field)) { + absl::MutexLock l(&mutex); + cache->insert({{policy_.id(), field}, true}); + return true; + } + parents.insert(field); + } + for (const FieldDescriptor* subfield : + GetProtobufFields(field->message_type())) { + if (!subfield->message_type()) continue; + if (MustBeUnset(subfield)) continue; + IsRecursionBreaker(root, subfield, parents); + } + if (can_be_unset) parents.erase(field); + absl::MutexLock l(&mutex); + // If the result is computed while visiting the children, we shouldn't + // overwrite. For example, if we visit A -> B -> C -> A, we can return the + // result of the nested A for upper-level A. + auto [it, inserted] = cache->insert({{policy_.id(), field}, false}); + return it->second; + } + bool IsRequired(const FieldDescriptor* field) const { return field->is_required() || IsMapValueMessage(field); }