Skip to content

Commit 27ca863

Browse files
authored
CM-46055 - Fix scan parameters (#288)
1 parent f1f7c63 commit 27ca863

File tree

4 files changed

+72
-65
lines changed

4 files changed

+72
-65
lines changed

cycode/cli/commands/scan/code_scanner.py

+67-59
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
def scan_sca_pre_commit(context: click.Context) -> None:
4747
scan_type = context.obj['scan_type']
48-
scan_parameters = get_default_scan_parameters(context)
48+
scan_parameters = get_scan_parameters(context)
4949
git_head_documents, pre_committed_documents = get_pre_commit_modified_documents(
5050
context.obj['progress_bar'], ScanProgressBarSection.PREPARE_LOCAL_FILES
5151
)
@@ -80,14 +80,13 @@ def scan_sca_commit_range(context: click.Context, path: str, commit_range: str)
8080

8181

8282
def scan_disk_files(context: click.Context, paths: Tuple[str]) -> None:
83-
scan_parameters = get_scan_parameters(context, paths)
8483
scan_type = context.obj['scan_type']
8584
progress_bar = context.obj['progress_bar']
8685

8786
try:
8887
documents = get_relevant_documents(progress_bar, ScanProgressBarSection.PREPARE_LOCAL_FILES, scan_type, paths)
8988
perform_pre_scan_documents_actions(context, scan_type, documents)
90-
scan_documents(context, documents, scan_parameters=scan_parameters)
89+
scan_documents(context, documents, get_scan_parameters(context, paths))
9190
except Exception as e:
9291
handle_scan_exception(context, e)
9392

@@ -151,14 +150,12 @@ def _enrich_scan_result_with_data_from_detection_rules(
151150

152151
def _get_scan_documents_thread_func(
153152
context: click.Context, is_git_diff: bool, is_commit_range: bool, scan_parameters: dict
154-
) -> Tuple[Callable[[List[Document]], Tuple[str, CliError, LocalScanResult]], str]:
153+
) -> Callable[[List[Document]], Tuple[str, CliError, LocalScanResult]]:
155154
cycode_client = context.obj['client']
156155
scan_type = context.obj['scan_type']
157156
severity_threshold = context.obj['severity_threshold']
158157
sync_option = context.obj['sync']
159158
command_scan_type = context.info_name
160-
aggregation_id = str(_generate_unique_id())
161-
scan_parameters['aggregation_id'] = aggregation_id
162159

163160
def _scan_batch_thread_func(batch: List[Document]) -> Tuple[str, CliError, LocalScanResult]:
164161
local_scan_result = error = error_message = None
@@ -227,7 +224,7 @@ def _scan_batch_thread_func(batch: List[Document]) -> Tuple[str, CliError, Local
227224

228225
return scan_id, error, local_scan_result
229226

230-
return _scan_batch_thread_func, aggregation_id
227+
return _scan_batch_thread_func
231228

232229

233230
def scan_commit_range(
@@ -287,20 +284,19 @@ def scan_commit_range(
287284
logger.debug('List of commit ids to scan, %s', {'commit_ids': commit_ids_to_scan})
288285
logger.debug('Starting to scan commit range (it may take a few minutes)')
289286

290-
scan_documents(context, documents_to_scan, is_git_diff=True, is_commit_range=True)
287+
scan_documents(
288+
context, documents_to_scan, get_scan_parameters(context, (path,)), is_git_diff=True, is_commit_range=True
289+
)
291290
return None
292291

293292

294293
def scan_documents(
295294
context: click.Context,
296295
documents_to_scan: List[Document],
296+
scan_parameters: dict,
297297
is_git_diff: bool = False,
298298
is_commit_range: bool = False,
299-
scan_parameters: Optional[dict] = None,
300299
) -> None:
301-
if not scan_parameters:
302-
scan_parameters = get_default_scan_parameters(context)
303-
304300
scan_type = context.obj['scan_type']
305301
progress_bar = context.obj['progress_bar']
306302

@@ -315,19 +311,15 @@ def scan_documents(
315311
)
316312
return
317313

318-
scan_batch_thread_func, aggregation_id = _get_scan_documents_thread_func(
319-
context, is_git_diff, is_commit_range, scan_parameters
320-
)
314+
scan_batch_thread_func = _get_scan_documents_thread_func(context, is_git_diff, is_commit_range, scan_parameters)
321315
errors, local_scan_results = run_parallel_batched_scan(
322316
scan_batch_thread_func, scan_type, documents_to_scan, progress_bar=progress_bar
323317
)
324318

325-
if len(local_scan_results) > 1:
326-
# if we used more than one batch, we need to fetch aggregate report url
327-
aggregation_report_url = _try_get_aggregation_report_url_if_needed(
328-
scan_parameters, context.obj['client'], scan_type
329-
)
330-
set_aggregation_report_url(context, aggregation_report_url)
319+
aggregation_report_url = _try_get_aggregation_report_url_if_needed(
320+
scan_parameters, context.obj['client'], scan_type
321+
)
322+
_set_aggregation_report_url(context, aggregation_report_url)
331323

332324
progress_bar.set_section_length(ScanProgressBarSection.GENERATE_REPORT, 1)
333325
progress_bar.update(ScanProgressBarSection.GENERATE_REPORT)
@@ -337,25 +329,6 @@ def scan_documents(
337329
print_results(context, local_scan_results, errors)
338330

339331

340-
def set_aggregation_report_url(context: click.Context, aggregation_report_url: Optional[str] = None) -> None:
341-
context.obj['aggregation_report_url'] = aggregation_report_url
342-
343-
344-
def _try_get_aggregation_report_url_if_needed(
345-
scan_parameters: dict, cycode_client: 'ScanClient', scan_type: str
346-
) -> Optional[str]:
347-
aggregation_id = scan_parameters.get('aggregation_id')
348-
if not scan_parameters.get('report'):
349-
return None
350-
if aggregation_id is None:
351-
return None
352-
try:
353-
report_url_response = cycode_client.get_scan_aggregation_report_url(aggregation_id, scan_type)
354-
return report_url_response.report_url
355-
except Exception as e:
356-
logger.debug('Failed to get aggregation report url: %s', str(e))
357-
358-
359332
def scan_commit_range_documents(
360333
context: click.Context,
361334
from_documents_to_scan: List[Document],
@@ -380,7 +353,7 @@ def scan_commit_range_documents(
380353
try:
381354
progress_bar.set_section_length(ScanProgressBarSection.SCAN, 1)
382355

383-
scan_result = init_default_scan_result(cycode_client, scan_id, scan_type)
356+
scan_result = init_default_scan_result(scan_id)
384357
if should_scan_documents(from_documents_to_scan, to_documents_to_scan):
385358
logger.debug('Preparing from-commit zip')
386359
from_commit_zipped_documents = zip_documents(scan_type, from_documents_to_scan)
@@ -518,7 +491,7 @@ def perform_scan_async(
518491
cycode_client,
519492
scan_async_result.scan_id,
520493
scan_type,
521-
scan_parameters.get('report'),
494+
scan_parameters,
522495
)
523496

524497

@@ -553,16 +526,14 @@ def perform_commit_range_scan_async(
553526
logger.debug(
554527
'Async commit range scan request has been triggered successfully, %s', {'scan_id': scan_async_result.scan_id}
555528
)
556-
return poll_scan_results(
557-
cycode_client, scan_async_result.scan_id, scan_type, scan_parameters.get('report'), timeout
558-
)
529+
return poll_scan_results(cycode_client, scan_async_result.scan_id, scan_type, scan_parameters, timeout)
559530

560531

561532
def poll_scan_results(
562533
cycode_client: 'ScanClient',
563534
scan_id: str,
564535
scan_type: str,
565-
should_get_report: bool = False,
536+
scan_parameters: dict,
566537
polling_timeout: Optional[int] = None,
567538
) -> ZippedFileScanResult:
568539
if polling_timeout is None:
@@ -579,7 +550,7 @@ def poll_scan_results(
579550
print_debug_scan_details(scan_details)
580551

581552
if scan_details.scan_status == consts.SCAN_STATUS_COMPLETED:
582-
return _get_scan_result(cycode_client, scan_type, scan_id, scan_details, should_get_report)
553+
return _get_scan_result(cycode_client, scan_type, scan_id, scan_details, scan_parameters)
583554

584555
if scan_details.scan_status == consts.SCAN_STATUS_ERROR:
585556
raise custom_exceptions.ScanAsyncError(
@@ -671,18 +642,19 @@ def parse_pre_receive_input() -> str:
671642
return pre_receive_input.splitlines()[0]
672643

673644

674-
def get_default_scan_parameters(context: click.Context) -> dict:
645+
def _get_default_scan_parameters(context: click.Context) -> dict:
675646
return {
676647
'monitor': context.obj.get('monitor'),
677648
'report': context.obj.get('report'),
678649
'package_vulnerabilities': context.obj.get('package-vulnerabilities'),
679650
'license_compliance': context.obj.get('license-compliance'),
680651
'command_type': context.info_name,
652+
'aggregation_id': str(_generate_unique_id()),
681653
}
682654

683655

684-
def get_scan_parameters(context: click.Context, paths: Tuple[str]) -> dict:
685-
scan_parameters = get_default_scan_parameters(context)
656+
def get_scan_parameters(context: click.Context, paths: Optional[Tuple[str]] = None) -> dict:
657+
scan_parameters = _get_default_scan_parameters(context)
686658

687659
if not paths:
688660
return scan_parameters
@@ -890,36 +862,51 @@ def _get_scan_result(
890862
scan_type: str,
891863
scan_id: str,
892864
scan_details: 'ScanDetailsResponse',
893-
should_get_report: bool = False,
865+
scan_parameters: dict,
894866
) -> ZippedFileScanResult:
895867
if not scan_details.detections_count:
896-
return init_default_scan_result(cycode_client, scan_id, scan_type, should_get_report)
868+
return init_default_scan_result(scan_id)
897869

898870
scan_raw_detections = cycode_client.get_scan_raw_detections(scan_type, scan_id)
899871

900872
return ZippedFileScanResult(
901873
did_detect=True,
902874
detections_per_file=_map_detections_per_file_and_commit_id(scan_type, scan_raw_detections),
903875
scan_id=scan_id,
904-
report_url=_try_get_report_url_if_needed(cycode_client, should_get_report, scan_id, scan_type),
876+
report_url=_try_get_any_report_url_if_needed(cycode_client, scan_id, scan_type, scan_parameters),
905877
)
906878

907879

908-
def init_default_scan_result(
909-
cycode_client: 'ScanClient', scan_id: str, scan_type: str, should_get_report: bool = False
910-
) -> ZippedFileScanResult:
880+
def init_default_scan_result(scan_id: str) -> ZippedFileScanResult:
911881
return ZippedFileScanResult(
912882
did_detect=False,
913883
detections_per_file=[],
914884
scan_id=scan_id,
915-
report_url=_try_get_report_url_if_needed(cycode_client, should_get_report, scan_id, scan_type),
916885
)
917886

918887

888+
def _try_get_any_report_url_if_needed(
889+
cycode_client: 'ScanClient',
890+
scan_id: str,
891+
scan_type: str,
892+
scan_parameters: dict,
893+
) -> Optional[str]:
894+
"""Tries to get aggregation report URL if needed, otherwise tries to get report URL."""
895+
aggregation_report_url = None
896+
if scan_parameters:
897+
_try_get_report_url_if_needed(cycode_client, scan_id, scan_type, scan_parameters)
898+
aggregation_report_url = _try_get_aggregation_report_url_if_needed(scan_parameters, cycode_client, scan_type)
899+
900+
if aggregation_report_url:
901+
return aggregation_report_url
902+
903+
return _try_get_report_url_if_needed(cycode_client, scan_id, scan_type, scan_parameters)
904+
905+
919906
def _try_get_report_url_if_needed(
920-
cycode_client: 'ScanClient', should_get_report: bool, scan_id: str, scan_type: str
907+
cycode_client: 'ScanClient', scan_id: str, scan_type: str, scan_parameters: dict
921908
) -> Optional[str]:
922-
if not should_get_report:
909+
if not scan_parameters.get('report', False):
923910
return None
924911

925912
try:
@@ -929,6 +916,27 @@ def _try_get_report_url_if_needed(
929916
logger.debug('Failed to get report URL', exc_info=e)
930917

931918

919+
def _set_aggregation_report_url(context: click.Context, aggregation_report_url: Optional[str] = None) -> None:
920+
context.obj['aggregation_report_url'] = aggregation_report_url
921+
922+
923+
def _try_get_aggregation_report_url_if_needed(
924+
scan_parameters: dict, cycode_client: 'ScanClient', scan_type: str
925+
) -> Optional[str]:
926+
if not scan_parameters.get('report', False):
927+
return None
928+
929+
aggregation_id = scan_parameters.get('aggregation_id')
930+
if aggregation_id is None:
931+
return None
932+
933+
try:
934+
report_url_response = cycode_client.get_scan_aggregation_report_url(aggregation_id, scan_type)
935+
return report_url_response.report_url
936+
except Exception as e:
937+
logger.debug('Failed to get aggregation report url: %s', str(e))
938+
939+
932940
def _map_detections_per_file_and_commit_id(scan_type: str, raw_detections: List[dict]) -> List[DetectionsPerFile]:
933941
"""Converts list of detections (async flow) to list of DetectionsPerFile objects (sync flow).
934942

cycode/cli/commands/scan/pre_commit/pre_commit_command.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import click
55

66
from cycode.cli import consts
7-
from cycode.cli.commands.scan.code_scanner import scan_documents, scan_sca_pre_commit
7+
from cycode.cli.commands.scan.code_scanner import get_scan_parameters, scan_documents, scan_sca_pre_commit
88
from cycode.cli.files_collector.excluder import exclude_irrelevant_documents_to_scan
99
from cycode.cli.files_collector.repository_documents import (
1010
get_diff_file_content,
@@ -44,4 +44,4 @@ def pre_commit_command(context: click.Context, ignored_args: List[str]) -> None:
4444
documents_to_scan.append(Document(get_path_by_os(get_diff_file_path(file)), get_diff_file_content(file)))
4545

4646
documents_to_scan = exclude_irrelevant_documents_to_scan(scan_type, documents_to_scan)
47-
scan_documents(context, documents_to_scan, is_git_diff=True)
47+
scan_documents(context, documents_to_scan, get_scan_parameters(context), is_git_diff=True)

cycode/cli/commands/scan/repository/repository_command.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def repository_command(context: click.Context, path: str, branch: str) -> None:
6363
perform_pre_scan_documents_actions(context, scan_type, documents_to_scan)
6464

6565
logger.debug('Found all relevant files for scanning %s', {'path': path, 'branch': branch})
66-
scan_parameters = get_scan_parameters(context, (path,))
67-
scan_documents(context, documents_to_scan, scan_parameters=scan_parameters)
66+
scan_documents(context, documents_to_scan, get_scan_parameters(context, (path,)))
6867
except Exception as e:
6968
handle_scan_exception(context, e)

tests/test_code_scanner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_is_relevant_file_to_scan_sca() -> None:
2929
@pytest.mark.parametrize('scan_type', config['scans']['supported_scans'])
3030
def test_try_get_report_url_if_needed_return_none(scan_type: str, scan_client: ScanClient) -> None:
3131
scan_id = uuid4().hex
32-
result = _try_get_report_url_if_needed(scan_client, False, scan_id, consts.SECRET_SCAN_TYPE)
32+
result = _try_get_report_url_if_needed(scan_client, scan_id, consts.SECRET_SCAN_TYPE, scan_parameters={})
3333
assert result is None
3434

3535

@@ -44,7 +44,7 @@ def test_try_get_report_url_if_needed_return_result(
4444
responses.add(get_scan_report_url_response(url, scan_id))
4545

4646
scan_report_url_response = scan_client.get_scan_report_url(str(scan_id), scan_type)
47-
result = _try_get_report_url_if_needed(scan_client, True, str(scan_id), scan_type)
47+
result = _try_get_report_url_if_needed(scan_client, str(scan_id), scan_type, scan_parameters={'report': True})
4848
assert result == scan_report_url_response.report_url
4949

5050

0 commit comments

Comments
 (0)