Skip to content

Commit

Permalink
Automated rollback of commit 313a522.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731092850
  • Loading branch information
FuzzTest Team authored and copybara-github committed Feb 26, 2025
1 parent 313a522 commit d2018ac
Showing 1 changed file with 46 additions and 76 deletions.
122 changes: 46 additions & 76 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ class ProtobufDomainUntypedImpl
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (IsCustomizedRecursivelyOnly()) {
if (!MustBeSet(field) && IsRecursionBreaker(field)) {
if (!MustBeSet(field) && IsFieldFinitelyRecursive(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
Expand Down Expand Up @@ -966,7 +966,7 @@ class ProtobufDomainUntypedImpl
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
if (IsCustomizedRecursivelyOnly()) {
if (IsRecursionBreaker(oneof->field(i))) continue;
if (IsFieldFinitelyRecursive(oneof->field(i))) continue;
if (MustBeUnset(oneof->field(i))) continue;
}
fields.push_back(i);
Expand Down Expand Up @@ -1704,12 +1704,21 @@ class ProtobufDomainUntypedImpl
return GetDomainForField<T, is_repeated>(field, /*use_policy=*/false);
}

// Analysis type for protobuf recursions.
enum class RecursionType {
// The proto contains a proto of type P, that must contain another P.
kInfinitelyRecursive,
// The proto contains a proto of type P, that can contain another P.
kFinitelyRecursive,
};

// Returns true if there are subprotos in the `descriptor` that form an
// infinite recursion.
bool IsInfinitelyRecursive(const Descriptor* descriptor) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsProtoRecursive(/*field=*/nullptr, parents, descriptor);
return IsProtoRecursive(/*field=*/nullptr, parents,
RecursionType::kInfinitelyRecursive, descriptor);
}

// Returns true if there are subfields in the `field` that form an
Expand All @@ -1718,6 +1727,14 @@ class ProtobufDomainUntypedImpl
// customized using `WithFieldsAlwaysSet`).
bool IsInfinitelyRecursive(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsProtoRecursive(field, parents,
RecursionType::kInfinitelyRecursive);
}

bool IsFieldFinitelyRecursive(const FieldDescriptor* field) {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (!field->message_type()) return false;
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<
absl::flat_hash_map<std::pair<int64_t, const FieldDescriptor*>, bool>>
Expand All @@ -1728,7 +1745,8 @@ class ProtobufDomainUntypedImpl
if (it != cache->end()) return it->second;
}
absl::flat_hash_set<const FieldDescriptor*> parents;
bool result = IsProtoRecursive(field, parents);
bool result =
IsProtoRecursive(field, parents, RecursionType::kFinitelyRecursive);
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, result});
return result;
Expand All @@ -1742,18 +1760,27 @@ class ProtobufDomainUntypedImpl
return index == kFieldCountIndex;
}

bool IsOneofRecursive(
const OneofDescriptor* oneof,
absl::flat_hash_set<const FieldDescriptor*>& parents) const {
bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const FieldDescriptor*>& parents,
RecursionType recursion_type) const {
bool is_oneof_recursive = false;
for (int i = 0; i < oneof->field_count(); ++i) {
const auto* field = oneof->field(i);
const auto field_policy = policy_.GetOptionalPolicy(field);
if (field_policy == OptionalPolicy::kAlwaysNull) continue;
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
field->message_type() &&
IsProtoRecursive(field, parents);
if (!is_oneof_recursive) return false;
if (recursion_type == RecursionType::kInfinitelyRecursive) {
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
field->message_type() &&
IsProtoRecursive(field, parents, recursion_type);
if (!is_oneof_recursive) {
return false;
}
} else {
if (field->message_type() &&
IsProtoRecursive(field, parents, recursion_type)) {
return true;
}
}
}
return is_oneof_recursive;
}
Expand Down Expand Up @@ -1802,6 +1829,7 @@ class ProtobufDomainUntypedImpl
// If `field` is nullptr, all fields of `descriptor` are checked.
bool IsProtoRecursive(const FieldDescriptor* field,
absl::flat_hash_set<const FieldDescriptor*>& parents,
RecursionType recursion_type,
const Descriptor* descriptor = nullptr) const {
if (field != nullptr) {
if (parents.contains(field)) return true;
Expand All @@ -1813,7 +1841,7 @@ class ProtobufDomainUntypedImpl
}
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
const auto* oneof = descriptor->oneof_decl(i);
if (IsOneofRecursive(oneof, parents)) {
if (IsOneofRecursive(oneof, parents, recursion_type)) {
if (field != nullptr) parents.erase(field);
return true;
}
Expand All @@ -1829,8 +1857,12 @@ class ProtobufDomainUntypedImpl
default_domain->Init(prng);
continue;
}
if (!MustBeSet(subfield)) continue;
if (IsProtoRecursive(subfield, parents)) {
if (recursion_type == RecursionType::kInfinitelyRecursive) {
if (!MustBeSet(subfield)) continue;
} else {
if (MustBeUnset(subfield)) continue;
}
if (IsProtoRecursive(subfield, parents, recursion_type)) {
if (field != nullptr) parents.erase(field);
return true;
}
Expand All @@ -1839,68 +1871,6 @@ class ProtobufDomainUntypedImpl
return false;
}

// A subset of proto types are considered as recursion breakers and won't
// get recursively initialized during domain initialization to avoid
// non-terminating initialization.
//
// Returns true if the `field` (F0) does not have to be set, and there are
// subfields in the form: F0 -> F1 -> ... -> Fn -> F0 or F20 -> F19 ... -> F0
// and none of other Fi-s are marked as recursion breakers so far. In other
// words, this method computes recursion breakers and check membership of
// `field` in the set of recursion breakers.
bool IsRecursionBreaker(const FieldDescriptor* field) {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (!field->message_type()) return false;
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsRecursionBreaker(/*root=*/field, field, parents);
}

bool IsRecursionBreaker(
const FieldDescriptor* root, const FieldDescriptor* field,
absl::flat_hash_set<const FieldDescriptor*>& parents) const {
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<
absl::flat_hash_map<std::pair<int64_t, const FieldDescriptor*>, bool>>
cache ABSL_GUARDED_BY(mutex);
{
absl::MutexLock l(&mutex);
auto it = cache->find({policy_.id(), field});
if (it != cache->end()) return it->second;
}
// Cannot break the recursion for required fields.
bool can_be_unset = !MustBeSet(field);
if (field->containing_oneof() && !can_be_unset) { // oneof must be set
// We check whether `field` is infinitely recursive without considering
// other oneof fields. If it is, there's another field in the oneof that
// can be set.
absl::flat_hash_set<const FieldDescriptor*> subfield_parents;
subfield_parents.insert(field);
can_be_unset = IsProtoRecursive(field, subfield_parents);
}
if (can_be_unset) {
// Break recursion for deeply nested or recursive protos.
if (parents.size() > 20 || parents.contains(field)) {
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, true});
return true;
}
parents.insert(field);
}
for (const FieldDescriptor* subfield :
GetProtobufFields(field->message_type())) {
if (!subfield->message_type()) continue;
if (MustBeUnset(subfield)) continue;
IsRecursionBreaker(root, subfield, parents);
}
if (can_be_unset) parents.erase(field);
absl::MutexLock l(&mutex);
// If the result is computed while visiting the children, we shouldn't
// overwrite. For example, if we visit A -> B -> C -> A, we can return the
// result of the nested A for upper-level A.
auto [it, inserted] = cache->insert({{policy_.id(), field}, false});
return it->second;
}

bool IsRequired(const FieldDescriptor* field) const {
return field->is_required() || IsMapValueMessage(field);
}
Expand Down

0 comments on commit d2018ac

Please sign in to comment.