diff --git a/src/tocode/analysis.py b/src/tocode/analysis.py index e3b5c3f..4490385 100644 --- a/src/tocode/analysis.py +++ b/src/tocode/analysis.py @@ -69,7 +69,7 @@ def collect(self) -> ProgramAnalysis: or f"{self.session.backend_label} auto-analysis" ) self.progress.log(f"Analyzing with {label}") - with self.progress.bar(total=15, desc="analyze", unit="step") as bar: + with self.progress.bar(total=16, desc="analyze", unit="step") as bar: self.session.analyze() bar.update(1) info = self.session.info() @@ -92,6 +92,8 @@ def collect(self) -> ProgramAnalysis: bar.update(1) flags = [self._flag(row) for row in self.session.flags()] bar.update(1) + data_xrefs = self._data_xrefs(symbols, relocations, strings, flags) + bar.update(1) routines = self._routines(self.session.functions(), imports, segments) bar.update(1) callees, callers, import_calls = self._call_graph(routines, imports) @@ -128,6 +130,7 @@ def collect(self) -> ProgramAnalysis: import_calls=import_calls, roots=roots, thunks=thunks, + data_xrefs=data_xrefs, ) self.analysis = analysis self.analysis_seconds = time.monotonic() - started @@ -249,6 +252,28 @@ def _routines( ) return result + def _data_xrefs( + self, + symbols: list[SymbolEntry], + relocations: list[RelocationEntry], + strings: list[StringEntry], + flags: list[FlagEntry], + ) -> dict[int, list[tuple[int, bool]]]: + collect = getattr(self.session, "data_xrefs", None) + if not callable(collect): + return {} + addresses: set[int] = set() + addresses.update(item.vaddr for item in strings) + addresses.update(item.vaddr for item in symbols if not item.imported) + addresses.update(item.vaddr for item in relocations) + addresses.update(item.offset for item in flags) + if not addresses: + return {} + try: + return collect(addresses) + except Exception: # noqa: BLE001 + return {} + def _call_graph( self, routines: dict[int, Routine], diff --git a/src/tocode/backends/base.py b/src/tocode/backends/base.py index 5a9bcbf..e836cbc 100644 --- a/src/tocode/backends/base.py +++ b/src/tocode/backends/base.py @@ -59,6 +59,8 @@ def calls_from( self, address: int, imports: dict[int, Any], functions: dict[int, Any] ) -> tuple[list[int], list[str]]: ... + def data_xrefs(self, addresses: Any) -> dict[int, list[tuple[int, bool]]]: ... + @dataclass(slots=True) class IdaProbe: diff --git a/src/tocode/backends/ida.py b/src/tocode/backends/ida.py index e390361..345bb67 100644 --- a/src/tocode/backends/ida.py +++ b/src/tocode/backends/ida.py @@ -74,6 +74,7 @@ def __init__( self._ida_fixup = self._optional_import("ida_fixup") self._ida_auto = self._optional_import("ida_auto") self._ida_nalt = self._optional_import("ida_nalt") + self._ida_xref = self._optional_import("ida_xref") self._db: Any = None if db_path is None: @@ -118,13 +119,9 @@ def __init__( self._strings_ready = False self._decompiler_ready = False - self._disasm_cache: dict[int, str] = {} - self._decompile_cache: dict[int, str] = {} - self._summary_cache: dict[int, str] = {} self._locals_cache: dict[int, list[Any]] = {} self._imports_cache: list[dict[str, Any]] | None = None self._relocs_cache: list[dict[str, Any]] | None = None - self._primed: set[int] = set() def _optional_import(self, module: str): try: @@ -145,23 +142,36 @@ def _wait_for_auto_analysis(self) -> None: def analyze(self) -> None: if self._strings_ready: return - try: - from ida_domain.strings import StringListConfig, StringType + # An already-analyzed database (e.g. a reused `.i64`) usually carries a + # populated string list, so rescanning the whole image is wasted work. + # Only rebuild when the list is empty; never skip when there is nothing + # to lose. + if not self._has_strings(): + try: + from ida_domain.strings import StringListConfig, StringType - self._db.strings.rebuild( - StringListConfig( - string_types=[StringType.C, StringType.C_16], - min_len=4, - only_ascii_7bit=False, + self._db.strings.rebuild( + StringListConfig( + string_types=[StringType.C, StringType.C_16], + min_len=4, + only_ascii_7bit=False, + ) ) - ) - except Exception: # noqa: BLE001 - try: - self._db.strings.rebuild() except Exception: # noqa: BLE001 - pass + try: + self._db.strings.rebuild() + except Exception: # noqa: BLE001 + pass self._strings_ready = True + def _has_strings(self) -> bool: + try: + for _ in self._db.strings: + return True + except Exception: # noqa: BLE001 + return False + return False + def close(self) -> None: if self._db is None: return @@ -212,9 +222,6 @@ def restore_parallel_resources(self) -> None: self._open_existing_database(resolved_db) def release_render_memory(self) -> None: - self._disasm_cache.clear() - self._decompile_cache.clear() - self._summary_cache.clear() self._locals_cache.clear() if self._ida_hexrays is None: return @@ -256,11 +263,7 @@ def _open_existing_database(self, resolved_db: Path) -> None: def _clear_caches(self) -> None: self._decompiler_ready = False - self._disasm_cache.clear() - self._decompile_cache.clear() - self._summary_cache.clear() self._locals_cache.clear() - self._primed.clear() def worker(self) -> "IdaSession": if self._cache_db is not None and self._cache_db.exists(): @@ -505,8 +508,6 @@ def decompile(self, address: int) -> str: return "\n".join(lines) if isinstance(lines, list) else str(lines) def function_summary(self, address: int) -> str: - if address in self._summary_cache: - return self._summary_cache[address] func = self._need_function(address) signature = self._db.functions.get_signature( func @@ -557,6 +558,24 @@ def calls_from( imported.add(name) return sorted(edges), sorted(name for name in imported if name) + def data_xrefs(self, addresses: Any) -> dict[int, list[tuple[int, bool]]]: + if self._ida_xref is None: + return {} + xref = self._ida_xref + write_type = int(getattr(xref, "dr_W", 2)) + result: dict[int, list[tuple[int, bool]]] = {} + for address in addresses: + target = int(address) + refs: list[tuple[int, bool]] = [] + block = xref.xrefblk_t() + ok = block.first_to(target, xref.XREF_DATA) + while ok: + refs.append((int(block.frm), int(block.type) == write_type)) + ok = block.next_to() + if refs: + result[target] = refs + return result + def _resolve_thunk(self, func: Any) -> Any: from ida_domain.functions import FunctionFlags @@ -581,19 +600,6 @@ def _need_function(self, address: int): raise BackendError(f"IDA could not resolve function at 0x{address:x}") return func - def _prime(self, address: int) -> None: - if address in self._primed: - return - if self._ida_hexrays is not None: - try: - self.ensure_decompiler() - self._function_pseudocode(self._need_function(address)) - except Exception: # noqa: BLE001 - pass - self._locals_cache.pop(address, None) - self._summary_cache.pop(address, None) - self._primed.add(address) - def _locals(self, address: int) -> list[Any]: if address not in self._locals_cache: try: diff --git a/src/tocode/backends/r2.py b/src/tocode/backends/r2.py index 25cd014..95612ed 100644 --- a/src/tocode/backends/r2.py +++ b/src/tocode/backends/r2.py @@ -139,6 +139,24 @@ def _disasm_json(self, address: int) -> dict[str, Any]: self._pdfj[address] = self.cmdj(f"pdfj @ 0x{address:x}") or {} return self._pdfj[address] + def data_xrefs(self, addresses) -> dict[int, list[tuple[int, bool]]]: + result: dict[int, list[tuple[int, bool]]] = {} + for address in addresses: + target = int(address) + rows = self.cmdj(f"axtj @ 0x{target:x}") or [] + refs: list[tuple[int, bool]] = [] + for row in rows: + frm = row.get("from") + if frm is None: + continue + kind = str(row.get("type", "")).lower() + perm = str(row.get("perm", "")).lower() + is_write = "w" in perm or kind == "write" + refs.append((int(frm), is_write)) + if refs: + result[target] = refs + return result + def calls_from( self, address: int, imports, functions ) -> tuple[list[int], list[str]]: diff --git a/src/tocode/cli.py b/src/tocode/cli.py index 399f645..720d065 100644 --- a/src/tocode/cli.py +++ b/src/tocode/cli.py @@ -78,6 +78,11 @@ def build_parser() -> argparse.ArgumentParser: action="store_true", help="Also write tree-sitter/Semgrep friendly source under src/tree.", ) + parser.add_argument( + "--entropy", + action="store_true", + help="Compute per-section Shannon entropy (off by default; slow on large binaries).", + ) parser.add_argument( "-q", "--quiet", @@ -130,4 +135,5 @@ def _run_one( progress=progress, jobs=args.jobs, tree=args.tree, + entropy=args.entropy, ) diff --git a/src/tocode/exporter.py b/src/tocode/exporter.py index cbcbb16..f42ce36 100644 --- a/src/tocode/exporter.py +++ b/src/tocode/exporter.py @@ -89,6 +89,7 @@ class ExportContext: out_dir: Path | None jobs: int | None tree_enabled: bool + entropy_enabled: bool = False analysis: ProgramAnalysis | None = None root: Path | None = None raw_dir: Path | None = None @@ -146,6 +147,7 @@ def export_binary( progress: Progress | None = None, jobs: int | None = None, tree: bool = False, + entropy: bool = False, ) -> ExportSummary: progress = progress or analyzer.progress context = ExportContext( @@ -154,6 +156,7 @@ def export_binary( out_dir=out_dir, jobs=jobs, tree_enabled=tree, + entropy_enabled=entropy, ) _prepare_tree(context) _cluster(context) @@ -244,6 +247,13 @@ def _select_render_workers(context: ExportContext) -> None: else None, ) context.render_mode = "process" if context.worker_count > 1 else "single" + if is_ida and context.jobs is not None and context.worker_count < context.jobs: + context.progress.log( + f"Note: limiting to {context.worker_count} worker(s) instead of the " + f"requested {context.jobs} to fit available memory " + f"(each IDA worker loads the whole database; override with " + f"TOCODE_IDA_WORKER_MEMORY_MB)." + ) context.progress.log( describe_jobs( function_count=count, @@ -871,6 +881,8 @@ def build_tree_cluster_file( else raw_resolved.with_suffix(".asm"), asm_line_start=raw_range.asm_line_start if raw_range is not None else 1, asm_line_end=raw_range.asm_line_end if raw_range is not None else 1, + arg_count=raw_range.arg_count if raw_range is not None else None, + local_count=raw_range.local_count if raw_range is not None else None, ) ) c_line = end + 2 @@ -937,6 +949,7 @@ def build_cluster_files( _summary_function(routine, summary_path, item.summary_text).rstrip() + "\n\n" ) + arg_count, local_count = _counts_from_summary(item.summary_text) ranges.append( FunctionRange( address=address, @@ -947,6 +960,8 @@ def build_cluster_files( asm_file=asm_resolved, asm_line_start=asm_start, asm_line_end=asm_end, + arg_count=arg_count, + local_count=local_count, ) ) c_line = c_end + 2 @@ -961,6 +976,35 @@ def build_cluster_files( } +def _counts_from_summary(summary_text: str) -> tuple[int | None, int | None]: + """Recover the argument and local counts the backend reported in a summary. + + The decompiler computes these while rendering, so reading them back from the + summary avoids re-deriving them (which, for IDA, would mean decompiling every + function a second time during inventory). Returns ``(None, None)`` when the + summary does not carry the fields, so callers can fall back to inventory data. + """ + args: int | None = None + locals_: int | None = None + for line in summary_text.splitlines(): + key, sep, value = line.partition(":") + if not sep: + continue + label = key.strip() + if label == "args": + args = _safe_int(value) + elif label == "locals": + locals_ = _safe_int(value) + return args, locals_ + + +def _safe_int(value: str) -> int | None: + try: + return int(value.strip()) + except (TypeError, ValueError): + return None + + def render_functions( *, analyzer: BinaryAnalyzer, @@ -1450,7 +1494,7 @@ def write_header() -> None: def write_variables() -> None: context.data_variable_count = export_variables( - analysis, root, context.raw_ranges + analysis, root, entropy=context.entropy_enabled ) def write_indexes() -> None: @@ -1484,7 +1528,11 @@ def write_triage() -> None: write_json( root / "triage.json", triage_json( - analysis, context.clusters, context.raw_ranges, shared["reachable"] + analysis, + context.clusters, + context.raw_ranges, + shared["reachable"], + entropy=context.entropy_enabled, ), ) @@ -1506,13 +1554,14 @@ def write_docs() -> None: ("function index", write_indexes), ( "sections.json", - lambda: write_json(root / "sections.json", sections_json(analysis)), + lambda: write_json( + root / "sections.json", + sections_json(analysis, entropy=context.entropy_enabled), + ), ), ( "strings.json", - lambda: write_json( - root / "strings.json", strings_json(analysis, context.raw_ranges) - ), + lambda: write_json(root / "strings.json", strings_json(analysis)), ), ( "imports.json", diff --git a/src/tocode/metadata.py b/src/tocode/metadata.py index b907eae..318c19c 100644 --- a/src/tocode/metadata.py +++ b/src/tocode/metadata.py @@ -1,13 +1,14 @@ from __future__ import annotations -from collections import deque +from bisect import bisect_right +from collections import Counter, deque import math from pathlib import Path import re -from typing import Any +from typing import Callable from .naming import SHARED_CLUSTER_ID, c_file_name, clean_path_component -from .schema import Cluster, FunctionRange, ProgramAnalysis, Segment, StringEntry +from .schema import Cluster, FunctionRange, ProgramAnalysis, Routine, Segment def display_path(path: Path) -> str: @@ -59,7 +60,9 @@ def exports_json(analysis: ProgramAnalysis) -> dict[str, object]: } -def sections_json(analysis: ProgramAnalysis) -> dict[str, object]: +def sections_json( + analysis: ProgramAnalysis, *, entropy: bool = True +) -> dict[str, object]: return { "sections": [ { @@ -71,7 +74,7 @@ def sections_json(analysis: ProgramAnalysis) -> dict[str, object]: "type": item.kind, "permissions": item.perms, "rwx": item.readable and item.writable and item.executable, - "entropy": section_entropy(analysis, item), + "entropy": section_entropy(analysis, item, enabled=entropy), } for item in analysis.segments ] @@ -93,10 +96,8 @@ def relocations_json(analysis: ProgramAnalysis) -> dict[str, object]: } -def strings_json( - analysis: ProgramAnalysis, ranges: list[FunctionRange] -) -> dict[str, object]: - xrefs = string_xrefs(analysis.strings, ranges) +def strings_json(analysis: ProgramAnalysis) -> dict[str, object]: + locate = _function_locator(analysis) return { "strings": [ { @@ -107,7 +108,7 @@ def strings_json( "section": item.segment, "type": item.kind, "value": item.value, - "xrefs": xrefs.get(item.vaddr, []), + "xrefs": _data_xref_rows(analysis, item.vaddr, locate), } for item in analysis.strings ] @@ -139,8 +140,12 @@ def functions_json( "c_name": c_names.get(address, routine.name), "prototype": prototypes.get(address), "size": routine.size, - "nargs": routine.args_count, - "nlocals": routine.locals_count, + "nargs": raw_range.arg_count + if raw_range is not None and raw_range.arg_count is not None + else routine.args_count, + "nlocals": raw_range.local_count + if raw_range is not None and raw_range.local_count is not None + else routine.locals_count, "stackframe": routine.stack_size, "callees": [f"0x{item:x}" for item in callees], "callee_names": [ @@ -251,6 +256,8 @@ def triage_json( clusters: list[Cluster], ranges: list[FunctionRange], reachable_doc: dict[str, object], + *, + entropy: bool = True, ) -> dict[str, object]: reachable_rows = reachable_doc.get("reachable", []) return { @@ -258,10 +265,10 @@ def triage_json( "arch": analysis.binary.arch, "bits": analysis.binary.bits, "compiler": guess_compiler(analysis), - "packed": guess_packed(analysis), + "packed": guess_packed(analysis, entropy=entropy), "entry_clusters": entry_clusters(analysis, clusters, ranges), - "sections": triage_sections(analysis), - "rwx_sections": triage_sections(analysis, rwx_only=True), + "sections": triage_sections(analysis, entropy=entropy), + "rwx_sections": triage_sections(analysis, rwx_only=True, entropy=entropy), "export_count": len(analysis.exports), "import_count": len(analysis.imports), "strings_of_interest": interesting_strings(analysis)[:50], @@ -305,7 +312,10 @@ def entry_clusters( def export_variables( - analysis: ProgramAnalysis, root: Path, ranges: list[FunctionRange] + analysis: ProgramAnalysis, + root: Path, + *, + entropy: bool = True, ) -> int: data_dir = root / "data" sections: list[dict[str, object]] = [] @@ -319,7 +329,7 @@ def export_variables( handle.seek(segment.paddr) blob = handle.read(segment.size) (data_dir / file_name).write_bytes(blob) - if blob: + if blob and entropy: segment.entropy = round(shannon_entropy(blob), 6) sections.append( { @@ -329,10 +339,10 @@ def export_variables( "size": segment.vsize, "file_size": segment.size, "permissions": segment.perms, - "entropy": section_entropy(analysis, segment), + "entropy": section_entropy(analysis, segment, enabled=entropy), } ) - variables = variables_document(analysis, ranges) + variables = variables_document(analysis) write_json( data_dir / "variables.json", {"sections": sections, "variables": variables} ) @@ -344,7 +354,7 @@ def export_variables( def variables_document( - analysis: ProgramAnalysis, ranges: list[FunctionRange] + analysis: ProgramAnalysis, ) -> dict[str, dict[str, object]]: variables: dict[str, dict[str, object]] = {} seen: set[int] = set() @@ -446,106 +456,80 @@ def unique(base: str, address: int) -> str: } seen.add(flag.offset) - add_variable_xrefs(variables, ranges) + add_variable_xrefs(variables, analysis) return dict(sorted(variables.items())) -def source_lines(ranges: list[FunctionRange]) -> list[dict[str, Any]]: - cache: dict[Path, list[str]] = {} - rows: list[dict[str, Any]] = [] - for item in ranges: - if item.c_file not in cache: - try: - cache[item.c_file] = item.c_file.read_text( - encoding="utf-8", errors="replace" - ).splitlines() - except OSError: - cache[item.c_file] = [] - for line_number in range(item.c_line_start, item.c_line_end + 1): - index = line_number - 1 - if 0 <= index < len(cache[item.c_file]): - rows.append( - { - "address": item.address, - "function": item.name, - "path": str(item.c_file), - "line": line_number, - "text": cache[item.c_file][index], - } - ) - return rows +def _function_locator( + analysis: ProgramAnalysis, +) -> Callable[[int], Routine | None]: + """Build a fast "which function contains this address" lookup.""" + items = sorted( + (routine.address, routine.address + max(routine.size, 1), routine) + for routine in analysis.routines.values() + ) + starts = [item[0] for item in items] + + def locate(address: int) -> Routine | None: + index = bisect_right(starts, address) - 1 + if 0 <= index < len(items): + start, end, routine = items[index] + if start <= address < end: + return routine + return None + return locate -def string_xrefs( - strings: list[StringEntry], ranges: list[FunctionRange] -) -> dict[int, list[dict[str, object]]]: - lines = source_lines(ranges) - result: dict[int, list[dict[str, object]]] = {} - for item in strings: - needles = string_needles(item.value) - found: list[dict[str, object]] = [] - seen: set[tuple[int, int, str]] = set() - for line in lines: - if not any(needle in str(line["text"]) for needle in needles): - continue - key = (int(line["address"]), int(line["line"]), str(line["path"])) - if key in seen: - continue - seen.add(key) - found.append( - { - "function": line["function"], - "address": f"0x{int(line['address']):x}", - "cluster": display_path(Path(str(line["path"]))), - "line": line["line"], - } - ) - result[item.vaddr] = found - return result + +def _data_xref_rows( + analysis: ProgramAnalysis, + address: int, + locate: Callable[[int], Routine | None], +) -> list[dict[str, object]]: + """Cross-references to a data address, taken from the decompiler's xref data. + + Resolves each referencing instruction to its containing function and reports + read/write access, deduplicated per (function, access). + """ + rows: list[dict[str, object]] = [] + seen: set[tuple[int, bool]] = set() + for ref_ea, is_write in analysis.data_xrefs.get(address, ()): + routine = locate(ref_ea) + if routine is None: + continue + key = (routine.address, is_write) + if key in seen: + continue + seen.add(key) + rows.append( + { + "function": routine.name, + "address": f"0x{routine.address:x}", + "access": "write" if is_write else "read", + } + ) + rows.sort(key=lambda row: (str(row["address"]), str(row["access"]))) + return rows def add_variable_xrefs( - variables: dict[str, dict[str, object]], ranges: list[FunctionRange] + variables: dict[str, dict[str, object]], analysis: ProgramAnalysis ) -> None: - lines = source_lines(ranges) - for name, variable in variables.items(): - needles = {name} - value = variable.get("value") - if isinstance(value, str) and len(value) >= 4: - needles.update(string_needles(value)) - flag = variable.get("flag") - if isinstance(flag, str): - needles.add(flag) - xrefs: list[dict[str, object]] = [] - seen: set[tuple[int, bool]] = set() - for line in lines: - text = str(line["text"]) - if not any(needle and needle in text for needle in needles): - continue - key = (int(line["address"]), looks_written(text)) - if key in seen: - continue - seen.add(key) - xrefs.append( - { - "function": line["function"], - "address": f"0x{int(line['address']):x}", - "access": "write" if key[1] else "read", - } - ) - variable["xrefs"] = xrefs + locate = _function_locator(analysis) + for variable in variables.values(): + address = _parse_hex(variable.get("va")) + variable["xrefs"] = ( + _data_xref_rows(analysis, address, locate) if address is not None else [] + ) -def string_needles(value: str) -> list[str]: - text = value.strip() - if len(text) < 4: - return [] - escaped = text.encode("unicode_escape").decode("ascii") - return [ - item - for item in {text, escaped, escaped.replace("\\\\", "\\")} - if len(item) >= 4 - ] +def _parse_hex(value: object) -> int | None: + if not isinstance(value, str): + return None + try: + return int(value, 16) + except ValueError: + return None def interesting_variables( @@ -635,7 +619,7 @@ def looks_forwarded(value: str) -> bool: def triage_sections( - analysis: ProgramAnalysis, *, rwx_only: bool = False + analysis: ProgramAnalysis, *, rwx_only: bool = False, entropy: bool = True ) -> list[dict[str, object]]: rows: list[dict[str, object]] = [] for item in analysis.segments: @@ -645,7 +629,7 @@ def triage_sections( rows.append( { "name": item.name, - "entropy": section_entropy(analysis, item), + "entropy": section_entropy(analysis, item, enabled=entropy), "perms": item.perms, "size": item.size, "rwx": rwx, @@ -685,23 +669,25 @@ def guess_compiler(analysis: ProgramAnalysis) -> str | None: return None -def guess_packed(analysis: ProgramAnalysis) -> bool: +def guess_packed(analysis: ProgramAnalysis, *, entropy: bool = True) -> bool: + if not entropy: + return False executable = [segment for segment in analysis.segments if segment.executable] return bool( [ segment for segment in executable - if (section_entropy(analysis, segment) or 0.0) >= 7.2 + if (section_entropy(analysis, segment, enabled=entropy) or 0.0) >= 7.2 ] and len(analysis.imports) <= 5 ) -def looks_written(line: str) -> bool: - return bool(re.search(r"(?])=(?!=)", line) or "++" in line or "--" in line) - - -def section_entropy(analysis: ProgramAnalysis, segment: Segment) -> float | None: +def section_entropy( + analysis: ProgramAnalysis, segment: Segment, *, enabled: bool = True +) -> float | None: + if not enabled: + return None if segment.entropy is not None: return segment.entropy if segment.size <= 0: @@ -721,9 +707,9 @@ def section_entropy(analysis: ProgramAnalysis, segment: Segment) -> float | None def shannon_entropy(data: bytes) -> float: if not data: return 0.0 - counts: dict[int, int] = {} - for byte in data: - counts[byte] = counts.get(byte, 0) + 1 + # Counter counts at C speed, far faster than a per-byte Python loop on the + # large segments found in kernels and other big binaries. + counts = Counter(data) total = len(data) return -sum((count / total) * math.log2(count / total) for count in counts.values()) diff --git a/src/tocode/parallel.py b/src/tocode/parallel.py index 86a337d..609c376 100644 --- a/src/tocode/parallel.py +++ b/src/tocode/parallel.py @@ -14,8 +14,10 @@ # A worker loads the whole IDA database into memory. Estimate its resident cost # from the database size so that huge databases (kernels) do not over-subscribe # RAM. Base covers IDA runtime + Hex-Rays; the factor covers the loaded database. -IDA_WORKER_BASE_MEMORY_MB = 768 -IDA_DB_RESIDENT_FACTOR = 1.5 +# Both are tunable defaults, not hardware-specific values: the resulting ceiling +# is computed from the host's real available memory and the real database size. +DEFAULT_IDA_WORKER_BASE_MEMORY_MB = 768 +DEFAULT_IDA_DB_RESIDENT_FACTOR = 1.5 def choose_jobs( @@ -32,9 +34,20 @@ def choose_jobs( ) -> int: limit = job_limit if job_limit is not None else configured_job_limit() is_ida = backend.lower() == "ida" - # Explicit `--jobs N` is honored as-is; the operator has opted into N workers. + memory_ceiling = ( + _ida_memory_ceiling(available_memory_mb, ida_worker_memory_mb, database_size_mb) + if is_ida + else None + ) + + # An explicit `--jobs N` is honored, but still capped by the memory budget for + # IDA: each worker loads the whole database, so N workers that cannot fit in + # RAM get OOM-killed mid-export, which is strictly worse than running fewer. if requested is not None: - return max(1, min(requested, function_count or 1, limit)) + chosen = max(1, min(requested, function_count or 1, limit)) + if memory_ceiling is not None: + chosen = min(chosen, memory_ceiling) + return max(1, chosen) if function_count < MIN_FUNCTIONS_FOR_AUTO or analysis_seconds is None: return 1 @@ -44,12 +57,8 @@ def choose_jobs( cpus = cpu_count if cpu_count is not None else (os.cpu_count() or 1) backend_limit = MAX_AUTO_IDA_JOBS if is_ida else MAX_AUTO_JOBS ceiling = min(cpus, backend_limit, limit, function_count) - if is_ida: - memory_ceiling = _ida_memory_ceiling( - available_memory_mb, ida_worker_memory_mb, database_size_mb - ) - if memory_ceiling is not None: - ceiling = min(ceiling, memory_ceiling) + if memory_ceiling is not None: + ceiling = min(ceiling, memory_ceiling) target = math.ceil(function_count / FUNCTIONS_PER_WORKER) return max(1, min(ceiling, target)) @@ -67,8 +76,8 @@ def _ida_memory_ceiling( else configured_ida_worker_memory_mb() ) if database_size_mb is not None and database_size_mb > 0: - estimated = IDA_WORKER_BASE_MEMORY_MB + int( - database_size_mb * IDA_DB_RESIDENT_FACTOR + estimated = configured_ida_worker_base_mb() + int( + database_size_mb * configured_ida_db_resident_factor() ) worker_memory_mb = max(worker_memory_mb, estimated) return max(1, available_memory_mb // worker_memory_mb) @@ -83,6 +92,8 @@ def describe_jobs( backend: str, ) -> str: if requested is not None: + if selected < requested: + return f"Workers: {selected} (requested {requested}, capped for memory)" return f"Workers: {selected} requested" if function_count < MIN_FUNCTIONS_FOR_AUTO: return f"Workers: 1 for {function_count} functions" @@ -115,6 +126,28 @@ def configured_ida_worker_memory_mb() -> int: return max(512, value) +def configured_ida_worker_base_mb() -> int: + raw = os.environ.get( + "TOCODE_IDA_WORKER_BASE_MEMORY_MB", str(DEFAULT_IDA_WORKER_BASE_MEMORY_MB) + ).strip() + try: + value = int(raw) + except ValueError: + return DEFAULT_IDA_WORKER_BASE_MEMORY_MB + return max(0, value) + + +def configured_ida_db_resident_factor() -> float: + raw = os.environ.get( + "TOCODE_IDA_DB_RESIDENT_FACTOR", str(DEFAULT_IDA_DB_RESIDENT_FACTOR) + ).strip() + try: + value = float(raw) + except ValueError: + return DEFAULT_IDA_DB_RESIDENT_FACTOR + return max(0.0, value) + + def available_memory_mb() -> int | None: meminfo = _linux_mem_available_mb() if meminfo is not None: diff --git a/src/tocode/schema.py b/src/tocode/schema.py index 09fb665..3dc9bf5 100644 --- a/src/tocode/schema.py +++ b/src/tocode/schema.py @@ -135,6 +135,8 @@ class FunctionRange: asm_file: Path asm_line_start: int asm_line_end: int + arg_count: int | None = None + local_count: int | None = None @dataclass(slots=True) @@ -186,6 +188,9 @@ class ProgramAnalysis: import_calls: dict[int, list[str]] roots: list[int] thunks: set[int] + # Data address -> list of (referencing instruction address, is_write), taken + # directly from the decompiler's cross-reference database during analysis. + data_xrefs: dict[int, list[tuple[int, bool]]] = field(default_factory=dict) _app_cache: list[Routine] | None = field(default=None, init=False, repr=False) def segment_at(self, address: int) -> Segment | None: diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index c136486..2d4191f 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -93,7 +93,9 @@ def test_choose_jobs_auto_ida_budget_allows_parallel_for_small_database() -> Non ) -def test_requested_jobs_are_not_limited_by_available_memory() -> None: +def test_requested_ida_jobs_are_capped_by_available_memory() -> None: + # Each IDA worker loads the whole database; 3 requested workers cannot fit in + # ~3.5 GB at 4 GB/worker, so the count is capped to avoid OOM-killed workers. assert ( choose_jobs( function_count=300, @@ -105,6 +107,50 @@ def test_requested_jobs_are_not_limited_by_available_memory() -> None: available_memory_mb=3500, ida_worker_memory_mb=4096, ) + == 1 + ) + + +def test_ida_memory_model_is_env_tunable(monkeypatch) -> None: + def select() -> int: + return choose_jobs( + function_count=18000, + analysis_seconds=20.0, + requested=8, + backend="ida", + cpu_count=8, + job_limit=16, + available_memory_mb=6800, + database_size_mb=1900, + ) + + for name in ( + "TOCODE_IDA_DB_RESIDENT_FACTOR", + "TOCODE_IDA_WORKER_BASE_MEMORY_MB", + "TOCODE_IDA_WORKER_MEMORY_MB", + ): + monkeypatch.delenv(name, raising=False) + # Default model caps a 1.9 GB database on ~6.8 GB to a single worker. + assert select() == 1 + # Operators can relax the model without code changes. + monkeypatch.setenv("TOCODE_IDA_DB_RESIDENT_FACTOR", "0") + monkeypatch.setenv("TOCODE_IDA_WORKER_BASE_MEMORY_MB", "0") + monkeypatch.setenv("TOCODE_IDA_WORKER_MEMORY_MB", "1024") + assert select() == 6 # 6800 // 1024 + + +def test_requested_jobs_ignore_memory_for_non_ida_backends() -> None: + assert ( + choose_jobs( + function_count=300, + analysis_seconds=0.2, + requested=3, + backend="r2", + cpu_count=32, + job_limit=64, + available_memory_mb=3500, + ida_worker_memory_mb=4096, + ) == 3 ) diff --git a/tests/test_cli.py b/tests/test_cli.py index de8fe38..7191313 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -27,3 +27,10 @@ def test_parser_accepts_short_quiet_flag() -> None: args = build_parser().parse_args(["-q", "sample.bin"]) assert args.quiet is True + + +def test_parser_uses_entropy_as_opt_in_flag() -> None: + parser = build_parser() + + assert parser.parse_args(["sample.bin"]).entropy is False + assert parser.parse_args(["--entropy", "sample.bin"]).entropy is True diff --git a/tests/test_exporter.py b/tests/test_exporter.py index 298065e..4a0fba1 100644 --- a/tests/test_exporter.py +++ b/tests/test_exporter.py @@ -2,18 +2,22 @@ import json from pathlib import Path +from typing import Any, cast from tocode.exporter import ( + _counts_from_summary, _worker_spec, export_binary, fallback_prototype, render_one, tree_safe_function, ) +from tocode.metadata import functions_json from tocode.naming import NameBook from tocode.progress import Progress from tocode.schema import ( BinaryFacts, + FunctionRange, ProgramAnalysis, Routine, Segment, @@ -460,6 +464,168 @@ def test_default_output_uses_invoked_binary_path(tmp_path: Path) -> None: assert summary.root_dir == (invoked_root / "sample_decompiler").resolve() +def test_counts_from_summary_reads_args_and_locals() -> None: + summary = ( + "signature: int f(int a, int b)\n" + "address: 0x1000\n" + "size: 32 bytes\n" + "args: 2\n" + "locals: 5\n" + ) + assert _counts_from_summary(summary) == (2, 5) + + +def test_counts_from_summary_missing_fields_returns_none() -> None: + assert _counts_from_summary("radare-style summary with no counts") == (None, None) + + +def _single_routine_analysis() -> ProgramAnalysis: + binary = BinaryFacts( + path=Path("/bin/sample"), + arch="x86", + bits=64, + image_base=0, + os_name="linux", + format_name="elf", + file_type="elf", + entrypoints=[], + ) + routine = Routine( + address=0x1000, + name="sub_1000", + size=32, + signature=None, + calltype=None, + noreturn=False, + stack_size=0, + locals_count=0, + args_count=0, + outdegree=0, + indegree=0, + ) + return ProgramAnalysis( + binary=binary, + segments=[], + routines={0x1000: routine}, + imports={}, + exports=[], + symbols=[], + relocations=[], + strings=[], + flags=[], + callees={0x1000: []}, + callers={0x1000: []}, + import_calls={0x1000: []}, + roots=[0x1000], + thunks=set(), + ) + + +def test_strings_json_xrefs_come_from_backend_data() -> None: + from tocode.metadata import strings_json + + analysis = _single_routine_analysis() # routine 0x1000 covers 0x1000..0x1020 + analysis.strings.append( + StringEntry(0x4000, 0x4000, 5, 5, ".rodata", "ascii", "hello") + ) + analysis.data_xrefs[0x4000] = [(0x1008, False), (0x1010, True)] + + rows = cast(list[dict[str, Any]], strings_json(analysis)["strings"]) + + assert rows[0]["xrefs"] == [ + {"function": "sub_1000", "address": "0x1000", "access": "read"}, + {"function": "sub_1000", "address": "0x1000", "access": "write"}, + ] + + +def test_add_variable_xrefs_uses_backend_data() -> None: + from tocode.metadata import add_variable_xrefs + + analysis = _single_routine_analysis() + analysis.data_xrefs[0x4000] = [(0x1004, False)] + variables: dict[str, dict[str, object]] = {"obj_4000": {"va": "0x4000"}} + + add_variable_xrefs(variables, analysis) + + assert variables["obj_4000"]["xrefs"] == [ + {"function": "sub_1000", "address": "0x1000", "access": "read"} + ] + + +def test_data_xref_to_address_outside_any_function_is_dropped() -> None: + from tocode.metadata import add_variable_xrefs + + analysis = _single_routine_analysis() + analysis.data_xrefs[0x4000] = [(0x9999, False)] # not inside any routine + variables: dict[str, dict[str, object]] = {"obj_4000": {"va": "0x4000"}} + + add_variable_xrefs(variables, analysis) + + assert variables["obj_4000"]["xrefs"] == [] + + +def test_sections_json_omits_entropy_unless_enabled() -> None: + from tocode.metadata import sections_json + + analysis = _single_routine_analysis() + analysis.segments.append( + Segment(".text", 16, 16, "PROGBITS", "r-x", 0, 0x1000, entropy=5.0) + ) + + off = cast(list[dict[str, Any]], sections_json(analysis, entropy=False)["sections"]) + on = cast(list[dict[str, Any]], sections_json(analysis, entropy=True)["sections"]) + + assert off[0]["entropy"] is None + assert on[0]["entropy"] == 5.0 + + +def test_functions_json_prefers_render_time_counts() -> None: + analysis = _single_routine_analysis() + ranges = [ + FunctionRange( + address=0x1000, + name="sub_1000", + c_file=Path("a.c"), + c_line_start=1, + c_line_end=2, + asm_file=Path("a.asm"), + asm_line_start=1, + asm_line_end=2, + arg_count=3, + local_count=7, + ) + ] + rows = cast( + list[dict[str, Any]], functions_json(analysis, ranges, {}, {})["functions"] + ) + assert rows[0]["nargs"] == 3 + assert rows[0]["nlocals"] == 7 + + +def test_functions_json_falls_back_to_inventory_counts() -> None: + analysis = _single_routine_analysis() + analysis.routines[0x1000].args_count = 4 + analysis.routines[0x1000].locals_count = 1 + # Range without recovered counts (e.g. radare2 backend summary). + ranges = [ + FunctionRange( + address=0x1000, + name="sub_1000", + c_file=Path("a.c"), + c_line_start=1, + c_line_end=2, + asm_file=Path("a.asm"), + asm_line_start=1, + asm_line_end=2, + ) + ] + rows = cast( + list[dict[str, Any]], functions_json(analysis, ranges, {}, {})["functions"] + ) + assert rows[0]["nargs"] == 4 + assert rows[0]["nlocals"] == 1 + + def test_tree_safe_function_preserves_scanner_calls() -> None: source = ( "__int64 __fastcall sub_1000@(char *a1)\n"