Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions dojo/finding/deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,3 +1216,40 @@ def match_finding_to_existing_findings(finding, product=None, engagement=None, t
deduplicationLogger.debug(qs.query)

return qs


def hashcode_values_writer(model_type, batch, fields):
"""
mass_model_updater ``writer`` for the hash-recompute paths (dedupe command + tuner
async tasks). The hash fields are text columns, so write the whole batch with one
``UPDATE t SET f = v.f FROM (VALUES (pk, f...), ...) WHERE t.pk = v.pk`` instead of
bulk_update's per-row CASE/WHEN. Values are bound as parameters and cast to text
(which also resolves the type of an all-NULL column). PostgreSQL only; falls back
to bulk_update on other backends.
"""
from django.db import connection # noqa: PLC0415

if not batch:
return
if connection.vendor != "postgresql":
model_type.objects.bulk_update(batch, fields)
return

meta = model_type._meta
columns = [meta.get_field(name).column for name in fields]
row_placeholder = "(" + ",".join(["%s"] * (1 + len(fields))) + ")"
placeholders = ",".join([row_placeholder] * len(batch))
params = []
for obj in batch:
params.append(obj.pk)
params.extend(getattr(obj, name) for name in fields)
value_cols = ", ".join(f"c{idx}" for idx in range(1 + len(fields)))
set_clause = ", ".join(f'"{col}" = v.c{idx + 1}::text' for idx, col in enumerate(columns))
sql = (
f'UPDATE "{meta.db_table}" AS t '
f"SET {set_clause} "
f"FROM (VALUES {placeholders}) AS v({value_cols}) "
f'WHERE t."{meta.pk.column}" = v.c0'
)
with connection.cursor() as cursor:
cursor.execute(sql, params)
12 changes: 11 additions & 1 deletion dojo/management/commands/dedupe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
do_dedupe_finding_task,
do_dedupe_finding_task_internal,
get_finding_models_for_deduplication,
hashcode_values_writer,
)
from dojo.models import Finding, Product
from dojo.utils import (
Expand Down Expand Up @@ -96,6 +97,10 @@ def _run_dedupe(self, *, restrict_to_parsers, hash_code_only, dedupe_only, dedup
"test", "test__engagement", "test__engagement__product", "test__test_type",
).prefetch_related(
"locations",
# vulnerability_id_set feeds hash_code computation for parsers whose
# HASHCODE_FIELDS_PER_SCANNER includes vulnerability_ids; prefetch to avoid
# a per-finding query in get_vulnerability_ids().
"vulnerability_id_set",
Prefetch(
"original_finding",
queryset=Finding.objects.only("id", "duplicate_finding_id").order_by("-id"),
Expand All @@ -108,6 +113,10 @@ def _run_dedupe(self, *, restrict_to_parsers, hash_code_only, dedupe_only, dedup
"test", "test__engagement", "test__engagement__product", "test__test_type",
).prefetch_related(
"endpoints",
# vulnerability_id_set feeds hash_code computation for parsers whose
# HASHCODE_FIELDS_PER_SCANNER includes vulnerability_ids; prefetch to avoid
# a per-finding query in get_vulnerability_ids().
"vulnerability_id_set",
Prefetch(
"original_finding",
queryset=Finding.objects.only("id", "duplicate_finding_id").order_by("-id"),
Expand All @@ -118,7 +127,8 @@ def _run_dedupe(self, *, restrict_to_parsers, hash_code_only, dedupe_only, dedup
if not dedupe_only:
logger.info("######## Start Updating Hashcodes (foreground) ########")

mass_model_updater(Finding, findings, generate_hash_code, fields=["hash_code"], order="asc", log_prefix="hash_code computation ")
hash_code_writer = hashcode_values_writer if settings.MASS_HASH_CODE_USE_SQL_WRITER else None
mass_model_updater(Finding, findings, generate_hash_code, fields=["hash_code"], order="asc", log_prefix="hash_code computation ", writer=hash_code_writer)

logger.info("######## Done Updating Hashcodes########")

Expand Down
9 changes: 5 additions & 4 deletions dojo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2990,10 +2990,11 @@ def _get_unsaved_vulnerability_ids(finding) -> str:

def _get_saved_vulnerability_ids(finding) -> str:
if finding.id is not None:
vulnerability_ids = Vulnerability_Id.objects.filter(finding=finding)
deduplicationLogger.debug("get_vulnerability_ids after the finding was saved. Vulnerability references count: " + str(vulnerability_ids.count()))
# convert list of vulnerability_ids to the list of their canonical representation
vulnerability_id_str_list = [str(vulnerability_id) for vulnerability_id in vulnerability_ids.all()]
# Use the reverse relation (vulnerability_id_set) rather than a fresh
# Vulnerability_Id.objects.filter(...) so prefetch_related("vulnerability_id_set")
# is honored — avoids an N+1 (COUNT + SELECT per finding) during dedupe/hashcode.
vulnerability_id_str_list = [str(vulnerability_id) for vulnerability_id in finding.vulnerability_id_set.all()]
deduplicationLogger.debug("get_vulnerability_ids after the finding was saved. Vulnerability references count: " + str(len(vulnerability_id_str_list)))
# sort vulnerability_ids strings
return "".join(sorted(vulnerability_id_str_list))
return ""
Expand Down
4 changes: 4 additions & 0 deletions dojo/settings/settings.dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@
DD_HASHCODE_FIELDS_PER_SCANNER=(str, ""),
# Set deduplication algorithms per parser, via en env variable that contains a JSON string
DD_DEDUPLICATION_ALGORITHM_PER_PARSER=(str, ""),
# When True, hash_code mass updates use the PostgreSQL VALUES-join fast writer
# (falls back to bulk_update on other backends). Set False to always use bulk_update.
DD_MASS_HASH_CODE_USE_SQL_WRITER=(bool, True),
# Specifies whether the "first seen" date of a given report should be used over the "last seen" date
DD_USE_FIRST_SEEN=(bool, False),
# When set to True, use the older version of the qualys parser that is a more heavy handed in setting severity
Expand Down Expand Up @@ -1693,6 +1696,7 @@ def generate_url(scheme, double_slashes, user, password, host, port, path, param

USE_FIRST_SEEN = env("DD_USE_FIRST_SEEN")
USE_QUALYS_LEGACY_SEVERITY_PARSING = env("DD_QUALYS_LEGACY_SEVERITY_PARSING")
MASS_HASH_CODE_USE_SQL_WRITER = env("DD_MASS_HASH_CODE_USE_SQL_WRITER")

# ------------------------------------------------------------------------------
# Notifications
Expand Down
57 changes: 49 additions & 8 deletions dojo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,21 @@ def add_field_errors_to_response(form):
add_error_message_to_response(error)


def mass_model_updater(model_type, models, function, fields, page_size=1000, order="asc", log_prefix=""):
def default_mass_model_writer(model_type, batch, fields):
"""Default mass_model_updater ``writer``: persist a batch via Django's bulk_update."""
if not batch:
return
model_type.objects.bulk_update(batch, fields)


def _flush_mass_update(model_type, batch, fields, writer):
"""Persist a batch via the supplied writer, else the default writer."""
if not batch:
return
(writer or default_mass_model_writer)(model_type, batch, fields)


def mass_model_updater(model_type, models, function, fields, page_size=1000, order="asc", log_prefix="", *, skip_unchanged=True, writer=None):
"""
Using the default for model in queryset can be slow for large querysets. Even
when using paging as LIMIT and OFFSET are slow on database. In some cases we can optimize
Expand All @@ -1586,6 +1600,14 @@ def mass_model_updater(model_type, models, function, fields, page_size=1000, ord
was processed and continue from there on the next page. This is fast because
it results in an index seek instead of executing the whole query again and skipping
the first X items.

When ``fields`` is given:
- skip_unchanged (default True): rows whose tracked ``fields`` were not changed by
``function`` are not written (compared against the values loaded from the page
query; deferred fields are read from ``__dict__`` so no extra query is issued).
- writer (optional): a callable ``writer(model_type, batch, fields)`` used to persist
each batch instead of Django's ``bulk_update`` (e.g. a backend-specific fast path).
Defaults to ``bulk_update``.
"""
# force ordering by id to make our paging work
last_id = 0
Expand All @@ -1608,6 +1630,7 @@ def mass_model_updater(model_type, models, function, fields, page_size=1000, ord
logger.debug("%s found %d models for mass update:", log_prefix, total_count)

i = 0
written = 0
batch = []
total_pages = (total_count // page_size) + 2
# logger.debug("pages to process: %d", total_pages)
Expand All @@ -1623,22 +1646,40 @@ def mass_model_updater(model_type, models, function, fields, page_size=1000, ord
i += 1
last_id = model.id

# snapshot tracked fields before mutation (read from __dict__ to avoid
# triggering a deferred-field load); used to skip no-op writes
before = None
if fields and skip_unchanged:
before = [model.__dict__.get(f) for f in fields]

function(model)

batch.append(model)
if fields and skip_unchanged and before is not None and all(
model.__dict__.get(f) == old for f, old in zip(fields, before, strict=True)
):
# nothing changed for this row -> no write needed
pass
else:
batch.append(model)

if (i > 0 and i % page_size == 0):
if fields:
model_type.objects.bulk_update(batch, fields)
if fields and len(batch) >= page_size:
_flush_mass_update(model_type, batch, fields, writer)
written += len(batch)
batch = []
elif not fields and len(batch) >= page_size:
# function has side effects only; keep memory bounded
batch = []

if i > 0 and i % page_size == 0:
logger.debug("%s%s out of %s models processed ...", log_prefix, i, total_count)

logger.info("%s%s out of %s models processed ...", log_prefix, i, total_count)

if fields:
model_type.objects.bulk_update(batch, fields)
if fields and batch:
_flush_mass_update(model_type, batch, fields, writer)
written += len(batch)
batch = []
logger.info("%s%s out of %s models processed ...", log_prefix, i, total_count)
logger.info("%s%s out of %s models processed (%s written) ...", log_prefix, i, total_count, written)


def to_str_typed(obj):
Expand Down
165 changes: 165 additions & 0 deletions unittests/test_mass_model_updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Tests for mass_model_updater performance optimizations:
- skip-unchanged: rows whose tracked fields did not change are not written.
- fast VALUES UPDATE: on PostgreSQL with simple scalar fields, writes use a
single `UPDATE ... FROM (VALUES ...)` join instead of bulk_update's CASE/WHEN.
Both must remain byte-for-byte correct (incl. NULLs) and fall back to bulk_update
for non-postgres backends / non-simple fields.
"""

from django.db import connection
from django.test.utils import CaptureQueriesContext

from dojo.finding.deduplication import hashcode_values_writer
from dojo.models import Finding
from dojo.utils import mass_model_updater

from .dojo_test_case import DojoTestCase, versioned_fixtures


def _updates(captured):
return [q["sql"] for q in captured if q["sql"].lstrip().upper().startswith("UPDATE")]


@versioned_fixtures
class TestMassModelUpdater(DojoTestCase):
fixtures = ["dojo_testdata.json"]

def _finding_ids(self, n=5):
return list(Finding.objects.exclude(duplicate=True).values_list("id", flat=True)[:n])

def test_writes_changed_values(self):
ids = self._finding_ids()
qs = Finding.objects.filter(id__in=ids)

def fn(f):
f.hash_code = f"changed-{f.id}"

mass_model_updater(Finding, qs, fn, fields=["hash_code"], order="asc")

for fid in ids:
self.assertEqual(
Finding.objects.get(id=fid).hash_code,
f"changed-{fid}",
msg=f"finding {fid} hash_code not persisted",
)

def test_skips_unchanged_rows(self):
# Regression/perf: re-running over already-correct rows must issue NO UPDATE.
ids = self._finding_ids()
# First, set known values.
mass_model_updater(
Finding, Finding.objects.filter(id__in=ids),
lambda f: setattr(f, "hash_code", f"v-{f.id}"), fields=["hash_code"], order="asc",
)
# Now recompute to the SAME values → expect zero UPDATE statements.
with CaptureQueriesContext(connection) as ctx:
mass_model_updater(
Finding, Finding.objects.filter(id__in=ids),
lambda f: setattr(f, "hash_code", f"v-{f.id}"), fields=["hash_code"], order="asc",
)
self.assertEqual(
len(_updates(ctx.captured_queries)), 0,
msg=f"expected 0 UPDATEs for unchanged rows, got: {_updates(ctx.captured_queries)}",
)

def test_handles_null_values(self):
ids = self._finding_ids()
qs = Finding.objects.filter(id__in=ids)

def fn(f):
# alternate None / value to exercise NULL in the VALUES list
f.hash_code = None if f.id % 2 == 0 else f"h-{f.id}"

mass_model_updater(Finding, qs, fn, fields=["hash_code"], order="asc")

for fid in ids:
expected = None if fid % 2 == 0 else f"h-{fid}"
self.assertEqual(
Finding.objects.get(id=fid).hash_code, expected,
msg=f"finding {fid} hash_code mismatch (null handling)",
)

def test_writer_hook_is_used_for_changed_rows(self):
# A caller-supplied writer replaces bulk_update for persisting batches.
ids = self._finding_ids()
calls = []

def writer(model_type, batch, fields):
calls.append((model_type, [m.id for m in batch], list(fields)))
model_type.objects.bulk_update(batch, fields)

mass_model_updater(
Finding, Finding.objects.filter(id__in=ids),
lambda f: setattr(f, "hash_code", f"w-{f.id}"), fields=["hash_code"],
order="asc", writer=writer,
)
written_ids = sorted(i for _, batch_ids, _ in calls for i in batch_ids)
self.assertEqual(written_ids, sorted(ids), msg=f"writer not called for all changed rows: {calls}")
for fid in ids:
self.assertEqual(Finding.objects.get(id=fid).hash_code, f"w-{fid}")

def test_writer_hook_not_called_when_nothing_changed(self):
ids = self._finding_ids()
mass_model_updater(
Finding, Finding.objects.filter(id__in=ids),
lambda f: setattr(f, "hash_code", f"s-{f.id}"), fields=["hash_code"], order="asc",
)
called = []
mass_model_updater(
Finding, Finding.objects.filter(id__in=ids),
lambda f: setattr(f, "hash_code", f"s-{f.id}"), fields=["hash_code"],
order="asc", writer=lambda *a: called.append(a),
)
self.assertEqual(called, [], msg="writer must not be called when no row changed")

def test_skip_unchanged_can_be_disabled(self):
ids = self._finding_ids()
mass_model_updater(
Finding, Finding.objects.filter(id__in=ids),
lambda f: setattr(f, "hash_code", f"x-{f.id}"), fields=["hash_code"], order="asc",
)
with CaptureQueriesContext(connection) as ctx:
mass_model_updater(
Finding, Finding.objects.filter(id__in=ids),
lambda f: setattr(f, "hash_code", f"x-{f.id}"), fields=["hash_code"],
order="asc", skip_unchanged=False,
)
self.assertGreater(
len(_updates(ctx.captured_queries)), 0,
msg="skip_unchanged=False must still write unchanged rows",
)

def test_hashcode_values_writer_uses_values_sql_on_postgres(self):
if connection.vendor != "postgresql":
self.skipTest("VALUES fast path is postgres-only")
objs = list(Finding.objects.filter(id__in=self._finding_ids()))
for o in objs:
o.hash_code = f"vw-{o.id}"
with CaptureQueriesContext(connection) as ctx:
hashcode_values_writer(Finding, objs, ["hash_code"])
ups = _updates(ctx.captured_queries)
self.assertTrue(any("from (values" in u.lower() for u in ups), msg=f"expected VALUES update, got: {ups}")
self.assertFalse(any("case when" in u.lower() for u in ups), msg=f"unexpected CASE WHEN: {ups}")
for o in objs:
self.assertEqual(Finding.objects.get(id=o.id).hash_code, f"vw-{o.id}")

def test_hashcode_values_writer_handles_null(self):
objs = list(Finding.objects.filter(id__in=self._finding_ids()))
for i, o in enumerate(objs):
o.hash_code = None if i % 2 == 0 else f"vn-{o.id}"
hashcode_values_writer(Finding, objs, ["hash_code"])
for i, o in enumerate(objs):
expected = None if i % 2 == 0 else f"vn-{o.id}"
self.assertEqual(Finding.objects.get(id=o.id).hash_code, expected)

def test_fields_none_calls_function_without_writing(self):
ids = self._finding_ids()
seen = []
with CaptureQueriesContext(connection) as ctx:
mass_model_updater(
Finding, Finding.objects.filter(id__in=ids),
lambda f: seen.append(f.id), fields=None, order="asc",
)
self.assertEqual(sorted(seen), sorted(ids), msg="function must run for every model")
self.assertEqual(len(_updates(ctx.captured_queries)), 0, msg="fields=None must not write")
Loading