Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions vulnerabilities/pipelines/v2_improvers/collect_patch_texts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#
# Copyright (c) nexB Inc. and others. All rights reserved.
# VulnerableCode is a trademark of nexB Inc.
# SPDX-License-Identifier: Apache-2.0
# See http://www.apache.org/licenses/LICENSE-2.0 for the license text.
# See https://github.com/aboutcode-org/vulnerablecode for support or download.
# See https://aboutcode.org for more information about nexB OSS projects.
#

import logging

import requests
from aboutcode.pipeline import LoopProgress
from django.db.models import Q

from vulnerabilities.models import Patch
from vulnerabilities.pipelines import VulnerableCodePipeline


class CollectPatchTextsPipeline(VulnerableCodePipeline):
"""
Improver pipeline to collect missing patch texts for Patch objects that have a patch_url.
"""

pipeline_id = "collect_patch_texts_v2"
license_expression = None

@classmethod
def steps(cls):
return (cls.collect_and_store_patch_texts,)

def collect_and_store_patch_texts(self):
patches_without_text = Patch.objects.filter(
Q(patch_url__isnull=False) & ~Q(patch_url=""),
Q(patch_text__isnull=True) | Q(patch_text=""),
)

self.log(f"Processing {patches_without_text.count():,d} patches to collect text.")

updated_patch_count = 0
progress = LoopProgress(total_iterations=patches_without_text.count(), logger=self.log)

for patch in progress.iter(patches_without_text.iterator(chunk_size=500)):
raw_url = get_raw_patch_url(patch.patch_url)
if not raw_url:
continue

try:
response = requests.get(raw_url, timeout=10)
if response.status_code == 200:
patch.patch_text = response.text
patch.save()
updated_patch_count += 1
else:
self.log(
f"Failed to fetch patch from {raw_url}: Status {response.status_code}",
level=logging.WARNING if response.status_code < 500 else logging.ERROR,
)
except requests.RequestException as e:
self.log(f"Error fetching patch from {raw_url}: {e}", level=logging.ERROR)

self.log(f"Successfully collected text for {updated_patch_count:,d} Patch entries.")


def get_raw_patch_url(url):
"""
Return a fetchable raw patch URL from common VCS hosting URLs,
or the URL itself if it already points to a .patch or .diff file.
Return None if the URL type is not recognized.
"""
if not url:
return None

url = url.strip()

if "github.com" in url and "/commit/" in url and not url.endswith(".patch"):
return f"{url}.patch"

if "github.com" in url and "/pull/" in url and not url.endswith(".patch"):
return f"{url}.patch"

if "gitlab.com" in url and "/commit/" in url and not url.endswith(".patch"):
return f"{url}.patch"

if "gitlab.com" in url and "/merge_requests/" in url and not url.endswith(".patch"):
return f"{url}.patch"

if url.endswith(".patch") or url.endswith(".diff"):
return url

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Copyright (c) nexB Inc. and others. All rights reserved.
# VulnerableCode is a trademark of nexB Inc.
# SPDX-License-Identifier: Apache-2.0
# See http://www.apache.org/licenses/LICENSE-2.0 for the license text.
# See https://github.com/aboutcode-org/vulnerablecode for support or download.
# See https://aboutcode.org for more information about nexB OSS projects.
#

import unittest
from unittest.mock import MagicMock
from unittest.mock import patch as mock_patch

from vulnerabilities.pipelines.v2_improvers.collect_patch_texts import CollectPatchTextsPipeline
from vulnerabilities.pipelines.v2_improvers.collect_patch_texts import get_raw_patch_url


class TestCollectPatchTextsPipeline(unittest.TestCase):
def setUp(self):
self.pipeline = CollectPatchTextsPipeline()

def test_get_raw_patch_url(self):
url = "https://github.com/user/repo/commit/abc1234567890"
expected = "https://github.com/user/repo/commit/abc1234567890.patch"
self.assertEqual(get_raw_patch_url(url), expected)

url = "https://github.com/user/repo/pull/123"
expected = "https://github.com/user/repo/pull/123.patch"
self.assertEqual(get_raw_patch_url(url), expected)

url = "https://gitlab.com/user/repo/-/commit/abc1234567890"
expected = "https://gitlab.com/user/repo/-/commit/abc1234567890.patch"
self.assertEqual(get_raw_patch_url(url), expected)

url = "https://gitlab.com/user/repo/-/merge_requests/123"
expected = "https://gitlab.com/user/repo/-/merge_requests/123.patch"
self.assertEqual(get_raw_patch_url(url), expected)

url = "https://example.com/fix.patch"
self.assertEqual(get_raw_patch_url(url), url)

url = "https://example.com/some/article"
self.assertIsNone(get_raw_patch_url(url))

@mock_patch("vulnerabilities.pipelines.v2_improvers.collect_patch_texts.Patch")
@mock_patch("requests.get")
def test_collect_and_store_patch_texts(self, mock_get, mock_patch_model):
p1 = MagicMock(patch_url="https://github.com/u/r/commit/c1", patch_text=None)
p2 = MagicMock(patch_url="https://github.com/u/r/pull/1", patch_text="")
p3 = MagicMock(patch_url="https://example.com/no-patch", patch_text=None)
p4 = MagicMock(patch_url="https://example.com/fix.patch", patch_text=None)

mock_qs = MagicMock()
mock_qs.count.return_value = 4
mock_qs.iterator.return_value = [p1, p2, p3, p4]

mock_patch_model.objects.filter.return_value = mock_qs

def side_effect(url, timeout=10):
mock_resp = MagicMock()
mock_resp.status_code = 404
if url == "https://github.com/u/r/commit/c1.patch":
mock_resp.status_code = 200
mock_resp.text = "diff --git a/file b/file\n+code"
elif url == "https://github.com/u/r/pull/1.patch":
mock_resp.status_code = 200
mock_resp.text = "diff --git a/pr b/pr\n+pr_code"
elif url == "https://example.com/fix.patch":
mock_resp.status_code = 200
mock_resp.text = "diff --git a/direct b/direct\n+direct_code"
return mock_resp

mock_get.side_effect = side_effect

self.pipeline.collect_and_store_patch_texts()

self.assertEqual(p1.patch_text, "diff --git a/file b/file\n+code")
p1.save.assert_called_once()

self.assertEqual(p2.patch_text, "diff --git a/pr b/pr\n+pr_code")
p2.save.assert_called_once()

p3.save.assert_not_called()

self.assertEqual(p4.patch_text, "diff --git a/direct b/direct\n+direct_code")
p4.save.assert_called_once()