Skip to content

Commit

Permalink
Initialize recursive protobuf fields more efficiently.
Browse files Browse the repository at this point in the history
Currently, a field (F0) is not initialized if there are subfields of the form F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs. This means that for a huge protobuf field that has a recursive subfield deep in its definition, the whole field is not initialized. And later, even when F0 is initialized, F1 won't get initialized, etc. This could be very inefficient. To avoid this, we define "recursion breaker fields". For example, "Fs" becomes a recursion breaker. Then, all fields up to Fs are initialized. And later when Fs gets initialized, all Fs -> ... -> Fn get initialized.

This CL consists of the following changes:
- IsProtoRecursive deals with infinite recursions only.
- IsFinitelyRecursive is replaced with IsRecursionBreaker which is implemented separately.

PiperOrigin-RevId: 722802824
  • Loading branch information
hadi88 authored and copybara-github committed Feb 25, 2025
1 parent c7651e4 commit b6c2f2c
Showing 1 changed file with 76 additions and 46 deletions.
122 changes: 76 additions & 46 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ class ProtobufDomainUntypedImpl
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (IsCustomizedRecursivelyOnly()) {
if (!MustBeSet(field) && IsFieldFinitelyRecursive(field)) {
if (!MustBeSet(field) && IsRecursionBreaker(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
Expand Down Expand Up @@ -966,7 +966,7 @@ class ProtobufDomainUntypedImpl
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
if (IsCustomizedRecursivelyOnly()) {
if (IsFieldFinitelyRecursive(oneof->field(i))) continue;
if (IsRecursionBreaker(oneof->field(i))) continue;
if (MustBeUnset(oneof->field(i))) continue;
}
fields.push_back(i);
Expand Down Expand Up @@ -1704,21 +1704,12 @@ class ProtobufDomainUntypedImpl
return GetDomainForField<T, is_repeated>(field, /*use_policy=*/false);
}

// Analysis type for protobuf recursions.
enum class RecursionType {
// The proto contains a proto of type P, that must contain another P.
kInfinitelyRecursive,
// The proto contains a proto of type P, that can contain another P.
kFinitelyRecursive,
};

// Returns true if there are subprotos in the `descriptor` that form an
// infinite recursion.
bool IsInfinitelyRecursive(const Descriptor* descriptor) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsProtoRecursive(/*field=*/nullptr, parents,
RecursionType::kInfinitelyRecursive, descriptor);
return IsProtoRecursive(/*field=*/nullptr, parents, descriptor);
}

// Returns true if there are subfields in the `field` that form an
Expand All @@ -1727,14 +1718,6 @@ class ProtobufDomainUntypedImpl
// customized using `WithFieldsAlwaysSet`).
bool IsInfinitelyRecursive(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsProtoRecursive(field, parents,
RecursionType::kInfinitelyRecursive);
}

bool IsFieldFinitelyRecursive(const FieldDescriptor* field) {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (!field->message_type()) return false;
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<
absl::flat_hash_map<std::pair<int64_t, const FieldDescriptor*>, bool>>
Expand All @@ -1745,8 +1728,7 @@ class ProtobufDomainUntypedImpl
if (it != cache->end()) return it->second;
}
absl::flat_hash_set<const FieldDescriptor*> parents;
bool result =
IsProtoRecursive(field, parents, RecursionType::kFinitelyRecursive);
bool result = IsProtoRecursive(field, parents);
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, result});
return result;
Expand All @@ -1760,27 +1742,18 @@ class ProtobufDomainUntypedImpl
return index == kFieldCountIndex;
}

bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const FieldDescriptor*>& parents,
RecursionType recursion_type) const {
bool IsOneofRecursive(
const OneofDescriptor* oneof,
absl::flat_hash_set<const FieldDescriptor*>& parents) const {
bool is_oneof_recursive = false;
for (int i = 0; i < oneof->field_count(); ++i) {
const auto* field = oneof->field(i);
const auto field_policy = policy_.GetOptionalPolicy(field);
if (field_policy == OptionalPolicy::kAlwaysNull) continue;
if (recursion_type == RecursionType::kInfinitelyRecursive) {
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
field->message_type() &&
IsProtoRecursive(field, parents, recursion_type);
if (!is_oneof_recursive) {
return false;
}
} else {
if (field->message_type() &&
IsProtoRecursive(field, parents, recursion_type)) {
return true;
}
}
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
field->message_type() &&
IsProtoRecursive(field, parents);
if (!is_oneof_recursive) return false;
}
return is_oneof_recursive;
}
Expand Down Expand Up @@ -1829,7 +1802,6 @@ class ProtobufDomainUntypedImpl
// If `field` is nullptr, all fields of `descriptor` are checked.
bool IsProtoRecursive(const FieldDescriptor* field,
absl::flat_hash_set<const FieldDescriptor*>& parents,
RecursionType recursion_type,
const Descriptor* descriptor = nullptr) const {
if (field != nullptr) {
if (parents.contains(field)) return true;
Expand All @@ -1841,7 +1813,7 @@ class ProtobufDomainUntypedImpl
}
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
const auto* oneof = descriptor->oneof_decl(i);
if (IsOneofRecursive(oneof, parents, recursion_type)) {
if (IsOneofRecursive(oneof, parents)) {
if (field != nullptr) parents.erase(field);
return true;
}
Expand All @@ -1857,12 +1829,8 @@ class ProtobufDomainUntypedImpl
default_domain->Init(prng);
continue;
}
if (recursion_type == RecursionType::kInfinitelyRecursive) {
if (!MustBeSet(subfield)) continue;
} else {
if (MustBeUnset(subfield)) continue;
}
if (IsProtoRecursive(subfield, parents, recursion_type)) {
if (!MustBeSet(subfield)) continue;
if (IsProtoRecursive(subfield, parents)) {
if (field != nullptr) parents.erase(field);
return true;
}
Expand All @@ -1871,6 +1839,68 @@ class ProtobufDomainUntypedImpl
return false;
}

// A subset of proto types are considered as recursion breakers and won't
// get recursively initialized during domain initialization to avoid
// non-terminating initialization.
//
// Returns true if the `field` (F0) does not have to be set, and there are
// subfields in the form: F0 -> F1 -> ... -> Fn -> F0 or F20 -> F19 ... -> F0
// and none of other Fi-s are marked as recursion breakers so far. In other
// words, this method computes recursion breakers and check membership of
// `field` in the set of recursion breakers.
bool IsRecursionBreaker(const FieldDescriptor* field) {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (!field->message_type()) return false;
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsRecursionBreaker(/*root=*/field, field, parents);
}

bool IsRecursionBreaker(
const FieldDescriptor* root, const FieldDescriptor* field,
absl::flat_hash_set<const FieldDescriptor*>& parents) const {
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<
absl::flat_hash_map<std::pair<int64_t, const FieldDescriptor*>, bool>>
cache ABSL_GUARDED_BY(mutex);
{
absl::MutexLock l(&mutex);
auto it = cache->find({policy_.id(), field});
if (it != cache->end()) return it->second;
}
// Cannot break the recursion for required fields.
bool can_be_unset = !MustBeSet(field);
if (field->containing_oneof() && !can_be_unset) { // oneof must be set
// We check whether `field` is infinitely recursive without considering
// other oneof fields. If it is, there's another field in the oneof that
// can be set.
absl::flat_hash_set<const FieldDescriptor*> subfield_parents;
subfield_parents.insert(field);
can_be_unset = IsProtoRecursive(field, subfield_parents);
}
if (can_be_unset) {
// Break recursion for deeply nested or recursive protos.
if (parents.size() > 20 || parents.contains(field)) {
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, true});
return true;
}
parents.insert(field);
}
for (const FieldDescriptor* subfield :
GetProtobufFields(field->message_type())) {
if (!subfield->message_type()) continue;
if (MustBeUnset(subfield)) continue;
IsRecursionBreaker(root, subfield, parents);
}
if (can_be_unset) parents.erase(field);
absl::MutexLock l(&mutex);
// If the result is computed while visiting the children, we shouldn't
// overwrite. For example, if we visit A -> B -> C -> A, we can return the
// result of the nested A for upper-level A.
auto [it, inserted] = cache->insert({{policy_.id(), field}, false});
return it->second;
}

bool IsRequired(const FieldDescriptor* field) const {
return field->is_required() || IsMapValueMessage(field);
}
Expand Down

0 comments on commit b6c2f2c

Please sign in to comment.