diff --git a/pyproject.toml b/pyproject.toml index 2a11faf48d..2436a4ae43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,6 +136,7 @@ run = "scancodeio:combined_run" analyze_docker_image = "scanpipe.pipelines.analyze_docker:Docker" analyze_root_filesystem_or_vm_image = "scanpipe.pipelines.analyze_root_filesystem:RootFS" analyze_windows_docker_image = "scanpipe.pipelines.analyze_docker_windows:DockerWindows" +analyze_symbols_reachability = "scanpipe.pipelines.collect_symbols_reachability:SymbolReachability" benchmark_purls = "scanpipe.pipelines.benchmark_purls:BenchmarkPurls" collect_strings_gettext = "scanpipe.pipelines.collect_strings_gettext:CollectStringsGettext" collect_symbols_ctags = "scanpipe.pipelines.collect_symbols_ctags:CollectSymbolsCtags" diff --git a/scanpipe/pipelines/collect_symbols_reachability.py b/scanpipe/pipelines/collect_symbols_reachability.py new file mode 100644 index 0000000000..15519fc661 --- /dev/null +++ b/scanpipe/pipelines/collect_symbols_reachability.py @@ -0,0 +1,35 @@ +# +# Copyright (c) nexB Inc. and others. All rights reserved. +# VulnerableCode is a trademark of nexB Inc. +# SPDX-License-Identifier: Apache-2.0 +# See http://www.apache.org/licenses/LICENSE-2.0 for the license text. +# See https://github.com/aboutcode-org/vulnerablecode for support or download. +# See https://aboutcode.org for more information about nexB OSS projects. +# + +from scanpipe.pipelines import Pipeline +from scanpipe.pipes import reachability + + +class SymbolReachability(Pipeline): + """ + Patch reachability analysis, for given a vulnerability patches + """ + + download_inputs = False + is_addon = True + results_url = "/project/{slug}/resources/?extra_data=symbol_reachability" + + @classmethod + def steps(cls): + return (cls.analyze_and_store_symbol_reachability,) + + def analyze_and_store_symbol_reachability(self): + """ + Perform symbol-level reachability analysis for each patch. + This step compares the AST of patched/vulnerable files against the codebase resources. + Results are stored directly in the 'extra_data' of each CodebaseResource. + """ + reachability.collect_and_store_symbol_reachability_results( + project=self.project, logger=self.log + ) diff --git a/scanpipe/pipes/reachability.py b/scanpipe/pipes/reachability.py new file mode 100644 index 0000000000..5169eba94d --- /dev/null +++ b/scanpipe/pipes/reachability.py @@ -0,0 +1,697 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# http://nexb.com and https://github.com/aboutcode-org/scancode.io +# The ScanCode.io software is licensed under the Apache License version 2.0. +# Data generated with ScanCode.io is provided as-is without warranties. +# ScanCode is a trademark of nexB Inc. +# +# You may not use this software except in compliance with the License. +# You may obtain a copy of the License at: http://apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +# +# Data Generated with ScanCode.io is provided on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. No content created from +# ScanCode.io should be considered or used as legal advice. Consult an Attorney +# for any legal advice. +# +# ScanCode.io is a free software code scanning tool from nexB Inc. and others. +# Visit https://github.com/aboutcode-org/scancode.io for support and download. + +import os +import shutil +import tempfile +from enum import Enum +from pathlib import Path + +from git import Repo +from git.diff import NULL_TREE +from git.exc import BadName +from matchcode_toolkit.fingerprinting import create_file_fingerprints +from scancode.api import get_file_info +from unidiff import PatchSet + +from scanpipe.pipes.symbols import TS_QUERIES +from scanpipe.pipes.symbols import _root_of +from scanpipe.pipes.symbols import collect_definitions +from scanpipe.pipes.symbols import extract_calls_in_node +from scanpipe.pipes.symbols import extract_definitions +from scanpipe.pipes.symbols import extract_symbols +from scanpipe.pipes.symbols import is_nested_function +from scanpipe.pipes.symbols import parse_code_to_ast +from scanpipe.pipes.symbols import qualified_name_from_index + +EMPTY_TREE_SHA = "4b825dc642cb6eb9a060e54bf8b8e6f9b79b4d2b" + + +class ReachabilityStatus(str, Enum): + REACHABLE = "REACHABLE" + POTENTIALLY_REACHABLE = "POTENTIALLY_REACHABLE" + NOT_REACHABLE = "NOT_REACHABLE" + + +def api_mocker(): + """ + TODO: Remove this once the API patch url is done + """ + return [ + { + "vcs_url": "https://github.com/pallets/flask", + "commit_hash": "089cb86dd22bff589a4eafb7ab8e42dc357623b4", + }, + ] + + +def clone_repo(vcs_url, commit_hash=None): + repo_path = tempfile.mkdtemp(prefix="symbol-reachability-") + + try: + repo = Repo.clone_from(vcs_url, repo_path) + + if commit_hash: + repo.git.checkout(commit_hash) + + return repo_path + + except BadName as exc: + cleanup_repo(repo_path) + raise ValueError(f"Commit {commit_hash} not found") from exc + + except Exception: + cleanup_repo(repo_path) + raise + + +def cleanup_repo(repo_path): + if repo_path and os.path.exists(repo_path): + shutil.rmtree(repo_path, ignore_errors=True) + + +def normalize_text(content): + if content is None: + return "" + + if isinstance(content, bytes): + return content.decode("utf-8", errors="replace") + + return str(content) + + +def is_supported_language(language): + """A language is supported if we have tree-sitter queries for it.""" + return bool(language) and language in TS_QUERIES + + +def detect_language_with_scancode(file_path, content): + """ + Write `content` to a temp file preserving `file_path`'s basename + so the extension is meaningful, then ask ScanCode's `get_file_info` + to return the programming language. + """ + content = normalize_text(content) + + if not content: + return None + + tmp_dir = tempfile.mkdtemp(prefix="patch-lang-") + + try: + target = Path(tmp_dir) / Path(file_path).name + target.write_text(content, encoding="utf-8", errors="replace") + + info = get_file_info(location=str(target)) or {} + return info.get("programming_language") or None + + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + +def get_commit_and_parent(repo, commit_hash): + commit = repo.commit(commit_hash) + parent = commit.parents[0] if commit.parents else None + return commit, parent + + +def get_commit_diff_text(repo, parent_commit, commit): + """Whole-commit unified diff (used to extract changed line numbers).""" + base = parent_commit.hexsha if parent_commit else EMPTY_TREE_SHA + return repo.git.diff(base, commit.hexsha, unified=0) + + +def get_changed_files(parent_commit, commit): + """ + Return: + { + file_path: { + "vulnerable_text": "...", + "fixed_text": "...", + } + } + + """ + diffs = ( + parent_commit.diff(commit, create_patch=False) + if parent_commit + else commit.diff(NULL_TREE, create_patch=False) + ) + + files = {} + for diff in diffs: + change_type = diff.change_type + old_path = diff.a_path if change_type in ("D", "M", "R") else None + new_path = diff.b_path if change_type in ("A", "M", "R") else None + path_key = new_path or old_path + + if not path_key: + continue + + entry = files.setdefault( + path_key, + { + "vulnerable_text": "", + "fixed_text": "", + }, + ) + + if old_path and parent_commit: + entry["vulnerable_text"] = ( + (parent_commit.tree / old_path) + .data_stream.read() + .decode("utf-8", errors="replace") + ) + + if new_path: + entry["fixed_text"] = ( + (commit.tree / new_path) + .data_stream.read() + .decode("utf-8", errors="replace") + ) + + return files + + +def get_changed_lines(diff_text, file_path): + """Return `(removed_lines, added_lines)` for one file from a unified diff.""" + removed = [] + added = [] + + if not diff_text: + return removed, added + + for patched_file in PatchSet.from_string(diff_text): + candidates = { + patched_file.path, + (patched_file.source_file or "").removeprefix("a/"), + (patched_file.target_file or "").removeprefix("b/"), + } + + if file_path not in candidates: + continue + + for hunk in patched_file: + for line in hunk: + if line.is_removed and line.source_line_no: + removed.append(line.source_line_no) + elif line.is_added and line.target_line_no: + added.append(line.target_line_no) + + return removed, added + + +def diff_changed_symbols(vuln_meta, fixed_meta): + """ + Keep only symbols whose body actually differs between vulnerable and fixed + versions. Pair by qualified name first. + """ + fixed_by_qn = { + metadata["qualified_name"]: metadata for metadata in fixed_meta.values() + } + + vuln_by_qn = { + metadata["qualified_name"]: metadata for metadata in vuln_meta.values() + } + + vuln_only = { + key: metadata + for key, metadata in vuln_meta.items() + if fixed_by_qn.get(metadata["qualified_name"], {}).get("text") + != metadata["text"] + } + + fixed_only = { + key: metadata + for key, metadata in fixed_meta.items() + if vuln_by_qn.get(metadata["qualified_name"], {}).get("text") + != metadata["text"] + } + + return vuln_only, fixed_only + + +def analyze_patched_file(vulnerable_text, fixed_text, diff_text, file_path): + """ + Return `(vuln_metadata, fixed_metadata, language)` for one changed file, + restricted to symbols actually touched by the patch. + """ + vulnerable_text = normalize_text(vulnerable_text) + fixed_text = normalize_text(fixed_text) + + language = detect_language_with_scancode( + file_path, fixed_text + ) or detect_language_with_scancode(file_path, vulnerable_text) + + if not is_supported_language(language): + return {}, {}, language + + vuln_tree, _ = ( + parse_code_to_ast(vulnerable_text, language) + if vulnerable_text + else (None, None) + ) + + fixed_tree, _ = ( + parse_code_to_ast(fixed_text, language) if fixed_text else (None, None) + ) + + if vuln_tree is None and fixed_tree is None: + return {}, {}, language + + removed_lines, added_lines = get_changed_lines(diff_text, file_path) + + vuln_nodes = ( + extract_symbols(vuln_tree, removed_lines, language) if vuln_tree else [] + ) + + fixed_nodes = ( + extract_symbols(fixed_tree, added_lines, language) if fixed_tree else [] + ) + + vuln_meta, fixed_meta = diff_changed_symbols( + build_symbol_metadata(vuln_nodes, language), + build_symbol_metadata(fixed_nodes, language), + ) + + return vuln_meta, fixed_meta, language + + +def collect_patch_symbols(repo, commit_hash): + """ + Return: + { + language: { + "vulnerable": { + "file_path::symbol_key": metadata, + ... + }, + "fixed": { + "file_path::symbol_key": metadata, + ... + }, + }, + ... + } + + Symbols are bucketed by language so resources are only matched against + patch symbols extracted from the same language. + + """ + commit, parent = get_commit_and_parent(repo, commit_hash) + diff_text = get_commit_diff_text(repo, parent, commit) + changed = get_changed_files(parent, commit) + + by_language = {} + for file_path, texts in changed.items(): + vulnerable_text = texts["vulnerable_text"] + fixed_text = texts["fixed_text"] + vuln_meta, fixed_meta, language = analyze_patched_file( + vulnerable_text=vulnerable_text, + fixed_text=fixed_text, + diff_text=diff_text, + file_path=file_path, + ) + + if not language or not (vuln_meta or fixed_meta): + continue + + language_bucket = by_language.setdefault( + language, + { + "vulnerable": {}, + "fixed": {}, + }, + ) + + language_bucket["vulnerable"].update( + {f"{file_path}::{key}": metadata for key, metadata in vuln_meta.items()} + ) + + language_bucket["fixed"].update( + {f"{file_path}::{key}": metadata for key, metadata in fixed_meta.items()} + ) + + return by_language + + +def append_symbol_reachability_result(resource, result): + """ + Append one symbol reachability result to the resource extra_data without + overwriting previous results. + """ + extra_data = resource.extra_data or {} + existing_results = extra_data.get("symbols_reachability", []) + + if not isinstance(existing_results, list): + existing_results = [existing_results] + + existing_results.append(result) + + resource.update_extra_data( + { + "symbols_reachability": existing_results, + } + ) + + +def collect_and_store_symbol_reachability_results(project, logger=None): + """ + For each known patch commit, determine whether each project codebase + resource is reachable to the vulnerable code by comparing tree-sitter ASTs + of the patch versus the resource. + + Result classification: + - REACHABLE + - POTENTIALLY_REACHABLE + - NOT_REACHABLE + """ + candidate_resources = project.codebaseresources.files().filter( + is_binary=False, + is_archive=False, + is_media=False, + ) + + for patch in api_mocker(): + vcs_url = patch["vcs_url"] + commit_hash = patch["commit_hash"] + repo_path = None + try: + repo_path = clone_repo(vcs_url, commit_hash) + repo = Repo(repo_path) + + patch_symbols_by_language = collect_patch_symbols(repo, commit_hash) + + if not patch_symbols_by_language: + continue + + for resource in candidate_resources: + resource_language = resource.programming_language + + if resource_language not in patch_symbols_by_language: + continue + + resource_text = normalize_text(resource.file_content) + + if not resource_text: + continue + + patch_symbols = patch_symbols_by_language[resource_language] + vuln_metadata = patch_symbols["vulnerable"] + fixed_metadata = patch_symbols["fixed"] + + resource_index = build_resource_index( + resource_text, + resource_language, + ) + + if not resource_index: + continue + + vuln_match_symbols = match_symbols_against_resource( + vuln_metadata, + resource_index, + ) + + fixed_match_symbols = match_symbols_against_resource( + fixed_metadata, + resource_index, + ) + + if not vuln_match_symbols and not fixed_match_symbols: + continue + + result = { + "reachability_status": classify_reachability(vuln_match_symbols), + "summary": { + "vulnerable_symbols": sorted(vuln_match_symbols), + "fixed_symbols": sorted(fixed_match_symbols), + "call_paths": { + qn: ev.get("reachable_from", []) + for qn, ev in vuln_match_symbols.items() + if ev.get("called") + }, + }, + "evidence": vuln_match_symbols, + "patch": { + "vcs_url": vcs_url, + "commit_hash": commit_hash, + }, + } + + append_symbol_reachability_result(resource, result) + + except Exception as e: + logger( + "Failed to collect symbol reachability for " + f"{vcs_url}@{commit_hash}: {e}" + ) + finally: + cleanup_repo(repo_path) + + +def compute_reachable_symbols(call_graph, target_simple_names): + """ + Find all symbols that can transitively reach any of ``target_simple_names``. + + Reachability is matched on *simple* names (the call graph records callee + tokens, not fully-qualified names), so distinct symbols sharing a name are + treated as equivalent. This can over-approximate reachability. + + Returns: + (reachable_callers, has_direct_call) + reachable_callers: qualified names of all transitive callers + has_direct_call: whether any symbol calls a target directly + + """ + if not call_graph or not target_simple_names: + return set(), False + + edges = call_graph["edges"] + targets = set(target_simple_names) + + callers_of: dict[str, set[str]] = {} + for caller_qn, callees in edges.items(): + for callee_simple in callees: + callers_of.setdefault(callee_simple, set()).add(caller_qn) + + direct_callers: set[str] = set() + for target in targets: + direct_callers |= callers_of.get(target, set()) + + has_direct_call = bool(direct_callers) + + qn_to_simple = {qn: meta["simple_name"] for qn, meta in call_graph["nodes"].items()} + + reachable = set(direct_callers) + frontier = list(direct_callers) + + while frontier: + current_qn = frontier.pop() + current_simple = qn_to_simple.get(current_qn) + if not current_simple: + continue + for parent_qn in callers_of.get(current_simple, ()): + if parent_qn not in reachable: + reachable.add(parent_qn) + frontier.append(parent_qn) + + return reachable, has_direct_call + + +def build_resource_index(resource_text, language): + resource_text = normalize_text(resource_text) + + if not is_supported_language(language) or not resource_text: + return None + + tree, _ = parse_code_to_ast(resource_text, language) + + if tree is None: + return None + + call_graph = build_call_graph(tree, language) + + meta = ( + call_graph["nodes"] + if call_graph + else build_symbol_metadata( + extract_definitions(tree, language), + language, + ) + ) + + return { + "definitions": {metadata["qualified_name"] for metadata in meta.values()}, + "fingerprints": { + metadata["fingerprint"] + for metadata in meta.values() + if metadata["fingerprint"] + }, + "call_graph": call_graph, + } + + +def match_symbols_against_resource(symbols, resource_index): + if not symbols or not resource_index: + return {} + + call_graph = resource_index.get("call_graph") + target_simple_names = {metadata["simple_name"] for metadata in symbols.values()} + + reachable_callers, _has_direct = compute_reachable_symbols( + call_graph, + target_simple_names, + ) + + called_simple_names = set() + if call_graph: + for callees in call_graph["edges"].values(): + called_simple_names |= callees + + matched = {} + for metadata in symbols.values(): + qualified_name = metadata["qualified_name"] + simple_name = metadata["simple_name"] + fingerprint = metadata["fingerprint"] + + defined = qualified_name in resource_index["definitions"] + fingerprint_hit = bool( + fingerprint and fingerprint in resource_index["fingerprints"] + ) + called = simple_name in called_simple_names + + if not (defined or fingerprint_hit or called): + continue + + entry = matched.setdefault( + qualified_name, + { + "defined": False, + "called": False, + "reachable_from": [], + }, + ) + + entry["defined"] = entry["defined"] or defined + entry["called"] = entry["called"] or called + + if fingerprint_hit: + entry["exact_match_fingerprint"] = fingerprint + + if called: + entry["reachable_from"] = sorted(reachable_callers) + + return matched + + +def classify_reachability(evidence): + if not evidence: + return ReachabilityStatus.NOT_REACHABLE + + SEVERITY_RANK = { + ReachabilityStatus.NOT_REACHABLE: 0, + ReachabilityStatus.POTENTIALLY_REACHABLE: 1, + ReachabilityStatus.REACHABLE: 2, + } + + highest_status = ReachabilityStatus.NOT_REACHABLE + + for item in evidence.values(): + is_called = bool(item.get("called")) + has_path = bool(item.get("reachable_from")) + is_exact = "exact_match_fingerprint" in item + is_defined = bool(item.get("defined")) + + if is_called or (has_path and is_exact): + return ReachabilityStatus.REACHABLE + + elif has_path or is_exact or is_defined: + current_item_status = ReachabilityStatus.POTENTIALLY_REACHABLE + + else: + current_item_status = ReachabilityStatus.NOT_REACHABLE + + if SEVERITY_RANK[current_item_status] > SEVERITY_RANK[highest_status]: + highest_status = current_item_status + return highest_status + + +def build_symbol_metadata(nodes, language, index=None): + if index is None and nodes: + index = collect_definitions(_root_of(nodes[0]), language) + + metadata = {} + for node in nodes: + if is_nested_function(node, language): + continue + + qualified_name = qualified_name_from_index(node, index) + if not qualified_name: + continue + + body_text = node.text.decode("utf-8", errors="replace") + fingerprints = create_file_fingerprints(content=body_text) or {} + + key = qualified_name + suffix = 1 + while key in metadata: + suffix += 1 + key = f"{qualified_name}#{suffix}" + + metadata[key] = { + "qualified_name": qualified_name, + "simple_name": qualified_name.rsplit(".", 1)[-1], + "text": body_text, + "fingerprint": fingerprints.get("halo1"), + "start_line": node.start_point[0] + 1, + "end_line": node.end_point[0] + 1, + "node_type": node.type, + } + return metadata + + +def build_call_graph(tree, language): + if tree is None or not is_supported_language(language): + return None + + index = collect_definitions(tree.root_node, language) + definition_nodes = [d["node"] for d in index.values()] + metadata = build_symbol_metadata(definition_nodes, language, index=index) + + qualified_name_to_node = {} + for node in definition_nodes: + qualified_name = qualified_name_from_index(node, index) + if qualified_name: + qualified_name_to_node.setdefault(qualified_name, node) + + edges = {} + by_simple_name = {} + for qualified_name, meta in metadata.items(): + canonical = meta["qualified_name"] + node = qualified_name_to_node.get(canonical) + if node is None: + continue + edges.setdefault(canonical, set()).update(extract_calls_in_node(node, language)) + by_simple_name.setdefault(meta["simple_name"], set()).add(canonical) + + return {"nodes": metadata, "edges": edges, "by_simple_name": by_simple_name} diff --git a/scanpipe/pipes/symbols.py b/scanpipe/pipes/symbols.py index 76493d8dac..cac76562e3 100644 --- a/scanpipe/pipes/symbols.py +++ b/scanpipe/pipes/symbols.py @@ -20,8 +20,20 @@ # ScanCode.io is a free software code scanning tool from nexB Inc. and others. # Visit https://github.com/aboutcode-org/scancode.io for support and download. +import importlib +from functools import cache + from django.db.models import Q +from source_inspector import symbols_ctags +from source_inspector import symbols_pygments +from source_inspector import symbols_tree_sitter +from source_inspector.symbols_tree_sitter import TS_LANGUAGE_WHEELS +from source_inspector.symbols_tree_sitter import TreeSitterWheelNotInstalled +from tree_sitter import Language +from tree_sitter import Parser +from tree_sitter import Query + from aboutcode.pipeline import LoopProgress @@ -171,3 +183,182 @@ def _collect_and_store_tree_sitter_symbols_and_strings(resource): "source_strings": result.get("source_strings"), } ) + + +SYMBOLS_TYPE_SUPPORTED = { + "ctags": symbols_ctags.get_symbols, + "tree_sitter": symbols_tree_sitter.get_treesitter_symbols, + "pygments": symbols_pygments.get_pygments_symbols, +} + +# https://github.com/Aider-AI/aider/tree/5dc9490bb35f9729ef2c95d00a19ccd30c26339c/aider/queries/tree-sitter-language-pack +TS_QUERIES = { + "Python": { + "functions": """ + (function_definition name: (identifier) @name) @function + """, + "classes": """ + (class_definition name: (identifier) @name) @class + """, + "calls": """ + (call function: (identifier) @callee) + (call function: (attribute attribute: (identifier) @callee)) + """, + }, +} + + +@cache +def load_language(language: str) -> Language: + if language not in TS_LANGUAGE_WHEELS: + raise ValueError(f"Unsupported language: {language}") + + wheel = TS_LANGUAGE_WHEELS[language]["wheel"] + try: + grammar = importlib.import_module(wheel) + except ModuleNotFoundError as exc: + raise TreeSitterWheelNotInstalled( + f"Grammar wheel '{wheel}' is not installed." + ) from exc + return Language(grammar.language()) + + +@cache +def get_query(language: str, kind: str) -> Query | None: + source = TS_QUERIES.get(language, {}).get(kind, "").strip() + if not source: + return None + return Query(load_language(language), source) + + +def parse_code_to_ast(code_text: str, language: str): + if not code_text or not language or language not in TS_LANGUAGE_WHEELS: + return None, None + + ts_language = load_language(language) + parser = Parser(language=ts_language) + return parser.parse(code_text.encode("utf-8")), TS_LANGUAGE_WHEELS[language] + + +def run_query(query: Query, root_node): + """Yield ``(definition_node, name)`` pairs for function/class queries.""" + if query is None: + return + + for _pattern_index, captures in query.matches(root_node): + def_nodes = captures.get("function") or captures.get("class") or [] + if not def_nodes: + continue + + name_nodes = captures.get("name") or [] + name = ( + name_nodes[0].text.decode("utf-8", errors="replace") if name_nodes else None + ) + yield def_nodes[0], name + + +def query_captures(language, kind, node): + """Re-run a definition query on the root of node's tree.""" + query = get_query(language, kind) + return list(run_query(query, _root_of(node))) + + +def _root_of(node): + while node.parent is not None: + node = node.parent + return node + + +def is_nested_function(node, language): + function_nodes = { + captured_node + for captured_node, _ in query_captures(language, "functions", node) + } + class_nodes = { + captured_node for captured_node, _ in query_captures(language, "classes", node) + } + + if node not in function_nodes: + return False + + function_types = {captured_node.type for captured_node in function_nodes} + class_types = {captured_node.type for captured_node in class_nodes} + + parent = node.parent + + while parent is not None: + if parent.type in function_types: + return True + + if parent.type in class_types: + return False + + parent = parent.parent + + return False + + +def extract_calls_in_node(node, language: str): + query = get_query(language, "calls") + if query is None or node is None: + return set() + + names = set() + for _pattern_index, captures in query.matches(node): + for callee_node in captures.get("callee", []): + name = callee_node.text.decode("utf-8", errors="replace") + if name: + names.add(name) + return names + + +def collect_definitions(root_node, language: str): + index: dict[int, dict] = {} + for kind in ("functions", "classes"): + query = get_query(language, kind) + for node, name in run_query(query, root_node): + index[node.id] = {"node": node, "name": name, "kind": kind} + return index + + +def extract_definitions(tree, language: str, kinds=("functions", "classes")): + if tree is None: + return [] + index = collect_definitions(tree.root_node, language) + return [d["node"] for d in index.values() if d["kind"] in kinds] + + +def extract_symbols(tree, changed_lines: list[int], language: str): + if tree is None or not changed_lines: + return [] + + definition_ids = set(collect_definitions(tree.root_node, language).keys()) + if not definition_ids: + return [] + + seen = set() + enclosing = [] + + for line in changed_lines: + row = max(0, line - 1) + node = tree.root_node.descendant_for_point_range((row, 0), (row, 0)) + + while node is not None: + if node.id in definition_ids and node.id not in seen: + seen.add(node.id) + enclosing.append(node) + break + node = node.parent + + return enclosing + + +def qualified_name_from_index(node, index): + parts = [] + curr = node + while curr is not None: + definition = index.get(curr.id) + if definition is not None and definition["name"]: + parts.append(definition["name"]) + curr = curr.parent + return ".".join(reversed(parts)) diff --git a/scanpipe/tests/data/reachability/app.py b/scanpipe/tests/data/reachability/app.py new file mode 100644 index 0000000000..c64ae7d9d1 --- /dev/null +++ b/scanpipe/tests/data/reachability/app.py @@ -0,0 +1,35 @@ +import os + + +class ReportGenerator: + """A dummy class to test AST class method parsing.""" + + def __init__(self, base_dir): + self.base_dir = base_dir + + +def serve_report(request_payload): + """Top-level function handling a request.""" + generator = ReportGenerator("/var/reports") + requested_file = request_payload.get("file") + + # Helper function nested inside serve_report + def build_file_path(filename): + # VULNERABLE: Direct concatenation allows Path Traversal + # An attacker passing "../../etc/passwd" could read system files. + return os.path.join(generator.base_dir, filename) + + if not requested_file: + return "Error: No file specified" + + target_path = build_file_path(requested_file) + + if os.path.exists(target_path): + return f"Serving content of {target_path}" + + return "Error: File not found" + + +def unrelated_top_level_function(): + """An extra function to test AST node boundaries.""" + return "I am just here to add AST complexity." diff --git a/scanpipe/tests/data/reachability/diff-app.patch b/scanpipe/tests/data/reachability/diff-app.patch new file mode 100644 index 0000000000..ccb86953a8 --- /dev/null +++ b/scanpipe/tests/data/reachability/diff-app.patch @@ -0,0 +1,39 @@ +From 8f7b1c3d9a4e2b6f5d8c1a2e3f4b5c6d7e8f9a0b Mon Sep 17 00:00:00 2001 +From: Security Team +Date: Tue, 2 Jun 2026 10:00:00 +0000 +Subject: [PATCH] Fix path traversal vulnerability in report generator + +- Validates that target paths stay within the designated base_dir. +- Catches ValueError on invalid path resolution. +--- + app.py | 12 +++++++++--- + 1 file changed, 9 insertions(+), 3 deletions(-) + +diff --git a/app.py b/app.py +index a1b2c3d..e4f5g6h 100644 +--- a/app.py ++++ b/app.py +@@ -15,13 +15,19 @@ def serve_report(request_payload): + # Helper function nested inside serve_report + def build_file_path(filename): +- # VULNERABLE: Direct concatenation allows Path Traversal +- # An attacker passing "../../etc/passwd" could read system files. +- return os.path.join(generator.base_dir, filename) ++ # FIXED: Validate that the resolved path stays within the base_dir ++ base = os.path.abspath(generator.base_dir) ++ target = os.path.abspath(os.path.join(base, filename)) ++ if not target.startswith(base): ++ raise ValueError("Path Traversal Detected") ++ return target + + if not requested_file: + return "Error: No file specified" + +- target_path = build_file_path(requested_file) ++ try: ++ target_path = build_file_path(requested_file) ++ except ValueError: ++ return "Error: Invalid path" + + if os.path.exists(target_path): + return f"Serving content of {target_path}" \ No newline at end of file diff --git a/scanpipe/tests/data/reachability/fixed-app.py b/scanpipe/tests/data/reachability/fixed-app.py new file mode 100644 index 0000000000..3296bb843e --- /dev/null +++ b/scanpipe/tests/data/reachability/fixed-app.py @@ -0,0 +1,41 @@ +import os + + +class ReportGenerator: + """A dummy class to test AST class method parsing.""" + + def __init__(self, base_dir): + self.base_dir = base_dir + + +def serve_report(request_payload): + """Top-level function handling a request.""" + generator = ReportGenerator("/var/reports") + requested_file = request_payload.get("file") + + # Helper function nested inside serve_report + def build_file_path(filename): + # FIXED: Validate that the resolved path stays within the base_dir + base = os.path.abspath(generator.base_dir) + target = os.path.abspath(os.path.join(base, filename)) + if not target.startswith(base): + raise ValueError("Path Traversal Detected") + return target + + if not requested_file: + return "Error: No file specified" + + try: + target_path = build_file_path(requested_file) + except ValueError: + return "Error: Invalid path" + + if os.path.exists(target_path): + return f"Serving content of {target_path}" + + return "Error: File not found" + + +def unrelated_top_level_function(): + """An extra function to test AST node boundaries.""" + return "I am just here to add AST complexity." diff --git a/scanpipe/tests/data/reachability/vuln-app.py b/scanpipe/tests/data/reachability/vuln-app.py new file mode 100644 index 0000000000..c64ae7d9d1 --- /dev/null +++ b/scanpipe/tests/data/reachability/vuln-app.py @@ -0,0 +1,35 @@ +import os + + +class ReportGenerator: + """A dummy class to test AST class method parsing.""" + + def __init__(self, base_dir): + self.base_dir = base_dir + + +def serve_report(request_payload): + """Top-level function handling a request.""" + generator = ReportGenerator("/var/reports") + requested_file = request_payload.get("file") + + # Helper function nested inside serve_report + def build_file_path(filename): + # VULNERABLE: Direct concatenation allows Path Traversal + # An attacker passing "../../etc/passwd" could read system files. + return os.path.join(generator.base_dir, filename) + + if not requested_file: + return "Error: No file specified" + + target_path = build_file_path(requested_file) + + if os.path.exists(target_path): + return f"Serving content of {target_path}" + + return "Error: File not found" + + +def unrelated_top_level_function(): + """An extra function to test AST node boundaries.""" + return "I am just here to add AST complexity." diff --git a/scanpipe/tests/pipes/test_symbols_reachability.py b/scanpipe/tests/pipes/test_symbols_reachability.py new file mode 100644 index 0000000000..1a5fd7b3ad --- /dev/null +++ b/scanpipe/tests/pipes/test_symbols_reachability.py @@ -0,0 +1,466 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# http://nexb.com and https://github.com/nexB/scancode.io +# The ScanCode.io software is licensed under the Apache License version 2.0. +# Data generated with ScanCode.io is provided as-is without warranties. +# ScanCode is a trademark of nexB Inc. +# +# You may not use this software except in compliance with the License. +# You may obtain a copy of the License at: http://apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +# +# Data Generated with ScanCode.io is provided on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. No content created from +# ScanCode.io should be considered or used as legal advice. Consult an Attorney +# for any legal advice. +# +# ScanCode.io is a free software code scanning tool from nexB Inc. and others. +# Visit https://github.com/nexB/scancode.io for support and download. + +from pathlib import Path +from unittest.mock import patch + +from django.test import TestCase + +from scanpipe.models import Project +from scanpipe.pipes import collect_and_create_codebase_resources +from scanpipe.pipes.reachability import ReachabilityStatus +from scanpipe.pipes.reachability import analyze_patched_file +from scanpipe.pipes.reachability import build_call_graph +from scanpipe.pipes.reachability import build_symbol_metadata +from scanpipe.pipes.reachability import classify_reachability +from scanpipe.pipes.reachability import collect_and_store_symbol_reachability_results +from scanpipe.pipes.reachability import diff_changed_symbols +from scanpipe.pipes.reachability import get_changed_lines +from scanpipe.pipes.symbols import collect_definitions, extract_symbols +from scanpipe.pipes.symbols import extract_definitions +from scanpipe.pipes.symbols import parse_code_to_ast +from scanpipe.pipes.symbols import qualified_name_from_index + + +class SymbolReachabilityPipesTest(TestCase): + data = Path(__file__).parent.parent / "data" / "reachability" + + def setUp(self): + self.project1 = Project.objects.create(name="Analysis") + self.project1.codebase_path.mkdir(parents=True, exist_ok=True) + + @patch("scanpipe.pipes.reachability.Repo") + @patch("scanpipe.pipes.reachability.clone_repo") + @patch("scanpipe.pipes.reachability.api_mocker") + @patch("scanpipe.pipes.reachability.collect_patch_symbols") + def test_collect_and_store_symbol_reachability_results( + self, mock_collect_symbols, mock_api, mock_clone_repo, mock_repo + ): + app_text = (self.data / "app.py").read_text() + vuln_text = (self.data / "vuln-app.py").read_text() + fixed_text = (self.data / "fixed-app.py").read_text() + diff_text = (self.data / "diff-app.patch").read_text() + + vuln_meta, fixed_meta, lang = analyze_patched_file( + vulnerable_text=vuln_text, + fixed_text=fixed_text, + diff_text=diff_text, + file_path="app.py", + ) + + self.assertTrue(lang) + self.assertTrue(vuln_meta or fixed_meta) + mock_api.return_value = [ + { + "vcs_url": "https://github.com/aboutcode-org/test", + "commit_hash": "07ec0de1964b14bf085a1c9a27ece2b61ab6105c", + } + ] + + mock_clone_repo.return_value = str(self.project1.codebase_path) + mock_collect_symbols.return_value = { + lang: { + "vulnerable": { + f"app.py::{key}": metadata for key, metadata in vuln_meta.items() + }, + "fixed": { + f"app.py::{key}": metadata for key, metadata in fixed_meta.items() + }, + } + } + + resource_file = self.project1.codebase_path / "app.py" + resource_file.write_text(app_text) + collect_and_create_codebase_resources(self.project1) + + resource = self.project1.codebaseresources.get(path="app.py") + resource.programming_language = lang + resource.save() + + collect_and_store_symbol_reachability_results(self.project1) + + resource.refresh_from_db() + results = resource.extra_data.get("symbols_reachability") + + assert results == [ + { + "patch": { + "vcs_url": "https://github.com/aboutcode-org/test", + "commit_hash": "07ec0de1964b14bf085a1c9a27ece2b61ab6105c", + }, + "summary": { + "call_paths": {}, + "fixed_symbols": ["serve_report"], + "vulnerable_symbols": ["serve_report"], + }, + "evidence": { + "serve_report": { + "called": False, + "defined": True, + "reachable_from": [], + "exact_match_fingerprint": "000000556d322a47595af353274b000aa324e014", + } + }, + "reachability_status": "POTENTIALLY_REACHABLE", + } + ] + + def test_build_call_graph(self): + source_code = """ +def calculate_total(price, tax): + return price + get_tax_amount(price, tax) + +def get_tax_amount(price, tax): + return price * tax + +def process_order(): + total = calculate_total(100, 0.05) + print("Done") +""" + tree, _ = parse_code_to_ast(source_code, "Python") + result = build_call_graph(tree, "Python") + + assert result == { + "nodes": { + "calculate_total": { + "qualified_name": "calculate_total", + "simple_name": "calculate_total", + "text": "def calculate_total(price, tax):\n return price + get_tax_amount(price, tax)", + "fingerprint": "00000008060105fd3624134884412006ce880936", + "start_line": 2, + "end_line": 3, + "node_type": "function_definition", + }, + "get_tax_amount": { + "qualified_name": "get_tax_amount", + "simple_name": "get_tax_amount", + "text": "def get_tax_amount(price, tax):\n return price * tax", + "fingerprint": "000000058f0ee87d9669f20b1f473137b665bb20", + "start_line": 5, + "end_line": 6, + "node_type": "function_definition", + }, + "process_order": { + "qualified_name": "process_order", + "simple_name": "process_order", + "text": 'def process_order():\n total = calculate_total(100, 0.05)\n print("Done")', + "fingerprint": "000000071c3e6902da5c2b322386eff29068e3e2", + "start_line": 8, + "end_line": 10, + "node_type": "function_definition", + }, + }, + "edges": { + "calculate_total": {"get_tax_amount"}, + "get_tax_amount": set(), + "process_order": {"print", "calculate_total"}, + }, + "by_simple_name": { + "calculate_total": {"calculate_total"}, + "get_tax_amount": {"get_tax_amount"}, + "process_order": {"process_order"}, + }, + } + + def test_extract_definitions(self): + source_code = """ +class OrderManager: + def __init__(self, order_id): + self.order_id = order_id + + def process_payment(self): + print("Processing...") + +def calculate_discount(price): + return price * 0.10 + +class InventoryItem: + pass +""" + tree, _ = parse_code_to_ast(source_code, "Python") + functions = extract_definitions(tree, "Python", kinds=("functions",)) + assert ( + len(functions) == 3 + ) # '__init__', 'process_payment', and 'calculate_discount' + + assert functions[0].type == "function_definition" + first_func_text = functions[0].text.decode("utf-8") + assert "def __init__" in first_func_text + + classes = extract_definitions(tree, "Python", kinds=("classes",)) + assert len(classes) == 2 # OrderManager, InventoryItem + second_class_text = classes[1].text.decode("utf-8") + assert "class InventoryItem" in second_class_text + + def test_extract_definitions_empty(self): + tree, _ = parse_code_to_ast("", "Python") + assert extract_definitions(tree, "Python", kinds=("functions",)) == [] + assert extract_definitions(tree, "Python", kinds=("functions",)) == [] + assert extract_definitions(None, "Python", kinds=("classes",)) == [] + assert extract_definitions(None, "Python", kinds=("classes",)) == [] + + def test_get_qualified_name_functions(self): + source_code = """ +class CoreService: + class Validator: + def validate_payload(self, data): + return True + +def global_utility(): + pass + """ + + tree, _ = parse_code_to_ast(source_code, "Python") + index = collect_definitions(tree.root_node, "Python") + + functions = extract_definitions(tree, "Python", kinds=("functions",)) + assert len(functions) == 2 + + outer_function_name = qualified_name_from_index(functions[0], index) + inner_function_name = qualified_name_from_index(functions[1], index) + + assert outer_function_name == "CoreService.Validator.validate_payload" + assert inner_function_name == "global_utility" + + def test_get_qualified_classes(self): + source_code = """ +class FleetManagement: + class DroneController: + pass + """ + tree, _ = parse_code_to_ast(source_code, "Python") + index = collect_definitions(tree.root_node, "Python") + + classes = extract_definitions(tree, "Python", kinds=("classes",)) + assert len(classes) == 2 + + outer_class_name = qualified_name_from_index(classes[0], index) + inner_class_name = qualified_name_from_index(classes[1], index) + + assert outer_class_name == "FleetManagement" + assert inner_class_name == "FleetManagement.DroneController" + + def test_classify_reachability(self): + assert classify_reachability(None) == ReachabilityStatus.NOT_REACHABLE + assert classify_reachability({}) == ReachabilityStatus.NOT_REACHABLE + assert ( + classify_reachability( + {"sym1": {"exact_match_fingerprint": "hash123", "called": True}} + ) + == ReachabilityStatus.REACHABLE + ) + + assert ( + classify_reachability( + { + "sym1": { + "called": True, + "reachable_from": ["main_function", "api_handler"], + } + } + ) + == ReachabilityStatus.REACHABLE + ) + assert ( + classify_reachability({"sym1": {"defined": True, "called": False}}) + == ReachabilityStatus.POTENTIALLY_REACHABLE + ) + assert ( + classify_reachability( + {"sym1": {"exact_match_fingerprint": "hash123", "called": False}} + ) + == ReachabilityStatus.POTENTIALLY_REACHABLE + ) + assert ( + classify_reachability({"sym1": {"file_path": "src/vulnerable.py"}}) + == ReachabilityStatus.NOT_REACHABLE + ) + + def test_get_changed_lines(self): + data = Path(__file__).parent.parent / "data" / "reachability" + diff_text = (data / "diff-app.patch").read_text(encoding="utf-8") + + removed, added = get_changed_lines(diff_text, "app.py") + assert removed == [17, 18, 19, 24] + assert added == [17, 18, 19, 20, 21, 22, 27, 28, 29, 30] + + def test_build_symbol_metadata_processing(self): + source_code = """ +class Controller: + def process_data(payload): + def inner_helper(): + return True + return payload.strip() + +if True: + def process_data(payload): + return payload +""" + tree, _ = parse_code_to_ast(source_code, "Python") + nodes = extract_definitions(tree, "Python", kinds=("functions",)) + + metadata = build_symbol_metadata(nodes, "Python") + assert metadata == { + "Controller.process_data": { + "qualified_name": "Controller.process_data", + "simple_name": "process_data", + "text": "def process_data(payload):\n def inner_helper():\n return True\n return payload.strip()", + "fingerprint": "0000000888014a04b037189a42b238a2c50f218c", + "start_line": 3, + "end_line": 6, + "node_type": "function_definition", + }, + "process_data": { + "qualified_name": "process_data", + "simple_name": "process_data", + "text": "def process_data(payload):\n return payload", + "fingerprint": "000000022020300e882a900807880d0300010000", + "start_line": 9, + "end_line": 10, + "node_type": "function_definition", + }, + } + + def test_diff_changed_symbols(self): + vuln_meta = { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n return os.path.join(base, filename)", + }, + "sanitize_input": { + "qualified_name": "app.sanitize_input", + "text": "def sanitize_input(x):\n return x.strip()", + }, + "deprecated_logger": { + "qualified_name": "app.deprecated_logger", + "text": "def deprecated_logger():\n print('legacy')", + }, + } + + fixed_meta = { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n if not target.startswith(base): raise ValueError\n return target", + }, + "sanitize_input": { + "qualified_name": "app.sanitize_input", + "text": "def sanitize_input(x):\n return x.strip()", + }, + "audit_trail": { + "qualified_name": "app.audit_trail", + "text": "def audit_trail():\n log.info('action')", + }, + } + + vuln_only, fixed_only = diff_changed_symbols(vuln_meta, fixed_meta) + + assert vuln_only == { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n return os.path.join(base, filename)", + }, + "deprecated_logger": { + "qualified_name": "app.deprecated_logger", + "text": "def deprecated_logger():\n print('legacy')", + }, + } + assert fixed_only == { + "serve_report": { + "qualified_name": "app.serve_report", + "text": "def serve_report():\n if not target.startswith(base): raise ValueError\n return target", + }, + "audit_trail": { + "qualified_name": "app.audit_trail", + "text": "def audit_trail():\n log.info('action')", + }, + } + + def test_analyze_patched_file(self): + vuln_text = (self.data / "vuln-app.py").read_text(encoding="utf-8") + fixed_text = (self.data / "fixed-app.py").read_text(encoding="utf-8") + diff_text = (self.data / "diff-app.patch").read_text(encoding="utf-8") + + vuln_meta, fixed_meta, lang = analyze_patched_file( + vulnerable_text=vuln_text, + fixed_text=fixed_text, + diff_text=diff_text, + file_path="app.py", + ) + + assert vuln_meta == { + "serve_report": { + "qualified_name": "serve_report", + "simple_name": "serve_report", + "text": 'def serve_report(request_payload):\n """Top-level function handling a request."""\n generator = ReportGenerator("/var/reports")\n requested_file = request_payload.get("file")\n\n # Helper function nested inside serve_report\n def build_file_path(filename):\n # VULNERABLE: Direct concatenation allows Path Traversal\n # An attacker passing "../../etc/passwd" could read system files.\n return os.path.join(generator.base_dir, filename)\n\n if not requested_file:\n return "Error: No file specified"\n\n target_path = build_file_path(requested_file)\n\n if os.path.exists(target_path):\n return f"Serving content of {target_path}"\n\n return "Error: File not found"', + "fingerprint": "000000556d322a47595af353274b000aa324e014", + "start_line": 11, + "end_line": 30, + "node_type": "function_definition", + } + } + assert fixed_meta == { + "serve_report": { + "qualified_name": "serve_report", + "simple_name": "serve_report", + "text": 'def serve_report(request_payload):\n """Top-level function handling a request."""\n generator = ReportGenerator("/var/reports")\n requested_file = request_payload.get("file")\n\n # Helper function nested inside serve_report\n def build_file_path(filename):\n # FIXED: Validate that the resolved path stays within the base_dir\n base = os.path.abspath(generator.base_dir)\n target = os.path.abspath(os.path.join(base, filename))\n if not target.startswith(base):\n raise ValueError("Path Traversal Detected")\n return target\n\n if not requested_file:\n return "Error: No file specified"\n\n try:\n target_path = build_file_path(requested_file)\n except ValueError:\n return "Error: Invalid path"\n\n if os.path.exists(target_path):\n return f"Serving content of {target_path}"\n\n return "Error: File not found"', + "fingerprint": "0000006cceea8aedf1da91830f67b64927086d24", + "start_line": 11, + "end_line": 36, + "node_type": "function_definition", + } + } + + def test_extract_symbols(self): + source_code = ( + "def serve_report(request):\n" # Line 1 (Row 0) + " # Some processing here\n" # Line 2 (Row 1) + " def build_path(filename):\n" # Line 3 (Row 2) + " return filename.strip()\n" # Line 4 (Row 3) <- Targeted Change + " return build_path(request)\n" # Line 5 (Row 4) + ) + + tree, _ = parse_code_to_ast(source_code, "Python") + + changed_lines = [4] + enclosing_symbols = extract_symbols(tree, changed_lines, "Python") + + assert len(enclosing_symbols) == 1 + target_node = enclosing_symbols[0] + assert target_node.type == "function_definition" + + node_text = target_node.text.decode("utf-8") + assert "def build_path" in node_text + assert "def serve_report" not in node_text + + def test_extract_symbols_deduplication(self): + source_code = ( + "def calculate_total(price, tax):\n" + " amount = price * tax\n" # Line 2 -> Changed + " return price + amount\n" # Line 3 -> Changed + ) + + tree, _ = parse_code_to_ast(source_code, "Python") + changed_lines = [2, 3] + + enclosing_symbols = extract_symbols(tree, changed_lines, "Python") + assert len(enclosing_symbols) == 1 + assert enclosing_symbols[0].type == "function_definition" \ No newline at end of file