diff --git a/.github/workflows/binsync-compat-tests.yml b/.github/workflows/binsync-compat-tests.yml new file mode 100644 index 00000000..991c2df3 --- /dev/null +++ b/.github/workflows/binsync-compat-tests.yml @@ -0,0 +1,36 @@ +# This workflow tests declib compatibility with binsync +# It ensures that changes to declib don't break binsync functionality + +name: BinSync Compatibility Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + binsync-tests: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: "3.10" + - name: Set branch name + run: echo "BRANCH_NAME=${GITHUB_HEAD_REF}" >> $GITHUB_ENV + - name: Install declib from current branch + run: | + python -m pip install --upgrade pip + pip install . + - name: Install binsync and run its core tests + run: | + # Clone binsync and try to checkout the same branch if it exists + git clone https://github.com/binsync/binsync.git /tmp/binsync + cd /tmp/binsync + git checkout $BRANCH_NAME || true + pip install pytest .[extras] + # Run binsync core tests + pytest ./tests/test_client.py ./tests/test_state.py -v diff --git a/.github/workflows/core-tests.yml b/.github/workflows/core-tests.yml new file mode 100644 index 00000000..8778cbd6 --- /dev/null +++ b/.github/workflows/core-tests.yml @@ -0,0 +1,34 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Core Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: "${{ matrix.python-version }}" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest . ./examples/change_watcher_plugin/ + + - name: Pytest + run: | + pytest ./tests/test_artifacts.py ./tests/test_cli.py -v diff --git a/.github/workflows/dec-tests.yml b/.github/workflows/dec-tests.yml new file mode 100644 index 00000000..b0bd2ef4 --- /dev/null +++ b/.github/workflows/dec-tests.yml @@ -0,0 +1,58 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Decompiler Tests +env: + BN_SERIAL: ${{ secrets.BN_SERIAL }} + BN_LICENSE: ${{ secrets.BN_LICENSE }} + TOOLING_KEY: ${{ secrets.TOOLING_KEY }} + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] + - name: Download BS Artifact & Install IDA + run: | + (git clone https://github.com/binsync/bs-artifacts.git /tmp/bs-artifacts && \ + cd /tmp/bs-artifacts && \ + ./helpers/setup_ida_ci.sh) + # taken from https://github.com/mandiant/capa/blob/master/.github/workflows/tests.yml#L107-L147 + - name: Install Binary Ninja + if: ${{ env.BN_SERIAL != 0 }} + run: | + mkdir ./.github/binja + curl "https://raw.githubusercontent.com/Vector35/binaryninja-api/6812c97/scripts/download_headless.py" -o ./.github/binja/download_headless.py + python ./.github/binja/download_headless.py --serial ${{ env.BN_SERIAL }} --output .github/binja/BinaryNinja-headless.zip + unzip .github/binja/BinaryNinja-headless.zip -d .github/binja/ + python .github/binja/binaryninja/scripts/install_api.py --install-on-root --silent + - name: Set up Java 21 + uses: actions/setup-java@v4 + with: + distribution: "oracle" + java-version: "21" + - name: Install Ghidra + uses: antoniovazquezblanco/setup-ghidra@v2.0.12 + with: + version: "12.0" + auth_token: ${{ secrets.GITHUB_TOKEN }} + - name: Pytest + run: | + # these two test must be run in separate python environments, due to the way ghidra bridge works + # you also must run these tests in the exact order shown here + TEST_BINARIES_DIR=/tmp/bs-artifacts/binaries pytest tests/test_decompilers.py tests/test_client_server.py -sv \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..5d918205 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,40 @@ +name: Release + +on: + push: + tags: + - "v**" + +jobs: + + release-github: + name: Create Github Release + permissions: write-all + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Create Release + uses: ncipollo/release-action@v1 + with: + generateReleaseNotes: true + + release-pypi: + name: Release pypi package + runs-on: ubuntu-latest + steps: + - name: Checkout source + uses: actions/checkout@v2 + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: '3.10' + - name: Install build + run: pip install build + - name: Build dists + run: python -m build + - name: Release to PyPI + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..e5301349 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,103 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +DecLib is a unified decompiler API that provides an abstracted interface for working with multiple decompilers (IDA Pro, Binary Ninja, Ghidra, angr-management). It enables writing plugins and scripts that work across all supported decompilers with minimal changes. + +## Development Commands + +### Installation and Setup +```bash +# Install declib in development mode +pip install -e . + +# Install declib plugins to decompilers (required after pip install) +declib --install + +# Install to specific decompiler +declib --single-decompiler-install ida /path/to/ida +``` + +### Testing +```bash +# Run tests with pytest +pytest tests/ + +# Run specific test files +pytest tests/test_artifacts.py +pytest tests/test_decompilers.py +pytest tests/test_cli.py +pytest tests/test_remote_ghidra.py +``` + +### Project Management +```bash +# Build the package +python -m build + +# Install test dependencies +pip install -e ".[test]" + +# Install ghidra dependencies +pip install -e ".[ghidra]" +``` + +## Architecture + +### Core Components + +**DecompilerInterface** (`declib/api/decompiler_interface.py`): The main abstraction layer that provides unified access to different decompilers. Can operate in GUI mode (default) or headless mode. + +**ArtifactLifter** (`declib/api/artifact_lifter.py`): Handles conversion between decompiler-specific objects and DecLib artifacts. + +**Artifacts** (`declib/artifacts/`): Unified data structures representing decompiler concepts: +- `Function`, `FunctionHeader`, `FunctionArgument` +- `StackVariable`, `GlobalVariable` +- `Struct`, `StructMember`, `Enum`, `Typedef` +- `Comment`, `Patch`, `Context`, `Decompilation` + +**Decompiler Implementations** (`declib/decompilers/`): +- `ida/`: IDA Pro integration +- `binja/`: Binary Ninja integration +- `ghidra/`: Ghidra integration (with bridge support) +- `angr/`: angr-management integration + +### Plugin System + +**Decompiler Stubs** (`declib/decompiler_stubs/`): Plugin entry points for each decompiler that bootstrap DecLib functionality. + +**Plugin Installer** (`declib/plugin_installer.py`): Automatically installs DecLib plugins to detected decompiler installations. + +### Key Design Patterns + +**Artifact Dictionary Access**: Artifacts use a lazy-loading pattern where `.items()`, `.keys()`, `.values()` return "light" objects, but `dict[key]` returns full objects (which may trigger decompilation). + +**Decompiler Discovery**: Use `DecompilerInterface.discover()` to auto-detect the current decompiler environment. + +**Serialization**: All artifacts support JSON/TOML serialization via `.dumps()` and `.loads()` methods. + +## Development Guidelines + +### Adding New Decompiler Support +1. Create new directory in `declib/decompilers/` +2. Implement `interface.py` inheriting from `DecompilerInterface` +3. Implement `artifact_lifter.py` inheriting from `ArtifactLifter` +4. Add decompiler stub in `declib/decompiler_stubs/` +5. Update `SUPPORTED_DECOMPILERS` constant + +### Working with Artifacts +- Always use `deci.functions[addr]` to get full Function objects +- Use `for addr, light_func in deci.functions.items()` for iteration +- Test serialization with both JSON and TOML formats +- Validate artifacts work across all supported decompilers + +### Testing Strategy +- Unit tests for artifacts (`test_artifacts.py`) +- Integration tests for decompiler interfaces (`test_decompilers.py`) +- CLI testing (`test_cli.py`) +- Remote Ghidra functionality (`test_remote_ghidra.py`) + +### Environment Variables +- `GHIDRA_HEADLESS_PATH`: Path to Ghidra headless binary for headless mode \ No newline at end of file diff --git a/README.md b/README.md index eced956b..c7ac2049 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,105 @@ -# libbs (deprecated) +# DecLib +The decompiler API that works everywhere! -> **`libbs` has been renamed to [`declib`](https://github.com/binsync/declib).** -> -> This package is now a deprecation shim. Installing it will pull in `declib` -> and emit a `DeprecationWarning` on import. No further updates will be made -> to `libbs`. - -## Migrate +DecLib is an abstracted decompiler API that enables you to write plugins/scripts that work, with minimal edit, +in every decompiler supported by DecLib. DecLib was originally designed to work with [BinSync](https://binsync.net), and is the backbone +for all BinSync based plugins. +As an example, with the same script, you can [redefine the types of function variables with custom structs](./examples/struct_and_variable_use.py), all in less +than 30 lines, in any supported decompilers. +## Install ```bash -pip uninstall libbs pip install declib ``` -Replace `libbs` with `declib` in your imports: +The minimum Python version is **3.10**. + +## Supported Decompilers +- IDA Pro: **>= 8.4** (if you have an older version, use `v1.26.0`) +- Binary Ninja: **>= 2.4** +- angr-management: **>= 9.0** +- Ghidra: **>= 12.0** (started in PyGhidra mode) + +## Usage +DecLib exposes all decompiler API through the abstract class `DecompilerInterface`. The `DecompilerInterface` +can be used in either the default mode, which assumes a GUI, or `headless` mode. In `headless` mode, the interface will +start a new process using a specified decompiler. + +You can find various examples using DecLib in the [examples](./examples) folder. Examples that are plugins show off +more of the complicated API that allows you to use an abstracted UI, artifacts, and more. + +If you want a simplified command line interface (especially well-suited for LLMs), see the +[`decompiler` CLI guide](./docs/decompiler_cli.md). + +### UI Mode (default) +To use the same script everywhere, use the convenience function `DecompilerInterface.discover_interface()`, which will +auto find the correct interface. Copy the below code into any supported decompiler and it should run without edit. ```python -# Before -from libbs.api import DecompilerInterface +from declib.api import DecompilerInterface + +deci = DecompilerInterface.discover() +for addr in deci.functions: + function = deci.functions[addr] + if function.header.type == "void": + function.header.type = "int" + deci.functions[function.addr] = function +``` -# After +Note that for Ghidra in UI mode you must first start it in PyGhidra mode. You can do this by going to your install dir +and running `./support/pyghidraRun`. + +### Headless Mode +To use headless mode you must specify a decompiler to use. You can get the traditional interface using the following: + +```python from declib.api import DecompilerInterface + +deci = DecompilerInterface.discover(force_decompiler="ghidra", headless=True) ``` -All sources, docs, examples, and tests now live in the `declib` repository. +In the case of Ghidra, you must have the environment variable `GHIDRA_INSTALL_DIR` set to the path of the Ghidra +installation (the place the `ghidraRun` script is located). + +### Artifact Access Caveats +In designing the dictionaries that contain all Artifacts in a decompiler, we had a clash between ease-of-use and speed. +When accessing some artifacts like a `Function`, we must decompile the function. Decompiling is slow. Due to this issue +we slightly changed how these dictionaries work to fast accessing. + +The only way to access a **full** artifact is to use the `getitem` interface of a dictionary. In practice this +looks like the following: +```python +for func_addr, light_func in deci.functions.items(): + full_function = deci.function[func_addr] +``` + +Notice, when using the `items` function the function is `light`, meaning it does not contain stack vars and other +info. This also means using `keys`, `values`, or `list` on an artifact dictionary will have the same affect. + +### Serializing Artifacts +All artifacts are serializable to the TOML and JSON formats. Serialization is done like so: +```python +from declib.artifacts import Function +import json + +my_func = Function(name="my_func", addr=0x4000, size=0x10) +json_str = my_func.dumps(fmt="json") +loaded_dict = json.loads(json_str) # now loadable through normal JSON parsing +loaded_func = Function.loads(json_str, fmt="json") +``` + +## Sponsors +BinSync and its associated projects would not be possible without sponsorship. +In no particular order, we'd like to thank all the organizations that have previously or are currently sponsoring +one of the many BinSync projects. + +

+ NSF +
+ DARPA +
+ ARPA-H +
+ RevEng AI +

+ diff --git a/declib/__init__.py b/declib/__init__.py new file mode 100644 index 00000000..0810d6f6 --- /dev/null +++ b/declib/__init__.py @@ -0,0 +1,9 @@ +__version__ = "4.0.0" + + +import logging +logging.getLogger("declib").addHandler(logging.NullHandler()) +from declib.logger import Loggers +loggers = Loggers() +del Loggers +del logging diff --git a/declib/__main__.py b/declib/__main__.py new file mode 100644 index 00000000..040d8452 --- /dev/null +++ b/declib/__main__.py @@ -0,0 +1,190 @@ +import argparse +import sys +import logging + +from declib.plugin_installer import DecLibPluginInstaller + +_l = logging.getLogger(__name__) + + +def install(): + DecLibPluginInstaller().install() + + +def start_server( + socket_path=None, decompiler=None, binary_path=None, headless=False, + server_id=None, project_dir=None, +): + """Start the DecompilerServer (AF_UNIX socket-based)""" + try: + from declib.api.decompiler_server import DecompilerServer + from declib.api.decompiler_interface import DecompilerInterface + + # Configure logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Prepare interface kwargs + interface_kwargs = {} + if decompiler: + interface_kwargs['force_decompiler'] = decompiler + if binary_path: + interface_kwargs['binary_path'] = binary_path + if headless: + interface_kwargs['headless'] = headless + if project_dir: + interface_kwargs['project_dir'] = project_dir + + # Create and start server + if socket_path: + _l.info(f"Starting AF_UNIX DecompilerServer on {socket_path}") + else: + _l.info("Starting AF_UNIX DecompilerServer with auto-generated socket path") + if interface_kwargs: + _l.info(f"Interface options: {interface_kwargs}") + + with DecompilerServer(socket_path=socket_path, server_id=server_id, **interface_kwargs) as server: + _l.info("Server started successfully. Press Ctrl+C to stop.") + _l.info("Connect with: DecompilerClient.discover('unix://{}')".format(server.socket_path)) + try: + server.wait_for_shutdown() + except KeyboardInterrupt: + _l.info("Shutting down server...") + + except ImportError as e: + _l.error(f"Failed to import required modules: {e}") + sys.exit(1) + except Exception as e: + _l.error(f"Failed to start server: {e}") + sys.exit(1) + + +def test_client(server_url=None): + """Test the DecompilerClient connection""" + try: + from declib.api.decompiler_client import DecompilerClient + + # Configure logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + if server_url: + _l.info(f"Testing connection to DecompilerServer at {server_url}") + else: + _l.info("Testing connection to auto-discovered DecompilerServer") + + with DecompilerClient.discover(server_url=server_url) as client: + _l.info(f"Successfully connected to {client.name} decompiler") + _l.info(f"Binary path: {client.binary_path}") + _l.info(f"Binary hash: {client.binary_hash}") + _l.info(f"Decompiler available: {client.decompiler_available}") + + # Test fast artifact collections (benchmark performance) + import time + start_time = time.time() + functions = list(client.functions.items()) + end_time = time.time() + _l.info(f"Retrieved {len(functions)} functions in {end_time - start_time:.3f}s") + + start_time = time.time() + comments = list(client.comments.keys()) + end_time = time.time() + _l.info(f"Retrieved {len(comments)} comment keys in {end_time - start_time:.3f}s") + + _l.info("✅ Client test completed successfully!") + + except ImportError as e: + _l.error(f"Failed to import required modules: {e}") + sys.exit(1) + except Exception as e: + _l.error(f"Client test failed: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description=""" + The DecLib Command Line Util. This is the script interface to DecLib that allows you to install and run + the Ghidra UI for running plugins, and start the DecompilerServer. + """, + epilog=""" + Examples: + declib --install | + declib --server --socket-path /tmp/my_server.sock | + declib --server --decompiler ghidra --binary-path /path/to/binary --headless + """ + ) + parser.add_argument( + "--install", action="store_true", help=""" + Install all the DecLib plugins to every decompiler. + """ + ) + parser.add_argument( + "--single-decompiler-install", nargs=2, metavar=('decompiler', 'path'), help="Install DAILA into a single decompiler. Decompiler must be one of: ida, ghidra, binja, angr." + ) + parser.add_argument( + "--server", action="store_true", help=""" + Start the DecompilerServer to expose DecompilerInterface APIs over AF_UNIX sockets. + """ + ) + parser.add_argument( + "--server-url", help=""" + URL of the DecompilerServer to connect to (e.g., unix:///tmp/server.sock). + If not specified, will auto-discover running servers. + """ + ) + parser.add_argument( + "--socket-path", help=""" + Path for the AF_UNIX socket (default: auto-generated in temp directory). + """ + ) + parser.add_argument( + "--decompiler", choices=["ida", "ghidra", "binja", "angr"], help=""" + Force a specific decompiler for the server. If not specified, auto-detection will be used. + """ + ) + parser.add_argument( + "--binary-path", help=""" + Path to the binary file to analyze (required for headless mode). + """ + ) + parser.add_argument( + "--headless", action="store_true", help=""" + Run the decompiler in headless mode (no GUI). Requires --binary-path. + """ + ) + parser.add_argument( + "--server-id", help=""" + Explicit server ID to use; if omitted, a unique one is generated. + """ + ) + parser.add_argument( + "--project-dir", help=""" + Directory where the backend should store its project/database files + (Ghidra project, IDA .id*, etc.). If omitted, backend defaults apply + (Ghidra creates a project next to the binary; IDA writes .id* next + to the binary). + """ + ) + args = parser.parse_args() + + if args.single_decompiler_install: + decompiler, path = args.single_decompiler_install + DecLibPluginInstaller().install(interactive=False, paths_by_target={decompiler: path}) + elif args.install: + install() + elif args.server: + if args.headless and not args.binary_path: + parser.error("--headless requires --binary-path") + start_server( + socket_path=args.socket_path, + decompiler=args.decompiler, + binary_path=args.binary_path, + headless=args.headless, + server_id=args.server_id, + project_dir=args.project_dir, + ) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/declib/api/__init__.py b/declib/api/__init__.py new file mode 100644 index 00000000..4f676740 --- /dev/null +++ b/declib/api/__init__.py @@ -0,0 +1,13 @@ +from .artifact_lifter import ArtifactLifter +from .decompiler_interface import DecompilerInterface +from .type_parser import CTypeParser, CType + +from .decompiler_interface import ( + DecompilerInterface +) +from .artifact_lifter import ( + ArtifactLifter +) +from .type_parser import ( + CTypeParser, CType +) diff --git a/declib/api/artifact_dict.py b/declib/api/artifact_dict.py new file mode 100644 index 00000000..af98d5dd --- /dev/null +++ b/declib/api/artifact_dict.py @@ -0,0 +1,153 @@ +import typing +import logging + +from declib.artifacts import ( + Artifact, Comment, Enum, FunctionHeader, Function, FunctionArgument, + GlobalVariable, Patch, Segment, StackVariable, Struct, StructMember, Typedef +) + +if typing.TYPE_CHECKING: + from declib.api import DecompilerInterface + +_l = logging.getLogger(__name__) + + +class ArtifactDict(dict): + """ + The ArtifactDict is a Dictionary wrapper around the getting/setting/listing of artifacts in the decompiler. This + allows for a more pythonic interface to the decompiler artifacts. For example, instead of doing: + deci._set_function(func) + + You can do: + >>> deci.functions[func.addr] = func + + This class is not meant to be instantiated directly, but rather through the DecompilerInterface class. + There is currently some interesting affects and caveats to using this class: + - When you list artifacts, by calling list(), you will get a light copy of the artifacts. This means that if you + modify the artifacts in the list, they will not be reflected in the decompiler. You also do need get current + data in the decompiler, only an acknowledgement that the artifact exists. + - You must reassign the artifact to the dictionary to update the decompiler. + - When assigning something to the dictionary, it must always be in its lifted form. You will also only get lifted + artifacts back from the dictionary. + - For convience, you can access functions by their lowered address + """ + + def __init__(self, artifact_cls, deci: "DecompilerInterface", error_on_duplicate=False, scopable=False): + super().__init__() + + self._deci = deci + self._error_on_duplicate = error_on_duplicate + self._scopable = scopable + self._art_function = { + # ArtifactType: (setter, getter, lister) + Function: (self._deci._set_function, self._deci._get_function, self._deci._functions, self._deci._del_function), + StackVariable: (self._deci._set_stack_variable, self._deci._get_stack_variable, self._deci._stack_variables, self._deci._del_stack_variable), + GlobalVariable: (self._deci._set_global_variable, self._deci._get_global_var, self._deci._global_vars, self._deci._del_global_var), + Struct: (self._deci._set_struct, self._deci._get_struct, self._deci._structs, self._deci._del_struct), + Enum: (self._deci._set_enum, self._deci._get_enum, self._deci._enums, self._deci._del_enum), + Typedef: (self._deci._set_typedef, self._deci._get_typedef, self._deci._typedefs, self._deci._del_typedef), + Comment: (self._deci._set_comment, self._deci._get_comment, self._deci._comments, self._deci._del_comment), + Patch: (self._deci._set_patch, self._deci._get_patch, self._deci._patches, self._deci._del_patch), + Segment: (self._deci._set_segment, self._deci._get_segment, self._deci._segments, self._deci._del_segment) + } + + functions = self._art_function.get(artifact_cls, None) + if functions is None: + raise ValueError(f"Attempting to create a dict for a Artifact class that is not supported: {artifact_cls}") + + self._artifact_class = artifact_cls + self._artifact_setter, self._artifact_getter, self._artifact_lister, self._artifact_remover = functions + + def __len__(self): + return len(self._artifact_lister()) + + def _lifted_art_lister(self): + d = self._artifact_lister() + d_items = list(d.items()) + if not d_items: + return {} + + is_addr = hasattr(d_items[0][1], "addr") + new_d = {} + for k, v in d_items: + if is_addr: + k = self._deci.art_lifter.lift_addr(k) + new_d[k] = self._deci.art_lifter.lift(v) + + return new_d + + def __getitem__(self, item): + """ + Takes a lifted identifier as input and returns a lifted artifact + """ + if isinstance(item, int): + item = self._deci.art_lifter.lower_addr(item) + if self._scopable and not self._deci.supports_type_scopes: + item, _ = self._deci.art_lifter.parse_scoped_type(item) + + art = self._artifact_getter(item) + if art is None: + raise KeyError + + return self._deci.art_lifter.lift(art) + + def __setitem__(self, key, value): + """ + Both key and value must be lifted artifacts + """ + if not isinstance(value, self._artifact_class): + raise ValueError(f"Attempting to set a value of type {type(value)} to a dict of type {self._artifact_class}") + + if isinstance(key, int): + key = self._deci.art_lifter.lower_addr(key) + if self._scopable and not self._deci.supports_type_scopes: + key, _ = self._deci.art_lifter.parse_scoped_type(key) + + art = self._deci.art_lifter.lower(value) + if not self._artifact_setter(art) and self._error_on_duplicate: + raise ValueError(f"Set value {value} is already present at key {key}") + + def __contains__(self, item): + if isinstance(item, int): + item = self._deci.art_lifter.lower_addr(item) + if self._scopable and not self._deci.supports_type_scopes: + item, _ = self._deci.art_lifter.parse_scoped_type(item) + + data = self._artifact_getter(item) + return data is not None + + def __delitem__(self, key): + if isinstance(key, int): + key = self._deci.art_lifter.lower_addr(key) + if self._scopable and not self._deci.supports_type_scopes: + key, _ = self._deci.art_lifter.parse_scoped_type(key) + + art = self._artifact_getter(key) + if isinstance(art, Struct): + self._artifact_remover(key) + self._deci.struct_changed(art, deleted=True) + else: + self._artifact_remover(key) + + def __iter__(self): + return iter(self._lifted_art_lister()) + + def __repr__(self): + return f"<{self.__class__.__name__}: {self._artifact_class.__name__} len={self.__len__()}>" + + def __str__(self): + return f"{self._lifted_art_lister()}" + + def keys(self): + return self._lifted_art_lister().keys() + + def values(self): + return self._lifted_art_lister().values() + + def items(self): + return self._lifted_art_lister().items() + + def get(self, key, default=None): + if key in self: + return self[key] + return default diff --git a/declib/api/artifact_lifter.py b/declib/api/artifact_lifter.py new file mode 100644 index 00000000..98e1316c --- /dev/null +++ b/declib/api/artifact_lifter.py @@ -0,0 +1,161 @@ +import logging +import typing + +from declib.artifacts import StackVariable, Artifact, FunctionArgument, StructMember, Typedef, Enum, Struct +from declib.api.type_parser import CTypeParser + +if typing.TYPE_CHECKING: + from declib.api import DecompilerInterface + +_l = logging.getLogger(name=__name__) + + +class ArtifactLifter: + SCOPE_DELIMITER = "::" + + def __init__(self, deci: "DecompilerInterface", types=None): + self.deci = deci + self.type_parser = CTypeParser(extra_types=types) + + # + # Public API + # + + def lift(self, artifact: Artifact): + return self._lift_or_lower_artifact(artifact, "lift") + + def lower(self, artifact: Artifact): + return self._lift_or_lower_artifact(artifact, "lower") + + # + # Special handlers for scopes + # + + @staticmethod + def parse_scoped_type(type_str: str) -> tuple[str, str | None]: + """ + Parses a scoped type string into its base type and scope. + Note: the scope can be None if the type is not scoped. + + Examples: + 'stdint::uint32_t' -> ('uint32_t', 'stdint') + 'uint32_t' -> ('uint32_t', None) + """ + if not type_str: + return "", None + + # check if the type is scoped + scope = None + deli = ArtifactLifter.SCOPE_DELIMITER + if deli in type_str: + scope_parts = type_str.split(deli) + base_type = scope_parts[-1] + scope = deli.join(scope_parts[:-1]) + else: + base_type = type_str + + return base_type, scope + + @staticmethod + def scoped_type_to_str(name: str, scope: str | None = None) -> str: + """ + Converts a name and scope into a scoped type string. + Note: the scope can be None if the type is not scoped. + """ + return name if not scope else f"{scope}::{name}" + + # + # Override Mandatory Funcs + # + + def lift_type(self, type_str: str) -> str: + return type_str + + def lift_addr(self, addr: int) -> int: + base_addr = self.deci.binary_base_addr + if addr < base_addr: + self.deci.warning(f"Lifting an address that appears already lifted: {addr}...") + + return addr - base_addr + + def lift_stack_offset(self, offset: int, func_addr: int) -> int: + pass + + def lower_type(self, type_str: str) -> str: + if self.SCOPE_DELIMITER in type_str and not self.deci.supports_type_scopes: + type_str, scope = self.scoped_type_to_str(type_str) + + return type_str + + def lower_addr(self, addr: int) -> int: + base_addr = self.deci.binary_base_addr + if addr >= base_addr != 0: + self.deci.warning(f"Lowering an address that appears already lowered: {addr}...") + + return addr + base_addr + + def lower_stack_offset(self, offset: int, func_addr: int) -> int: + pass + + # + # Private + # + + def _lift_or_lower_artifact(self, artifact, mode): + target_attrs = ("name", "type", "offset", "addr", "func_addr", "line_map") + if mode not in ("lower", "lift"): + return None + + if not isinstance(artifact, Artifact): + return artifact + lifted_art = artifact.copy() + # correct simple properties in the artifact + for attr in target_attrs: + if hasattr(lifted_art, attr): + curr_val = getattr(lifted_art, attr) + if curr_val is None: + continue + + # special handling for stack variables + if attr == "offset": + if not isinstance(artifact, StackVariable): + continue + lifting_func = getattr(self, f"{mode}_stack_offset") + setattr(lifted_art, attr, lifting_func(curr_val, lifted_art.addr)) + # special handling for decompilation + elif attr == "line_map": + lifted_line_map = {} + lift_or_lower_func = self.lift_addr if mode == "lift" else self.lower_addr + for k, v in curr_val.items(): + lifted_line_map[k] = {lift_or_lower_func(_v) for _v in v} + + setattr(lifted_art, attr, lifted_line_map) + # special handling for types that have names + elif attr == "name": + if not isinstance(artifact, (Typedef, Enum, Struct)): + continue + lifted_type = self.lift_type(curr_val) if mode == "lift" else self.lower_type(curr_val) + setattr(lifted_art, attr, lifted_type) + else: + attr_func_name = attr if attr != "func_addr" else "addr" + lifting_func = getattr(self, f"{mode}_{attr_func_name}") + setattr(lifted_art, attr, lifting_func(curr_val)) + + # recursively correct nested artifacts + for attr in lifted_art.__slots__: + attr_val = getattr(lifted_art, attr) + if not attr_val: + continue + + # nested function headers + if attr == "header": + setattr(lifted_art, attr, self._lift_or_lower_artifact(attr_val, mode)) + # nested args, stack_vars, or struct_members + elif isinstance(attr_val, dict): + nested_arts = {} + for k, v in attr_val.items(): + nested_art = self._lift_or_lower_artifact(v, mode) + nested_arts[nested_art.offset if isinstance(nested_art, (StackVariable, FunctionArgument, StructMember)) else k] = nested_art + setattr(lifted_art, attr, nested_arts) + + return lifted_art diff --git a/declib/api/decompiler_client.py b/declib/api/decompiler_client.py new file mode 100644 index 00000000..6dd0b8a7 --- /dev/null +++ b/declib/api/decompiler_client.py @@ -0,0 +1,1219 @@ +# Note to reader: most of this code was generated by Claude 4.5. It may contain errors and was designed +# in tandem with decompiler_server.py and the tests/test_client_server.py file. This comment will be +# removed when the majority of the file is owned by a human author. + +import logging +import socket +import time +import os +import glob +import tempfile +from typing import Dict, Any, Optional, List, Union, Callable +from collections import defaultdict +import threading + +from declib.artifacts import ( + Artifact, Function, Comment, Patch, GlobalVariable, Segment, + Struct, Enum, Typedef, Context, Decompilation +) +from declib.artifacts.formatting import ArtifactFormat +from declib.api.decompiler_server import SocketProtocol +from declib.api.type_parser import CTypeParser +from declib.configuration import DecLibConfig +from declib.api import server_registry + +_l = logging.getLogger(__name__) + +# Must match decompiler_server._WIRE_FMT; JSON avoids the `toml` package's +# buggy handling of raw `\x` escapes inside decompilation text. +_WIRE_FMT = ArtifactFormat.JSON + + +class ArtLifterProxy: + """ + A proxy for the ArtifactLifter that delegates all operations to the remote server. + This maintains API compatibility with the local ArtifactLifter while sending + requests to the remote decompiler server. + """ + SCOPE_DELIMITER = "::" + + def __init__(self, client: 'DecompilerClient'): + self.client = client + + def lift(self, artifact: Artifact): + """Lift an artifact using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lift", + "args": [artifact] + }) + + def lower(self, artifact: Artifact): + """Lower an artifact using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lower", + "args": [artifact] + }) + + def lift_addr(self, addr: int) -> int: + """Lift an address using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lift_addr", + "args": [addr] + }) + + def lower_addr(self, addr: int) -> int: + """Lower an address using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lower_addr", + "args": [addr] + }) + + def lift_type(self, type_str: str) -> str: + """Lift a type string using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lift_type", + "args": [type_str] + }) + + def lower_type(self, type_str: str) -> str: + """Lower a type string using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lower_type", + "args": [type_str] + }) + + def lift_stack_offset(self, offset: int, func_addr: int) -> int: + """Lift a stack offset using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lift_stack_offset", + "args": [offset, func_addr] + }) + + def lower_stack_offset(self, offset: int, func_addr: int) -> int: + """Lower a stack offset using the remote decompiler""" + return self.client._send_request({ + "type": "method_call", + "method_name": "art_lifter.lower_stack_offset", + "args": [offset, func_addr] + }) + + @staticmethod + def parse_scoped_type(type_str: str) -> tuple[str, str | None]: + """ + Parse a scoped type string into its base type and scope. + This is a static method that doesn't need remote decompiler access. + """ + if not type_str: + return "", None + + # check if the type is scoped + scope = None + deli = ArtLifterProxy.SCOPE_DELIMITER + if deli in type_str: + scope_parts = type_str.split(deli) + base_type = scope_parts[-1] + scope = deli.join(scope_parts[:-1]) + else: + base_type = type_str + + return base_type, scope + + @staticmethod + def scoped_type_to_str(name: str, scope: str | None = None) -> str: + """ + Convert a name and scope into a scoped type string. + This is a static method that doesn't need remote decompiler access. + """ + return name if not scope else f"{scope}::{name}" + + +class FastClientArtifactDict(dict): + """ + A fast client-side proxy for ArtifactDict that communicates with DecompilerServer via AF_UNIX sockets. + + This class mimics the behavior of ArtifactDict but uses sockets for bulk operations + and maintains the same performance characteristics as the local version by using + the _lifted_art_lister pattern. + """ + + def __init__(self, collection_name: str, artifact_class, client: 'DecompilerClient'): + super().__init__() + self.collection_name = collection_name + self.artifact_class = artifact_class + self.client = client + self._light_cache = {} + self._light_cache_timestamp = 0 + self._cache_ttl = 10.0 # Cache for 10 seconds + + def _get_light_artifacts(self) -> Dict: + """Get all light artifacts using the server's fast bulk operation""" + current_time = time.time() + if current_time - self._light_cache_timestamp > self._cache_ttl: + # Cache expired, fetch from server using bulk endpoint + try: + _l.debug(f"Fetching light artifacts for {self.collection_name}") + request = { + "type": "get_light_artifacts", + "collection_name": self.collection_name + } + serialized_artifacts = self.client._send_request(request) + + # Reconstruct artifacts from serialized format + reconstructed_artifacts = {} + for addr, artifact_info in serialized_artifacts.items(): + try: + # Import the artifact class dynamically + module_name = artifact_info['module'] + class_name = artifact_info['type'] + serialized_data = artifact_info['data'] + + # Import the module and get the class + module = __import__(module_name, fromlist=[class_name]) + artifact_class = getattr(module, class_name) + + # Reconstruct the artifact using its loads method + artifact = artifact_class.loads(serialized_data, fmt=_WIRE_FMT) + reconstructed_artifacts[addr] = artifact + + except Exception as e: + _l.warning(f"Failed to reconstruct artifact at 0x{addr:x}: {e}") + # Skip problematic artifacts rather than failing completely + continue + + self._light_cache = reconstructed_artifacts + self._light_cache_timestamp = current_time + except Exception as e: + _l.warning(f"Failed to fetch light artifacts for {self.collection_name}: {e}") + + return self._light_cache + + def _invalidate_cache(self): + """Invalidate the light artifact cache""" + self._light_cache.clear() + self._light_cache_timestamp = 0 + + def __len__(self): + """Return the number of items in the collection""" + light_items = self._get_light_artifacts() + return len(light_items) + + def __iter__(self): + """Iterate over keys in the collection""" + light_items = self._get_light_artifacts() + return iter(light_items.keys()) + + def keys(self): + """Return an iterator over the keys (fast bulk operation)""" + light_items = self._get_light_artifacts() + return light_items.keys() + + def values(self): + """Return an iterator over the values (light artifacts, fast bulk operation)""" + light_items = self._get_light_artifacts() + return light_items.values() + + def items(self): + """Return an iterator over (key, value) pairs (fast bulk operation)""" + light_items = self._get_light_artifacts() + return light_items.items() + + def __getitem__(self, key): + """Get a full artifact by key (same behavior as local ArtifactDict)""" + # First, check if the key exists by looking at light artifacts + light_items = self._get_light_artifacts() + if key not in light_items: + raise KeyError(f"Key {key} not found in {self.collection_name}") + + # Key exists, get the full artifact from server + try: + request = { + "type": "get_full_artifact", + "collection_name": self.collection_name, + "key": key + } + return self.client._send_request(request) + except Exception as e: + if "not found" in str(e).lower(): + raise KeyError(f"Key {key} not found in {self.collection_name}") + else: + raise + + def get_light(self, key): + """Get a light artifact by key (fast, cached access)""" + light_items = self._get_light_artifacts() + if key not in light_items: + raise KeyError(f"Key {key} not found in {self.collection_name}") + return light_items[key] + + def get_full(self, key): + """Explicitly get a full artifact (with expensive operations like decompilation)""" + try: + request = { + "type": "get_full_artifact", + "collection_name": self.collection_name, + "key": key + } + return self.client._send_request(request) + except Exception as e: + if "not found" in str(e).lower(): + raise KeyError(f"Key {key} not found in {self.collection_name}") + else: + raise + + def __setitem__(self, key, value): + """Set an artifact by key on the server""" + if not isinstance(value, self.artifact_class): + raise ValueError(f"Expected {self.artifact_class.__name__}, got {type(value).__name__}") + + # Use the direct decompiler interface for setting artifacts + success = self.client.set_artifact(value) + + # Invalidate cache since we modified the collection + self._invalidate_cache() + + if not success: + raise RuntimeError(f"Failed to set artifact") + + def __delitem__(self, key): + """Delete an artifact by key (not implemented in decompiler interfaces)""" + raise NotImplementedError("Deletion not supported by DecompilerInterface") + + def __contains__(self, key): + """Check if a key exists in the collection""" + light_items = self._get_light_artifacts() + return key in light_items + + def get(self, key, default=None): + """Get a full artifact with a default value""" + try: + return self[key] # Use __getitem__ which returns full artifact + except KeyError: + return default + + +class DecompilerClient: + """ + A client that connects to DecompilerServer via AF_UNIX sockets and provides the same interface as DecompilerInterface. + + This class acts as a transparent proxy to a remote DecompilerInterface, allowing users to + write code that works identically whether using a local or remote decompiler. + """ + + def __init__(self, + socket_path: str, + timeout: float = 30.0): + """ + Initialize the DecompilerClient. + + Args: + socket_path: Path to the AF_UNIX socket + timeout: Connection timeout in seconds + """ + self.socket_path = socket_path + self.timeout = timeout + + # Connection state + self._socket = None + self._connected = False + self._server_info = None + self._socket_lock = threading.Lock() + + # Try to connect + self._connect() + + # Initialize fast artifact collections + self.functions = FastClientArtifactDict("functions", Function, self) + self.comments = FastClientArtifactDict("comments", Comment, self) + self.patches = FastClientArtifactDict("patches", Patch, self) + self.global_vars = FastClientArtifactDict("global_vars", GlobalVariable, self) + self.segments = FastClientArtifactDict("segments", Segment, self) + self.structs = FastClientArtifactDict("structs", Struct, self) + self.enums = FastClientArtifactDict("enums", Enum, self) + self.typedefs = FastClientArtifactDict("typedefs", Typedef, self) + + # Initialize callback attributes to match DecompilerInterface + self.artifact_change_callbacks = defaultdict(list) + self.decompiler_closed_callbacks = [] + self.decompiler_opened_callbacks = [] + self.undo_event_callbacks = [] + self._thread_artifact_callbacks = True + + # Create a proxy art_lifter that delegates to server + # art_lifter is typically used for address lifting operations + self.art_lifter = ArtLifterProxy(self) + + # Additional public attributes to match DecompilerInterface + self.type_parser = CTypeParser() # Local type parser + self.artifact_write_lock = threading.Lock() # Thread safety lock + self.config = DecLibConfig.update_or_make() # Configuration object + self.gui_plugin = None # GUI plugin reference + self.artifact_watchers_started = False # Watcher state + + # Event listener state for receiving callbacks from server + self._event_listener_running = False + self._subscribed_to_events = False + self._event_listener_thread = None + self._event_socket = None + self._event_socket_lock = threading.Lock() + + # These attributes will be fetched from server on first access + self._supports_undo = None + self._supports_type_scopes = None + self._qt_version = None + self._default_func_prefix = None + self._headless = None + self._force_click_recording = None + self._track_mouse_moves = None + + _l.info(f"DecompilerClient connected to {socket_path}") + + def _create_and_connect_socket(self) -> socket.socket: + """Create and connect a socket handling both AF_UNIX and AF_INET fallbacks.""" + if hasattr(socket, "AF_UNIX"): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(self.timeout) + sock.connect(self.socket_path) + else: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(self.timeout) + with open(self.socket_path, 'r') as f: + port = int(f.read().strip()) + sock.connect(('127.0.0.1', port)) + return sock + + def _connect(self): + """Establish connection to the server""" + try: + _l.debug(f"Attempting to connect to server at {self.socket_path}") + + self._socket = self._create_and_connect_socket() + + _l.debug("Socket connection established") + + # Test the connection by getting server info first + self._server_info = self._send_request({"type": "server_info"}) + _l.debug(f"Got server info: {self._server_info}") + + self._connected = True + + _l.info(f"Connected to {self._server_info.get('name', 'DecompilerServer')} " + f"using {self._server_info.get('decompiler', 'unknown')} decompiler") + except Exception as e: + _l.error(f"Failed to connect to DecompilerServer at {self.socket_path}: {e}") + + # Provide helpful error messages for common issues + if "No such file or directory" in str(e): + raise ConnectionError(f"Cannot connect to DecompilerServer at {self.socket_path}. " + f"Make sure the server is running with: declib --server") + elif "Connection refused" in str(e): + raise ConnectionError(f"Cannot connect to DecompilerServer at {self.socket_path}. " + f"Make sure the server is running.") + else: + raise ConnectionError(f"Cannot connect to DecompilerServer: {e}") + + def _send_request(self, request: Dict[str, Any]) -> Any: + """Send a request to the server and return the response""" + with self._socket_lock: + try: + SocketProtocol.send_message(self._socket, request) + response = SocketProtocol.recv_message(self._socket) + + # Check if response is an error + if isinstance(response, dict) and "error" in response: + error_type = response.get("type", "Exception") + error_msg = response.get("error", "Unknown error") + + # Try to reconstruct the original exception type + if error_type == "KeyError": + raise KeyError(error_msg) + elif error_type == "ValueError": + raise ValueError(error_msg) + elif error_type == "AttributeError": + raise AttributeError(error_msg) + else: + raise RuntimeError(f"{error_type}: {error_msg}") + + # Check if response is a serialized artifact + if isinstance(response, dict) and response.get("is_artifact"): + try: + # Reconstruct the artifact + module_name = response['module'] + class_name = response['type'] + serialized_data = response['data'] + + # Import the module and get the class + module = __import__(module_name, fromlist=[class_name]) + artifact_class = getattr(module, class_name) + + # Reconstruct the artifact using its loads method + artifact = artifact_class.loads(serialized_data, fmt=_WIRE_FMT) + return artifact + + except Exception as e: + _l.warning(f"Failed to reconstruct artifact response: {e}") + # Fall back to returning the raw response + return response + + return response + except Exception as e: + _l.error(f"Request failed: {e} for {request}") + raise + + # Properties - mirror DecompilerInterface properties + @property + def name(self) -> str: + """Name of the decompiler""" + return self._server_info.get('decompiler', 'remote') + + @property + def binary_base_addr(self) -> int: + """Base address of the binary""" + return self._send_request({"type": "property_get", "property_name": "binary_base_addr"}) + + @property + def binary_hash(self) -> str: + """Hash of the binary""" + return self._send_request({"type": "property_get", "property_name": "binary_hash"}) + + @property + def binary_path(self) -> Optional[str]: + """Path to the binary""" + return self._send_request({"type": "property_get", "property_name": "binary_path"}) + + @property + def decompiler_available(self) -> bool: + """Whether decompiler is available""" + return self._send_request({"type": "property_get", "property_name": "decompiler_available"}) + + @property + def default_pointer_size(self) -> int: + """Default pointer size""" + return self._send_request({"type": "property_get", "property_name": "default_pointer_size"}) + + # GUI API methods - delegate to remote decompiler + def gui_active_context(self) -> Optional[Context]: + """Get the active context from the GUI""" + return self._send_request({"type": "method_call", "method_name": "gui_active_context"}) + + def gui_goto(self, func_addr) -> None: + """Go to an address in the GUI""" + return self._send_request({"type": "method_call", "method_name": "gui_goto", "args": [func_addr]}) + + def gui_show_type(self, type_name: str) -> None: + """Show a type in the GUI""" + return self._send_request({"type": "method_call", "method_name": "gui_show_type", "args": [type_name]}) + + def gui_ask_for_string(self, question: str, title: str = "Plugin Question", default: str = "") -> str: + """Ask for a string input""" + return self._send_request({"type": "method_call", "method_name": "gui_ask_for_string", "args": [question, title, default]}) + + def gui_ask_for_choice(self, question: str, choices: list, title: str = "Plugin Question") -> str: + """Ask for a choice from a list""" + return self._send_request({"type": "method_call", "method_name": "gui_ask_for_choice", "args": [question, choices, title]}) + + def gui_popup_text(self, text: str, title: str = "Plugin Message") -> bool: + """Show a popup message""" + return self._send_request({"type": "method_call", "method_name": "gui_popup_text", "args": [text, title]}) + + # Core decompiler API methods - delegate to remote decompiler + def fast_get_function(self, func_addr) -> Optional[Function]: + """Get a light version of a function""" + return self._send_request({"type": "method_call", "method_name": "fast_get_function", "args": [func_addr]}) + + def get_func_size(self, func_addr) -> int: + """Get the size of a function""" + return self._send_request({"type": "method_call", "method_name": "get_func_size", "args": [func_addr]}) + + def decompile(self, addr: int, map_lines=False, **kwargs) -> Optional[Decompilation]: + """Decompile a function""" + return self._send_request({"type": "method_call", "method_name": "decompile", "args": [addr], "kwargs": {"map_lines": map_lines, **kwargs}}) + + def xrefs_to(self, artifact: Artifact, decompile=False, only_code=False) -> List[Artifact]: + """Get cross-references to an artifact""" + return self._send_request({"type": "method_call", "method_name": "xrefs_to", "args": [artifact], "kwargs": {"decompile": decompile, "only_code": only_code}}) + + def xrefs_to_addr(self, addr: int, only_code: bool = False) -> List[Artifact]: + """Get references to a raw address (e.g. a string constant)""" + return self._send_request({"type": "method_call", "method_name": "xrefs_to_addr", "args": [addr], "kwargs": {"only_code": only_code}}) + + def xrefs_from(self, func_addr: int) -> List[Function]: + """Get the callees of a function (what the function calls).""" + return self._send_request({"type": "method_call", "method_name": "xrefs_from", "args": [func_addr]}) + + def get_callers(self, target) -> List[Function]: + """Get callers of a function (by target Function, address, or symbol name)""" + return self._send_request({"type": "method_call", "method_name": "get_callers", "args": [target]}) + + def list_strings(self, filter: Optional[str] = None) -> List: + """List strings in the binary with an optional regex filter""" + return self._send_request({"type": "method_call", "method_name": "list_strings", "kwargs": {"filter": filter}}) + + def disassemble(self, addr: int, **kwargs) -> Optional[str]: + """Disassemble a function""" + return self._send_request({"type": "method_call", "method_name": "disassemble", "args": [addr], "kwargs": kwargs}) + + def read_memory(self, addr: int, size: int) -> Optional[bytes]: + """Read raw bytes from the loaded program.""" + return self._send_request({"type": "method_call", "method_name": "read_memory", "args": [addr, size]}) + + def get_callgraph(self, only_names=False): + """Get the call graph""" + return self._send_request({"type": "method_call", "method_name": "get_callgraph", "kwargs": {"only_names": only_names}}) + + def get_dependencies(self, artifact: Artifact, decompile=True, max_resolves=50, **kwargs) -> List[Artifact]: + """Get dependencies for an artifact""" + return self._send_request({"type": "method_call", "method_name": "get_dependencies", "args": [artifact], "kwargs": {"decompile": decompile, "max_resolves": max_resolves, **kwargs}}) + + def get_func_containing(self, addr: int) -> Optional[Function]: + """Get the function containing an address""" + return self._send_request({"type": "method_call", "method_name": "get_func_containing", "args": [addr]}) + + def get_decompilation_object(self, function: Function, **kwargs): + """Get the decompilation object for a function""" + return self._send_request({"type": "method_call", "method_name": "get_decompilation_object", "args": [function], "kwargs": kwargs}) + + def set_artifact(self, artifact: Artifact, lower=True, **kwargs) -> bool: + """Set an artifact in the decompiler""" + return self._send_request({"type": "method_call", "method_name": "set_artifact", "args": [artifact], "kwargs": {"lower": lower, **kwargs}}) + + def get_defined_type(self, type_str: str): + """Get a defined type by string""" + return self._send_request({"type": "method_call", "method_name": "get_defined_type", "args": [type_str]}) + + # Optional API methods - delegate to remote decompiler + def undo(self) -> None: + """Undo the last operation""" + return self._send_request({"type": "method_call", "method_name": "undo"}) + + def local_variable_names(self, func: Function) -> List[str]: + """Get local variable names for a function""" + return self._send_request({"type": "method_call", "method_name": "local_variable_names", "args": [func]}) + + def rename_local_variables_by_names(self, func: Function, name_map: Dict[str, str], **kwargs) -> bool: + """Rename local variables by name map""" + return self._send_request({"type": "method_call", "method_name": "rename_local_variables_by_names", "args": [func, name_map], "kwargs": kwargs}) + + # Logging methods - delegate to remote decompiler + def print(self, msg: str, **kwargs) -> None: + """Print a message""" + return self._send_request({"type": "method_call", "method_name": "print", "args": [msg], "kwargs": kwargs}) + + def info(self, msg: str, **kwargs) -> None: + """Log an info message""" + return self._send_request({"type": "method_call", "method_name": "info", "args": [msg], "kwargs": kwargs}) + + def debug(self, msg: str, **kwargs) -> None: + """Log a debug message""" + return self._send_request({"type": "method_call", "method_name": "debug", "args": [msg], "kwargs": kwargs}) + + def warning(self, msg: str, **kwargs) -> None: + """Log a warning message""" + return self._send_request({"type": "method_call", "method_name": "warning", "args": [msg], "kwargs": kwargs}) + + def error(self, msg: str, **kwargs) -> None: + """Log an error message""" + return self._send_request({"type": "method_call", "method_name": "error", "args": [msg], "kwargs": kwargs}) + + def _start_event_listener(self) -> None: + """Start the event listener thread to receive callbacks from server""" + if self._event_listener_running: + _l.debug("Event listener already running") + return + + _l.debug("Starting event listener") + + # Create a separate socket connection for receiving events + try: + self._event_socket = self._create_and_connect_socket() + + # Send subscription request to server + SocketProtocol.send_message(self._event_socket, {"type": "subscribe_events"}) + response = SocketProtocol.recv_message(self._event_socket) + + if response.get("status") == "subscribed": + self._subscribed_to_events = True + _l.debug("Successfully subscribed to events") + + # Start event listener thread + self._event_listener_running = True + self._event_listener_thread = threading.Thread( + target=self._event_listener_loop, + daemon=True + ) + self._event_listener_thread.start() + _l.info("Event listener started") + else: + _l.error(f"Failed to subscribe to events: {response}") + self._event_socket.close() + self._event_socket = None + + except Exception as e: + _l.error(f"Failed to start event listener: {e}") + if self._event_socket: + self._event_socket.close() + self._event_socket = None + + def _stop_event_listener(self) -> None: + """Stop the event listener thread""" + if not self._event_listener_running: + _l.debug("Event listener not running") + return + + _l.debug("Stopping event listener") + self._event_listener_running = False + + # Send unsubscribe request + if self._event_socket and self._subscribed_to_events: + try: + SocketProtocol.send_message(self._event_socket, {"type": "unsubscribe_events"}) + except: + pass + + # Close event socket + if self._event_socket: + try: + self._event_socket.close() + except: + pass + self._event_socket = None + + # Wait for thread to finish + if self._event_listener_thread and self._event_listener_thread.is_alive(): + self._event_listener_thread.join(timeout=2.0) + + self._subscribed_to_events = False + _l.info("Event listener stopped") + + def _event_listener_loop(self) -> None: + """Event listener thread loop that receives events from server""" + _l.debug("Event listener loop started") + + try: + while self._event_listener_running: + try: + # Set a timeout so we can periodically check if we should stop + self._event_socket.settimeout(1.0) + event = SocketProtocol.recv_message(self._event_socket) + + # Process the event + self._process_event(event) + + except socket.timeout: + # Normal timeout, continue loop + continue + except ConnectionError as e: + _l.warning(f"Event listener connection error: {e}") + break + except Exception as e: + _l.error(f"Error in event listener loop: {e}") + break + + except Exception as e: + _l.error(f"Fatal error in event listener loop: {e}") + finally: + _l.debug("Event listener loop ended") + self._event_listener_running = False + + def _process_event(self, event: Dict[str, Any]) -> None: + """Process an event received from the server""" + try: + event_type = event.get("event_type") + artifact_data = event.get("artifact") + + if not event_type or not artifact_data: + _l.warning(f"Invalid event received: {event}") + return + + # Reconstruct the artifact from serialized data + if isinstance(artifact_data, dict) and artifact_data.get("is_artifact"): + module_name = artifact_data['module'] + class_name = artifact_data['type'] + serialized_data = artifact_data['data'] + + # Import the module and get the class + module = __import__(module_name, fromlist=[class_name]) + artifact_class = getattr(module, class_name) + + # Reconstruct the artifact + artifact = artifact_class.loads(serialized_data, fmt=_WIRE_FMT) + + # Extract additional kwargs + kwargs = event.get("kwargs", {}) + + # Dispatch to appropriate handler based on event type + if event_type == "comment_changed": + self.comment_changed(artifact, **kwargs) + elif event_type == "function_header_changed": + self.function_header_changed(artifact, **kwargs) + elif event_type == "stack_variable_changed": + self.stack_variable_changed(artifact, **kwargs) + elif event_type == "struct_changed": + self.struct_changed(artifact, **kwargs) + elif event_type == "enum_changed": + self.enum_changed(artifact, **kwargs) + elif event_type == "typedef_changed": + self.typedef_changed(artifact, **kwargs) + elif event_type == "global_variable_changed": + self.global_variable_changed(artifact, **kwargs) + else: + _l.warning(f"Unknown event type: {event_type}") + + except Exception as e: + _l.error(f"Error processing event: {e}") + + # Lifecycle methods + def shutdown(self) -> None: + """Disconnect this client from the server. + + This only tears down the *local* client; the server (and its loaded + decompiler project) keeps running so other clients can still connect. + To actually stop the server, use ``shutdown_server()`` or the + ``decompiler stop`` CLI. + """ + _l.info("DecompilerClient shutting down") + + # Stop event listener first + if self._event_listener_running: + self._stop_event_listener() + + if self._socket: + try: + self._socket.close() + except Exception: + pass + self._connected = False + _l.info("DecompilerClient shut down complete") + + def shutdown_server(self) -> None: + """Ask the server to tear down its decompiler interface, then disconnect. + + Used by CLI commands like ``decompiler stop``. Regular usage should + prefer :meth:`shutdown`, which leaves the server running. + """ + if self._socket: + try: + self._send_request({"type": "shutdown_deci"}) + except Exception: + pass + self.shutdown() + + def is_connected(self) -> bool: + """Check if connected to the server""" + return self._connected and self._socket + + def reconnect(self) -> None: + """Reconnect to the server""" + if self._socket: + self._socket.close() + self._connect() + + def ping(self) -> bool: + """Ping the server to check connectivity""" + try: + self._send_request({"type": "server_info"}) + return True + except Exception: + return False + + # Context manager support + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + + @staticmethod + def discover_from_registry( + server_id: Optional[str] = None, + binary_path: Optional[str] = None, + binary_hash: Optional[str] = None, + backend: Optional[str] = None, + **kwargs, + ) -> 'DecompilerClient': + """ + Find a running server via the shared registry and connect to it. + + Filters narrow the pool in the order: server_id, binary_path, binary_hash, backend. + If no server matches, a ConnectionError is raised. + """ + record = server_registry.find_server( + server_id=server_id, + binary_path=binary_path, + binary_hash=binary_hash, + backend=backend, + ) + if not record: + filters = { + "server_id": server_id, + "binary_path": binary_path, + "binary_hash": binary_hash, + "backend": backend, + } + active = {k: v for k, v in filters.items() if v} + raise ConnectionError( + f"No matching DecompilerServer in registry. Filters: {active or 'none'}." + ) + return DecompilerClient(socket_path=record["socket_path"], **kwargs) + + # Static methods for compatibility + @staticmethod + def discover(server_url: str = None, binary_hash: str = None, **kwargs) -> 'DecompilerClient': + """ + Discover and connect to a DecompilerServer. + + This method provides a similar interface to DecompilerInterface.discover() + but connects to a remote server instead. It intelligently handles: + - Stale socket files from previous server instances + - Multiple running servers + - Binary hash matching to connect to the correct server + + Args: + server_url: URL of the server (e.g., "unix:///tmp/declib_server_abc123/decompiler.sock") + binary_hash: Optional binary hash to match against server's binary_hash + **kwargs: Additional arguments for DecompilerClient constructor + + Returns: + Connected DecompilerClient instance + + Raises: + ConnectionError: If no suitable server is found or connection fails + """ + if server_url: + # Parse server URL + if "://" in server_url: + protocol, path = server_url.split("://", 1) + if protocol != "unix": + _l.warning(f"Expected unix:// protocol, got {protocol}://") + socket_path = path + else: + # Assume it's a direct path + socket_path = server_url + + # If binary_hash is provided, validate it matches + if binary_hash: + try: + client = DecompilerClient(socket_path=socket_path, **kwargs) + server_hash = client.binary_hash + if server_hash != binary_hash: + client.shutdown() + raise ConnectionError( + f"Server at {socket_path} has binary_hash={server_hash}, " + f"but expected {binary_hash}" + ) + return client + except Exception as e: + raise ConnectionError(f"Failed to connect to server at {socket_path}: {e}") + else: + return DecompilerClient(socket_path=socket_path, **kwargs) + else: + # Auto-discovery: find all socket files and try to connect to each + temp_dir = tempfile.gettempdir() + pattern = os.path.join(temp_dir, "declib_server_*/decompiler.sock") + matches = glob.glob(pattern) + + if not matches: + raise ConnectionError("No DecompilerServer found. Start one with: declib --server") + + # Sort by modification time (newest first) to prefer recently started servers + matches.sort(key=lambda p: os.path.getmtime(p), reverse=True) + + _l.debug(f"Found {len(matches)} potential server socket(s)") + + # Try each socket, filtering by binary_hash if provided + successful_connections = [] + for socket_path in matches: + try: + _l.debug(f"Attempting connection to {socket_path}") + test_client = DecompilerClient(socket_path=socket_path, **kwargs) + + # Successfully connected, now check binary_hash if needed + if binary_hash: + try: + server_hash = test_client.binary_hash + if server_hash == binary_hash: + _l.info(f"Auto-discovered server at {socket_path} with matching binary_hash") + return test_client + else: + _l.debug(f"Server at {socket_path} has binary_hash={server_hash}, skipping") + test_client.shutdown() + except Exception as e: + _l.debug(f"Failed to get binary_hash from {socket_path}: {e}") + test_client.shutdown() + else: + # No binary_hash filter, use the first working server + _l.info(f"Auto-discovered server at {socket_path}") + return test_client + + except ConnectionError as e: + # This socket is defunct (server stopped), skip it + _l.debug(f"Failed to connect to {socket_path}: {e}") + continue + except Exception as e: + _l.debug(f"Unexpected error connecting to {socket_path}: {e}") + continue + + # No suitable server found + if binary_hash: + raise ConnectionError( + f"No DecompilerServer found with binary_hash={binary_hash}. " + f"Found {len(matches)} socket(s) but none matched." + ) + else: + raise ConnectionError( + f"No working DecompilerServer found. Found {len(matches)} socket(s) " + f"but all connections failed. Start a new server with: declib --server" + ) + + # Properties that fetch values from server on first access + @property + def supports_undo(self) -> bool: + """Check if the decompiler supports undo operations""" + if self._supports_undo is None: + self._supports_undo = self._send_request({"type": "property_get", "property_name": "supports_undo"}) + return self._supports_undo + + @property + def supports_type_scopes(self) -> bool: + """Check if the decompiler supports type scopes""" + if self._supports_type_scopes is None: + self._supports_type_scopes = self._send_request({"type": "property_get", "property_name": "supports_type_scopes"}) + return self._supports_type_scopes + + @property + def qt_version(self) -> str: + """Get the Qt version used by the decompiler""" + if self._qt_version is None: + self._qt_version = self._send_request({"type": "property_get", "property_name": "qt_version"}) + return self._qt_version + + @property + def default_func_prefix(self) -> str: + """Get the default function prefix used by the decompiler""" + if self._default_func_prefix is None: + self._default_func_prefix = self._send_request({"type": "property_get", "property_name": "default_func_prefix"}) + return self._default_func_prefix + + @property + def headless(self) -> bool: + """Check if the decompiler is running in headless mode""" + if self._headless is None: + self._headless = self._send_request({"type": "property_get", "property_name": "headless"}) + return self._headless + + @property + def force_click_recording(self) -> bool: + """Check if click recording is forced""" + if self._force_click_recording is None: + self._force_click_recording = self._send_request({"type": "property_get", "property_name": "force_click_recording"}) + return self._force_click_recording + + @property + def track_mouse_moves(self) -> bool: + """Check if mouse moves are tracked""" + if self._track_mouse_moves is None: + self._track_mouse_moves = self._send_request({"type": "property_get", "property_name": "track_mouse_moves"}) + return self._track_mouse_moves + + @property + def default_pointer_size(self) -> int: + """Get default pointer size""" + return self._send_request({"type": "property_get", "property_name": "default_pointer_size"}) + + # Artifact watcher methods + def start_artifact_watchers(self) -> None: + """Start artifact watchers on the remote decompiler""" + result = self._send_request({"type": "method_call", "method_name": "start_artifact_watchers"}) + self.artifact_watchers_started = True + + # Start event listener to receive callbacks from server + self._start_event_listener() + + return result + + def stop_artifact_watchers(self) -> None: + """Stop artifact watchers on the remote decompiler""" + # Stop event listener first + self._stop_event_listener() + + result = self._send_request({"type": "method_call", "method_name": "stop_artifact_watchers"}) + self.artifact_watchers_started = False + return result + + def should_watch_artifacts(self) -> bool: + """Check if artifacts should be watched""" + return self._send_request({"type": "method_call", "method_name": "should_watch_artifacts"}) + + # GUI registration methods (stubs since we can't proxy GUI operations) + def gui_register_ctx_menu(self, name: str, action_string: str, callback_func: Callable, category=None) -> bool: + """Register a context menu item (not supported in remote mode)""" + _l.warning("GUI context menu registration is not supported in remote decompiler mode") + return False + + def gui_register_ctx_menu_many(self, actions: dict) -> None: + """Register multiple context menu items (not supported in remote mode)""" + _l.warning("GUI context menu registration is not supported in remote decompiler mode") + + def gui_run_on_main_thread(self, func: Callable, *args, **kwargs): + """Run function on main thread (not supported in remote mode)""" + _l.warning("GUI main thread operations are not supported in remote decompiler mode") + raise NotImplementedError("GUI main thread operations not supported in remote mode") + + def gui_attach_qt_window(self, qt_window, title: str, target_window=None, position=None, *args, **kwargs) -> bool: + """Attach Qt window (not supported in remote mode)""" + _l.warning("GUI window attachment is not supported in remote decompiler mode") + return False + + # Event callback methods (these trigger callbacks locally but don't send to server) + def decompiler_opened_event(self, **kwargs): + """Handle decompiler opened event""" + for callback in self.decompiler_opened_callbacks: + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, kwargs=kwargs) + thread.start() + else: + callback(**kwargs) + except Exception as e: + _l.error(f"Error in decompiler opened callback: {e}") + + def decompiler_closed_event(self, **kwargs): + """Handle decompiler closed event""" + for callback in self.decompiler_closed_callbacks: + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, kwargs=kwargs) + thread.start() + else: + callback(**kwargs) + except Exception as e: + _l.error(f"Error in decompiler closed callback: {e}") + + def gui_undo_event(self, **kwargs): + """Handle GUI undo event""" + for callback in self.undo_event_callbacks: + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, kwargs=kwargs) + thread.start() + else: + callback(**kwargs) + except Exception as e: + _l.error(f"Error in undo event callback: {e}") + + def gui_context_changed(self, ctx: Context, **kwargs) -> Context: + """Handle GUI context changed event""" + # This would typically be handled by GUI callbacks locally + return ctx + + # Artifact change event methods (these handle local callbacks) + def function_header_changed(self, fheader, **kwargs): + """Handle function header changed event""" + for callback in self.artifact_change_callbacks.get(type(fheader), []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(fheader,), kwargs=kwargs) + thread.start() + else: + callback(fheader, **kwargs) + except Exception as e: + _l.error(f"Error in function header change callback: {e}") + return fheader + + def stack_variable_changed(self, svar, **kwargs): + """Handle stack variable changed event""" + for callback in self.artifact_change_callbacks.get(type(svar), []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(svar,), kwargs=kwargs) + thread.start() + else: + callback(svar, **kwargs) + except Exception as e: + _l.error(f"Error in stack variable change callback: {e}") + return svar + + def comment_changed(self, comment: Comment, deleted=False, **kwargs) -> Comment: + """Handle comment changed event""" + kwargs["deleted"] = deleted + for callback in self.artifact_change_callbacks.get(Comment, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(comment,), kwargs=kwargs) + thread.start() + else: + callback(comment, **kwargs) + except Exception as e: + _l.error(f"Error in comment change callback: {e}") + return comment + + def struct_changed(self, struct: Struct, deleted=False, **kwargs) -> Struct: + """Handle struct changed event""" + kwargs["deleted"] = deleted + for callback in self.artifact_change_callbacks.get(Struct, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(struct,), kwargs=kwargs) + thread.start() + else: + callback(struct, **kwargs) + except Exception as e: + _l.error(f"Error in struct change callback: {e}") + return struct + + def enum_changed(self, enum: Enum, deleted=False, **kwargs) -> Enum: + """Handle enum changed event""" + kwargs["deleted"] = deleted + for callback in self.artifact_change_callbacks.get(Enum, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(enum,), kwargs=kwargs) + thread.start() + else: + callback(enum, **kwargs) + except Exception as e: + _l.error(f"Error in enum change callback: {e}") + return enum + + def typedef_changed(self, typedef: Typedef, deleted=False, **kwargs) -> Typedef: + """Handle typedef changed event""" + kwargs["deleted"] = deleted + for callback in self.artifact_change_callbacks.get(Typedef, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(typedef,), kwargs=kwargs) + thread.start() + else: + callback(typedef, **kwargs) + except Exception as e: + _l.error(f"Error in typedef change callback: {e}") + return typedef + + def global_variable_changed(self, gvar: GlobalVariable, **kwargs) -> GlobalVariable: + """Handle global variable changed event""" + for callback in self.artifact_change_callbacks.get(GlobalVariable, []): + try: + if self._thread_artifact_callbacks: + import threading + thread = threading.Thread(target=callback, args=(gvar,), kwargs=kwargs) + thread.start() + else: + callback(gvar, **kwargs) + except Exception as e: + _l.error(f"Error in global variable change callback: {e}") + return gvar \ No newline at end of file diff --git a/declib/api/decompiler_interface.py b/declib/api/decompiler_interface.py new file mode 100644 index 00000000..5616f272 --- /dev/null +++ b/declib/api/decompiler_interface.py @@ -0,0 +1,1261 @@ +import inspect +import logging +import re +import threading +from collections import defaultdict +from functools import wraps +from typing import Dict, Optional, Tuple, List, Callable, Type, Union +from pathlib import Path +import os + +import networkx as nx + +import declib +from declib.api.artifact_lifter import ArtifactLifter +from declib.api.artifact_dict import ArtifactDict +from declib.api.type_parser import CTypeParser, CType +from declib.configuration import DecLibConfig +from declib.artifacts import ( + Artifact, + Function, FunctionHeader, StackVariable, + Comment, GlobalVariable, Patch, Segment, + Enum, Struct, FunctionArgument, Context, Decompilation, Typedef +) +from declib.decompilers import SUPPORTED_DECOMPILERS, ANGR_DECOMPILER, \ + BINJA_DECOMPILER, IDA_DECOMPILER, GHIDRA_DECOMPILER + +_l = logging.getLogger(name=__name__) + + +def requires_decompilation(f): + @wraps(f) + def _requires_decompilation(self, *args, **kwargs): + if self._decompiler_available: + for arg in args: + if isinstance(arg, Function) and arg.dec_obj is None: + arg.dec_obj = self.get_decompilation_object(arg) + + return f(self, *args, **kwargs) + return _requires_decompilation + + +class DecompilerInterface: + def __init__( + self, + # these flags should mostly be unchanged when passed through subclasses + name: str = "generic", + qt_version: str = "PySide6", + default_func_prefix: str = "sub_", + artifact_lifter: Optional[ArtifactLifter] = None, + error_on_artifact_duplicates: bool = False, + decompiler_available: bool = True, + supports_undo: bool = False, + supports_type_scopes: bool = False, + # these flags can be changed by subclassed decis + headless: bool = False, + binary_path: Optional[Union[Path, str]] = None, + init_plugin: bool = False, + plugin_name: str = f"generic_declib_plugin", + config: Optional[DecLibConfig] = None, + # [category/name] = (action_string, callback_func) + gui_ctx_menu_actions: Optional[dict] = None, + gui_init_args: Optional[Tuple] = None, + gui_init_kwargs: Optional[Dict] = None, + # [artifact_class] = list(callback_func) + artifact_change_callbacks: Optional[Dict[Type[Artifact], List[Callable]]] = None, + undo_event_callbacks: Optional[List[Callable]] = None, + decompiler_opened_callbacks: Optional[List[Callable]] = None, + decompiler_closed_callbacks: Optional[List[Callable]] = None, + thread_artifact_callbacks: bool = True, + force_click_recording: bool = False, + track_mouse_moves: bool = False, + **kwargs, + ): + self.name = name + self.art_lifter = artifact_lifter + self.type_parser = CTypeParser() + self.supports_undo = supports_undo + self.supports_type_scopes = supports_type_scopes + self.qt_version = qt_version + self.default_func_prefix = default_func_prefix + self._error_on_artifact_duplicates = error_on_artifact_duplicates + + self.headless = headless + self._binary_path = Path(binary_path) if binary_path else None + self._init_plugin = init_plugin + self._unparsed_gui_ctx_actions: dict[str, tuple[str, Callable]] = gui_ctx_menu_actions or {} + # (category, name, action_string, callback_func) + self._gui_ctx_menu_actions = [] + self._plugin_name = plugin_name + self.gui_plugin = None + self.artifact_watchers_started = False + self.force_click_recording = force_click_recording + self.track_mouse_moves = track_mouse_moves + + # locks + self.artifact_write_lock = threading.Lock() + + # callback functions, keyed by Artifact class + self.artifact_change_callbacks = artifact_change_callbacks or defaultdict(list) + self.undo_event_callbacks = undo_event_callbacks or [] + self.decompiler_opened_callbacks = decompiler_opened_callbacks or [] + self.decompiler_closed_callbacks = decompiler_closed_callbacks or [] + self._thread_artifact_callbacks = thread_artifact_callbacks + + # artifact dict aliases: + # these are the public API for artifacts that are used by the decompiler interface + self.functions = ArtifactDict(Function, self, error_on_duplicate=error_on_artifact_duplicates) + self.comments = ArtifactDict(Comment, self, error_on_duplicate=error_on_artifact_duplicates) + self.patches = ArtifactDict(Patch, self, error_on_duplicate=error_on_artifact_duplicates) + self.global_vars = ArtifactDict(GlobalVariable, self, error_on_duplicate=error_on_artifact_duplicates) + self.segments = ArtifactDict(Segment, self, error_on_duplicate=error_on_artifact_duplicates) + self.structs = ArtifactDict(Struct, self, error_on_duplicate=error_on_artifact_duplicates, scopable=True) + self.enums = ArtifactDict(Enum, self, error_on_duplicate=error_on_artifact_duplicates, scopable=True) + self.typedefs = ArtifactDict(Typedef, self, error_on_duplicate=error_on_artifact_duplicates, scopable=True) + + self._decompiler_available = decompiler_available + # override the file-saved config when one is passed in manually, otherwise + # either load it from the filesystem or create a new one and place it there + self.config = config if config is not None else DecLibConfig.update_or_make() + + if not self.headless: + args = gui_init_args or [] + kwargs = gui_init_kwargs or {} + self._init_gui_components(*args, **kwargs) + else: + self._init_headless_components() + + self.debug(f"Using configuration file: {self.config.save_location}") + self.config.save() + + def _init_headless_components(self, *args, **kwargs): + if not self._binary_path.exists(): + raise FileNotFoundError("You must provide a valid target binary path when using headless mode.") + + def _deinit_headless_components(self): + pass + + def _init_gui_components(self, *args, **kwargs): + from declib.ui.version import set_ui_version + set_ui_version(self.qt_version) + + # register a real plugin in the GUI + if self._init_plugin: + self.gui_plugin = self._init_gui_plugin(*args, **kwargs) + + # parse & register all context menu actions + self.gui_register_ctx_menu_many(self._unparsed_gui_ctx_actions) + + def _init_gui_plugin(self, *args, **kwargs): + return None + + def shutdown(self): + if self.artifact_watchers_started: + self.stop_artifact_watchers() + if self.headless: + self._deinit_headless_components() + + # + # Public API: + # These functions are the main API for interacting with the decompiler. In general, every function that takes + # an Artifact (including addresses) should be in the lifted form. Additionally, every function that returns an + # Artifact should be in the lifted form. This is to ensure that the decompiler interface is always in sync with + # the lifter. For getting and setting artifacts, the ArtifactDicts defined in the init should be used. + # + + # + # GUI API + # + + def gui_active_context(self) -> Optional[declib.artifacts.Context]: + """ + Returns the active location that the user is currently _clicked_ on in the decompiler. + This is returned as a Context object, which can address and screen naming information dependent + on the decompilers exposed data. + """ + raise NotImplementedError + + def gui_goto(self, func_addr) -> None: + """ + Relocates decompiler display to provided address + + @param func_addr: + @return: + """ + raise NotImplementedError + + def gui_show_type(self, type_name: str) -> None: + """ + Relocates decompiler display to type definition + + Does nothing if not implemented in a subclass + """ + pass + + def gui_register_ctx_menu(self, name, action_string, callback_func, category=None, shortcut=None) -> bool: + """ + Register a context menu / plugin action. + + :param name: unique identifier for the action + :param action_string: human-readable label shown in the menu + :param callback_func: function to invoke when the action fires + :param category: optional menu category / sub-path + :param shortcut: optional keyboard shortcut in Qt format (e.g. "Ctrl+Shift+D"). + Implementations translate this to their native format. When the native + decompiler cannot bind a shortcut programmatically, this is a no-op. + """ + raise NotImplementedError + + def gui_ask_for_string(self, question, title="Plugin Question", default="") -> str: + """ + Opens a GUI dialog box that asks the user for a string. If not overriden by the decompiler interface, + this will default to a Qt dialog box that is based on the decompilers Qt version. + """ + from declib.ui.utils import gui_ask_for_string + return gui_ask_for_string(question, title=title, default=default) + + def gui_ask_for_choice(self, question: str, choices: list, title="Plugin Question") -> str: + """ + Opens a GUI dialog box that asks the user for a choice. If not overriden by the decompiler interface, + this will default to a Qt dialog box that is based on the decompilers Qt version. + """ + from declib.ui.utils import gui_ask_for_choice + return gui_ask_for_choice(question, choices, title=title) + + def gui_popup_text(self, text: str, title: str = "Plugin Message") -> bool: + """ + Opens a GUI dialog box that displays a message. If not overriden by the decompiler interface, + this will default to a Qt dialog box that is based on the decompilers Qt version. + """ + from declib.ui.utils import gui_popup_text + return gui_popup_text(text, title=title) + + def gui_run_on_main_thread(self, func: Callable, *args, **kwargs): + """ + Runs the provided function on the main thread of the GUI. This is useful for updating the GUI from a + background thread. Only in Ghidra is this useful. + """ + return func(*args, **kwargs) + + def gui_attach_qt_window(self, qt_window: type["QWidgt"], title: str, target_window=None, position=None, *args, **kwargs) -> bool: + """ + Attaches a Qt window to the decompiler interface. This is useful for embedding custom Qt windows + into the decompiler interface. + """ + raise NotImplementedError + + @staticmethod + def _parse_ctx_menu_actions(actions: dict[str, tuple[str, Callable]]) -> List[Tuple[str, str, str, Callable]]: + gui_ctx_menu_actions = [] + for combined_name, items in actions.items(): + slashes = list(re.finditer("/", combined_name)) + if not slashes: + category = "" + name = combined_name + else: + last_slash = slashes[-1] + category = combined_name[:last_slash.start()] + name = combined_name[last_slash.start()+1:] + + gui_ctx_menu_actions.append((category, name,) + items) + + return gui_ctx_menu_actions + + def gui_register_ctx_menu_many(self, actions: dict[str, tuple[str, Callable]]): + parsed_actions = self._parse_ctx_menu_actions(actions) + for action in parsed_actions: + category, name, action_string, callback_func = action[:4] + shortcut = action[4] if len(action) > 4 else None + self.gui_register_ctx_menu( + name, action_string, callback_func, category=category, shortcut=shortcut + ) + + # + # Override Mandatory API + # + + def start_artifact_watchers(self): + """ + Starts the artifact watchers for the decompiler. This is a special function that is called + by the decompiler interface when the decompiler is ready to start watching for changes in the + decompiler. This is useful for plugins that want to watch for changes in the decompiler and + react to them. + + @return: + """ + self.debug("Starting BinSync artifact watchers...") + self.artifact_watchers_started = True + + def stop_artifact_watchers(self): + """ + Stops the artifact watchers for the decompiler. This is a special function that is called + by the decompiler interface when the decompiler is ready to stop watching for changes in the + decompiler. This is useful for plugins that want to watch for changes in the decompiler and + react to them. + """ + self.debug("Stopping BinSync artifact watchers...") + self.artifact_watchers_started = False + + @property + def binary_base_addr(self) -> int: + """ + Returns the base address of the binary in the decompiler. This is useful for calculating offsets + in the binary. Also mandatory for using the lifting and lowering API. + """ + raise NotImplementedError + + @property + def binary_hash(self) -> str: + """ + Returns a hex string of the currently loaded binary in the decompiler. For most cases, + this will simply be a md5hash of the binary. + + @rtype: hex string + """ + raise NotImplementedError + + @property + def binary_path(self) -> Optional[str]: + """ + Returns a string that is the path of the currently loaded binary. If there is no binary loaded + then None should be returned. + + @rtype: path-like string (/path/to/binary) + """ + return self._binary_path + + def fast_get_function(self, func_addr) -> Optional[Function]: + """ + Attempts to get a light version of the Function at func_addr. + This function implements special logic to be faster than grabbing all light-functions, or grabbing + a decompiled function. Use this API in the case where you may need to get a single functions info + many times in a loop. + + @param func_addr: + @return: + """ + raise NotImplementedError + + def get_func_size(self, func_addr) -> int: + """ + Returns the size of a function + + @param func_addr: + @return: + """ + raise NotImplementedError + + @property + def decompiler_available(self) -> bool: + """ + @return: True if decompiler is available for decompilation, False if otherwise + """ + return True + + def decompile(self, addr: int, map_lines=False, **kwargs) -> Optional[Decompilation]: + lowered_addr = self.art_lifter.lower_addr(addr) + if not self.decompiler_available: + _l.error("Decompiler is not available.") + return None + + sorted_funcs = sorted(self._functions().items(), key=lambda x: x[0]) + func_by_addr = {_addr: func for _addr, func in sorted_funcs} + func = None + if lowered_addr in func_by_addr: + func = func_by_addr[lowered_addr] + else: + _l.debug("Address is not a function start, searching for function...") + for func_addr, _func in sorted_funcs: + if _func.addr <= lowered_addr < (_func.addr + _func.size): + func = _func + break + + if func is None: + self.warning(f"Failed to find function for address {hex(lowered_addr)}") + return None + + try: + decompilation = self._decompile(func, map_lines=map_lines, **kwargs) + except Exception as e: + self.warning(f"Failed to decompile function at {hex(lowered_addr)}: {e}") + decompilation = None + + if decompilation is not None: + decompilation = self.art_lifter.lift(decompilation) + + return decompilation + + def xrefs_to(self, artifact: Artifact, decompile=False, only_code=False) -> List[Artifact]: + """ + Returns a list of artifacts that reference the provided artifact. + @param artifact: Artifact to find references to + @param decompile: If True, decompile the function before searching for xrefs + @return: List of artifacts that reference the provided artifact + """ + if not isinstance(artifact, Function): + raise ValueError("Only functions are supported for xrefs_to") + + return [] + + def xrefs_to_addr(self, addr: int, only_code: bool = False) -> List[Artifact]: + """Return artifacts that reference ``addr``. + + Unlike :meth:`xrefs_to`, which assumes a Function target and therefore + only fires on function entry points, this is a raw "who references + this address?" query. It's what you want after ``list_strings`` finds + a candidate string and you need to know which functions read it. + + The default implementation turns the address into a stub Function and + delegates to :meth:`xrefs_to`; subclasses should override this with a + real data-xref query when their backend exposes one. + + @param addr: Address (lifted) to find references to. + @param only_code: Restrict to code references if the backend supports it. + @return: List of referencing artifacts (typically Function stubs). + """ + return self.xrefs_to(Function(addr, 0), only_code=only_code) + + def xrefs_from(self, func_addr: int) -> List[Function]: + """Return the functions that ``func_addr`` calls (its direct callees). + + The default implementation falls back to get_callgraph + out_edges, + which is expensive because it computes xrefs for every function in + the binary. Subclasses should override with a direct per-function + callee query when their backend exposes one. + """ + try: + cg = self.get_callgraph(only_names=False) + except Exception as exc: + _l.debug("get_callgraph failed: %s", exc) + return [] + callees: List[Function] = [] + seen = set() + for caller, callee in cg.out_edges(nbunch=None): + if getattr(caller, "addr", None) != func_addr: + continue + callee_addr = getattr(callee, "addr", None) + if callee_addr in seen: + continue + seen.add(callee_addr) + callees.append(callee) + return callees + + def get_callers(self, target) -> List[Function]: + """ + Returns a list of Functions that call/reference the provided target. + + @param target: A Function, address (int), or symbol name (str). + @return: List of Function objects whose bodies reference `target`. Each result is a (light) + Function; only its addr (and name when resolvable) are guaranteed to be populated. + """ + func: Optional[Function] = None + if isinstance(target, Function): + func = target + elif isinstance(target, int): + func = self.fast_get_function(target) + if func is None: + func = Function(target, 0) + elif isinstance(target, str): + for addr, light_func in self.functions.items(): + if light_func.name == target: + func = self.fast_get_function(addr) or Function(addr, 0) + break + if func is None: + raise ValueError(f"Unable to locate function named {target!r}") + else: + raise ValueError(f"Unsupported target type for get_callers: {type(target)}") + + callers: List[Function] = [] + seen = set() + for xref in self.xrefs_to(func): + if not isinstance(xref, Function): + continue + if xref.addr in seen: + continue + seen.add(xref.addr) + if not xref.name: + resolved = self.fast_get_function(xref.addr) + if resolved is not None: + xref = resolved + callers.append(xref) + + return callers + + def list_strings(self, filter: Optional[str] = None) -> List[Tuple[int, str]]: + """ + Returns a list of (addr, string) tuples for strings found in the binary. + + This reports **only what the decompiler's own string detector + surfaced** — it is deliberately not a substitute for a full-file + scan. Backend fidelity varies (angr in particular misses most of + ``.rodata``); callers that need an exhaustive list should fall + back to external tools such as ``strings(1)``, ``rabin2 -z``, or + ``readelf -p`` and then use the resulting addresses with the + other APIs (``decompile``, ``xrefs_to_addr``, etc.). + + Subclasses are expected to override this with native, fast string + discovery. The base implementation returns an empty list. + + @param filter: Optional regex string; only strings that match will be returned. + @return: List of (address, string) tuples. + """ + return [] + + def disassemble(self, addr: int, **kwargs) -> Optional[str]: + """ + Returns the disassembly of a function as a single string. + + Subclasses should override this to emit decompiler-native disassembly. The default + implementation returns None. + + @param addr: Address of the function (or any address inside the function). + @return: The disassembly string, or None if unavailable. + """ + return None + + def read_memory(self, addr: int, size: int) -> Optional[bytes]: + """Read ``size`` bytes from the loaded program at ``addr``. + + Returns the raw bytes the backend has for the requested span. ``None`` + means "I couldn't satisfy the read at all" — out-of-range, uninitialized, + or the backend can't reach that memory. A short read (fewer bytes than + requested) is still valid and returned as-is; callers should check + ``len(result)`` if they need an exact count. + + @param addr: Lifted address to start reading from. + @param size: Number of bytes to read. Must be > 0. + @return: Bytes read, or ``None`` if the backend can't read this region. + """ + raise NotImplementedError + + def get_callgraph(self, only_names=False) -> nx.DiGraph: + """ + Returns the callgraph of the binary. This is a dict of function addresses to a list of function addresses + that the function calls. + """ + callgraph = nx.DiGraph() + for func in self.functions.values(): + callers = self.xrefs_to(func) + for caller in callers: + if isinstance(caller, Function): + if only_names: + callgraph.add_edge(caller.name, func.name) + else: + callgraph.add_edge(caller, func) + + return callgraph + + def get_dependencies(self, artifact: Artifact, decompile=True, max_resolves=50, **kwargs) -> List[Artifact]: + if not isinstance(artifact, Function): + raise ValueError("Only functions are supported for get_dependencies") + + # collect all xrefs to the function (for global variables) + if decompile: + # the function was never decompiled + if artifact.dec_obj is None: + # TODO: this needs to be fixed so that it still works without redecompiling. What if we want + # to do analysis on a function that is not set yet. + artifact = self.functions[artifact.addr] + + art_users = self.xrefs_to(artifact, decompile=decompile) + gvars = [art for art in art_users if isinstance(art, GlobalVariable)] + + # collect all structs/enums used in the function types + imported_types = set() + imported_types.add(self.get_defined_type(artifact.header.type)) + for arg in artifact.header.args.values(): + imported_types.add(self.get_defined_type(arg.type)) + for svar in artifact.stack_vars.values(): + imported_types.add(self.get_defined_type(svar.type)) + + # start resolving dependencies in structs + for _ in range(max_resolves): + new_imports = False + for imported_type in list(imported_types): + if isinstance(imported_type, Struct): + for member in imported_type.members.values(): + new_type = self.get_defined_type(member.type) + if new_type is not None and new_type not in imported_types: + imported_types.add(new_type) + new_imports = True + break + + if new_imports: + break + + if isinstance(imported_type, Typedef): + new_type = self.get_defined_type(imported_type.type) + if new_type is not None and new_type not in imported_types: + imported_types.add(new_type) + new_imports = True + + if not new_imports: + break + else: + self.warning("Max dependency resolves reached, returning partial results") + + all_deps = [art for art in list(imported_types) + gvars if art is not None] + return all_deps + + def get_func_containing(self, addr: int) -> Optional[Function]: + raise NotImplementedError + + def _decompile(self, function: Function, map_lines=False, **kwargs) -> Optional[Decompilation]: + raise NotImplementedError + + def get_decompilation_object(self, function: Function, **kwargs) -> Optional[object]: + raise NotImplementedError + + def should_watch_artifacts(self) -> bool: + return True + + # + # Override Optional API: + # These are API that provide extra introspection for plugins that may rely on DecLib Interface + # + + @property + def binary_arch(self) -> str: + """ + Returns a string of the currently loaded binary's architecture. + """ + raise NotImplementedError + + @property + def default_pointer_size(self) -> int: + """ + Returns the default pointer size of the binary. This is useful for calculating offsets + in the binary. + """ + raise NotImplementedError + + def undo(self): + """ + Undoes the last change made to the decompiler. + """ + raise NotImplementedError + + def local_variable_names(self, func: Function) -> List[str]: + """ + Returns a list of local variable names for a function. Note, these also include register variables + that are normally not liftable in DecLib. + @param func: Function to get local variable names for + @return: List of local variable names + """ + return [] + + def rename_local_variables_by_names(self, func: Function, name_map: Dict[str, str], **kwargs) -> bool: + """ + Renames local variables in a function by a name map. Note, these also include register variables + that are normally not liftable in DecLib. + @param func: Function to rename local variables in + @param name_map: Dictionary of old name to new name + @return: True if any local variables were renamed, False if otherwise + """ + return False + + # + # Private Artifact API: + # Unlike the public API, every function in this section should take and return artifacts in their native (lowered) + # form. + # + + # functions + def _set_function(self, func: Function, **kwargs) -> bool: + update = False + header = func.header + if header is not None: + update |= self._set_function_header(header, **kwargs) + + if func.stack_vars: + update |= self._set_stack_variables(list(func.stack_vars.values()), **kwargs) + + return update + + def _get_function(self, addr, **kwargs) -> Optional[Function]: + return None + + def _del_function(self, addr, **kwargs) -> bool: + return False + + def _functions(self) -> Dict[int, Function]: + """ + Returns a dict of declib.Functions that contain the addr, name, and size of each function in the decompiler. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return {} + + # stack vars + def _set_stack_variables(self, svars: List[StackVariable], **kwargs) -> bool: + update = False + for svar in svars: + update |= self._set_stack_variable(svar, **kwargs) + + return update + + def _set_stack_variable(self, svar: StackVariable, **kwargs) -> bool: + return False + + def _get_stack_variable(self, addr: int, offset: int, **kwargs) -> Optional[StackVariable]: + func = self._get_function(addr, **kwargs) + if func is None: + return None + + return func.stack_vars.get(offset, None) + + def _del_stack_variable(self, addr: int, offset: int, **kwargs) -> bool: + return False + + def _stack_variables(self, **kwargs) -> Dict[int,Dict[int, StackVariable]]: + stack_vars = defaultdict(dict) + for addr in self._functions(): + func = self._get_function(addr, **kwargs) + for svar in func.stack_vars.values(): + stack_vars[addr][svar.offset] = svar + + return dict(stack_vars) + + # global variables + def _set_global_variable(self, gvar: GlobalVariable, **kwargs) -> bool: + return False + + def _get_global_var(self, addr) -> Optional[GlobalVariable]: + return None + + def _del_global_var(self, addr) -> bool: + return False + + def _global_vars(self, **kwargs) -> Dict[int, GlobalVariable]: + """ + Returns a dict of declib.GlobalVariable that contain the addr and size of each global var. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return {} + + # structs + def _set_struct(self, struct: Struct, header=True, members=True, **kwargs) -> bool: + return False + + def _get_struct(self, name) -> Optional[Struct]: + return None + + def _del_struct(self, name) -> bool: + return False + + def _structs(self) -> Dict[str, Struct]: + """ + Returns a dict of declib.Structs that contain the name and size of each struct in the decompiler. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return {} + + # enums + def _set_enum(self, enum: Enum, **kwargs) -> bool: + return False + + def _get_enum(self, name) -> Optional[Enum]: + return None + + def _del_enum(self, name) -> bool: + return False + + def _enums(self) -> Dict[str, Enum]: + """ + Returns a dict of declib.Enum that contain the name of the enums in the decompiler. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return {} + + # typedefs + def _set_typedef(self, typedef: Typedef, **kwargs) -> bool: + return False + + def _get_typedef(self, name) -> Optional[Typedef]: + return None + + def _del_typedef(self, name) -> bool: + return False + + def _typedefs(self) -> Dict[str, Typedef]: + """ + Returns a dict of declib.Typedef that contain the name of the typedefs in the decompiler. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return {} + + # patches + def _set_patch(self, patch: Patch, **kwargs) -> bool: + return False + + def _get_patch(self, addr) -> Optional[Patch]: + return None + + def _del_patch(self, addr) -> bool: + return False + + def _patches(self) -> Dict[int, Patch]: + """ + Returns a dict of declib.Patch that contain the addr of each Patch and the bytes. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return {} + + # comments + def _set_comment(self, comment: Comment, **kwargs) -> bool: + return False + + def _get_comment(self, addr) -> Optional[Comment]: + return None + + def _del_comment(self, addr) -> bool: + return False + + def _comments(self) -> Dict[int, Comment]: + return {} + + # segments + def _set_segment(self, segment: Segment, **kwargs) -> bool: + return False + + def _get_segment(self, name) -> Optional[Segment]: + return None + + def _del_segment(self, name) -> bool: + return False + + def _segments(self) -> Dict[str, Segment]: + """ + Returns a dict of declib.Segment that contain the name, start_addr, end_addr, and permissions of each segment. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return {} + + # others... + def _set_function_header(self, fheader: FunctionHeader, **kwargs) -> bool: + return False + + # + # Change Callback API: + # Every callback in this group assumes the input will be decompiler-specific (lowered) and will + # lift it ONCE inside this function. Each one will return the lifted form, for easier overriding. + # + + def decompiler_opened_event(self, **kwargs): + """ + This function is called when the decompiler platform this interface is running on is opened for the first time. + In the presence of a decompiler with multiple tabs, this function will still only be called once. + """ + for callback_func in self.decompiler_opened_callbacks: + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, kwargs=kwargs, daemon=True).start() + else: + callback_func(**kwargs) + + def decompiler_closed_event(self, **kwargs): + """ + This function is called when the decompiler platform this interface is running on is closing/closed. + In the presence of a decompiler with multiple tabs, this function will still only be called once. + """ + for callback_func in self.decompiler_closed_callbacks: + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, kwargs=kwargs, daemon=True).start() + else: + callback_func(**kwargs) + + def gui_undo_event(self, **kwargs): + for callback_func in self.undo_event_callbacks: + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, kwargs=kwargs, daemon=True).start() + else: + callback_func(**kwargs) + + def gui_context_changed(self, ctx: Context, **kwargs) -> declib.artifacts.Context: + # XXX: should this be lifted? + for callback_func in self.artifact_change_callbacks[Context]: + args = (ctx,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return ctx + + def segment_changed(self, segment: Segment, **kwargs) -> Segment: + lifted_segment = self.art_lifter.lift(segment) + for callback_func in self.artifact_change_callbacks[Segment]: + args = (lifted_segment,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return lifted_segment + + def function_header_changed(self, fheader: FunctionHeader, **kwargs) -> FunctionHeader: + lifted_fheader = self.art_lifter.lift(fheader) + for callback_func in self.artifact_change_callbacks[FunctionHeader]: + args = (lifted_fheader,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return lifted_fheader + + def stack_variable_changed(self, svar: StackVariable, **kwargs) -> StackVariable: + lifted_svar = self.art_lifter.lift(svar) + for callback_func in self.artifact_change_callbacks[StackVariable]: + args = (lifted_svar,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return lifted_svar + + def comment_changed(self, comment: Comment, deleted=False, **kwargs) -> Comment: + kwargs["deleted"] = deleted + lifted_cmt = self.art_lifter.lift(comment) + for callback_func in self.artifact_change_callbacks[Comment]: + args = (lifted_cmt,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return lifted_cmt + + def struct_changed(self, struct: Struct, deleted=False, **kwargs) -> Struct: + kwargs["deleted"] = deleted + lifted_struct = self.art_lifter.lift(struct) + for callback_func in self.artifact_change_callbacks[Struct]: + args = (lifted_struct,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return lifted_struct + + def decompilation_changed(self, decompilation: Decompilation, **kwargs) -> Decompilation: + lifted_dcmp = self.art_lifter.lift(decompilation) + for callback_func in self.artifact_change_callbacks[Decompilation]: + args = (lifted_dcmp,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return lifted_dcmp + + def enum_changed(self, enum: Enum, deleted=False, **kwargs) -> Enum: + kwargs["deleted"] = deleted + lifted_enum = self.art_lifter.lift(enum) + for callback_func in self.artifact_change_callbacks[Enum]: + args = (lifted_enum,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return lifted_enum + + def typedef_changed(self, typedef: Typedef, deleted=False, **kwargs) -> Typedef: + kwargs["deleted"] = deleted + lifted_typedef = self.art_lifter.lift(typedef) + for callback_func in self.artifact_change_callbacks[Typedef]: + args = (lifted_typedef,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return lifted_typedef + + def global_variable_changed(self, gvar: GlobalVariable, **kwargs) -> GlobalVariable: + lifted_gvar = self.art_lifter.lift(gvar) + for callback_func in self.artifact_change_callbacks[GlobalVariable]: + args = (lifted_gvar,) + if self._thread_artifact_callbacks: + threading.Thread(target=callback_func, args=args, kwargs=kwargs, daemon=True).start() + else: + callback_func(*args, **kwargs) + + return lifted_gvar + + # + # Special Loggers and Printers + # + + def print(self, msg: str, **kwargs): + print(msg) + + def info(self, msg: str, **kwargs): + _l.info(msg) + + def debug(self, msg: str, **kwargs): + _l.debug(msg) + + def warning(self, msg: str, **kwargs): + _l.warning(msg) + + def error(self, msg: str, **kwargs): + _l.error(msg) + + # + # Utils + # + + def set_artifact(self, artifact: Artifact, lower=True, **kwargs) -> bool: + """ + Sets a declib Artifact into the decompilers local database. This operations allows you to change + what the native decompiler sees with declib Artifacts. This is different from opertions on a declib State, + since this is native to the decompiler + + >>> func = Function(0xdeadbeef, 0x800) + >>> func.name = "main" + >>> deci.set_artifact(func) + + @param artifact: + @param lower: Wether to convert the Artifacts types and offset into the local decompilers format + @return: True if the Artifact was succesfuly set into the decompiler + """ + set_map = { + Function: self._set_function, + FunctionHeader: self._set_function_header, + StackVariable: self._set_stack_variable, + Comment: self._set_comment, + GlobalVariable: self._set_global_variable, + Struct: self._set_struct, + Enum: self._set_enum, + Patch: self._set_patch, + Segment: self._set_segment, + Artifact: None, + } + + if lower: + artifact = self.art_lifter.lower(artifact) + + setter = set_map.get(type(artifact), None) + if setter is None: + _l.critical("Unsupported object is attempting to be set, please check your object: %s", artifact) + return False + + return setter(artifact, **kwargs) + + @staticmethod + def get_identifiers(artifact: Artifact) -> Tuple: + if isinstance(artifact, (Function, FunctionHeader, GlobalVariable, Patch, Comment)): + return (artifact.addr,) + elif isinstance(artifact, StackVariable): + return artifact.addr, artifact.offset + elif isinstance(artifact, FunctionArgument): + # TODO: add addr to function arguments + return (artifact.offset,) + elif isinstance(artifact, (Struct, Enum, Typedef, Segment)): + return (artifact.name,) + else: + raise ValueError(f"Unsupported artifact type: {type(artifact)}") + + def get_defined_type(self, type_str) -> Optional[Artifact]: + if not type_str: + return None + + normalized_type, scope = self.art_lifter.parse_scoped_type(type_str) + type_: CType = self.type_parser.parse_type(normalized_type) + if not type_: + # it was not parseable + return None + + # type is a primitive that returns no base type + base_type = type_.base_type + if base_type is None: + return None + + # if we trigger here, it means it's not a user-defined type + if not base_type.is_unknown: + return None + + base_type_str = base_type.type + lifted_scoped_type = self.art_lifter.scoped_type_to_str(base_type_str, scope) + if lifted_scoped_type in self.structs: + return self.structs[lifted_scoped_type] + elif lifted_scoped_type in self.enums: + return self.enums[lifted_scoped_type] + elif lifted_scoped_type in self.typedefs: + return self.typedefs[lifted_scoped_type] + else: + return None + + @staticmethod + def _find_global_in_call_frames(global_name, max_frames=10): + curr_frame = inspect.currentframe() + outer_frames = inspect.getouterframes(curr_frame, max_frames) + for frame in outer_frames: + global_data = frame.frame.f_globals.get(global_name, None) + if global_data is not None: + return global_data + else: + return None + + @staticmethod + def find_current_decompiler(force: str = None) -> Optional[str]: + """ + Finds the name of the current decompiler that this function is running inside of. Note, this function + does not create an interface, but instead finds the name of the decompiler that is currently running. + """ + available = set() + + # Binary Ninja + # this check needs to be done last since there is no way to traverse the stack frame to find the correct + # BV at this point in time. + try: + import binaryninja + has_bn_ui = False + try: + import binaryninjaui + has_bn_ui = True + except Exception: + pass + + if has_bn_ui: + return BINJA_DECOMPILER + available.add(BINJA_DECOMPILER) + # error can be thrown for an invalid license + except Exception as e: + if "License is not valid" in str(e): + _l.warning("Binary Ninja license is invalid, skipping...") + + # Ghidra + this_obj = DecompilerInterface._find_global_in_call_frames("__this__") + if (this_obj is not None) and (hasattr(this_obj, "currentProgram")): + available.add(GHIDRA_DECOMPILER) + if not force: + return GHIDRA_DECOMPILER + + # angr-management + try: + import angr + available.add(ANGR_DECOMPILER) + import angrmanagement + if DecompilerInterface._find_global_in_call_frames('workspace') is not None: + return ANGR_DECOMPILER + except Exception: + pass + + # IDA Pro + try: + import idaapi + available.add(IDA_DECOMPILER) + if not force: + return IDA_DECOMPILER + except Exception: + pass + + try: + # for IDA 9 Beta + import ida + available.add(IDA_DECOMPILER) + except ImportError: + pass + try: + # for IDA 9+ + import idapro + available.add(IDA_DECOMPILER) + except Exception: + pass + + if not available: + _l.critical("DecLib was unable to find the current decompiler you are running in or any headless instances!") + return None + + if force is not None and force not in available: + _l.critical("DecLib was unable to force the decompiler you requested... please check your environment.") + return None + + if force is None: + return available.pop() + + if force in available: + return force + + return None + + @staticmethod + def discover( + force_decompiler: str = None, + interface_overrides: Optional[Dict[str, "DecompilerInterface"]] = None, + **interface_kwargs + ) -> Optional["DecompilerInterface"]: + """ + This function is a special API helper that will attempt to detect the decompiler it is running in and + return the valid DLController for that decompiler. You may also force the chosen deci. + + @param force_decompiler: The optional string used to force a specific decompiler interface + @param interface_overrides: The optional dict used to override the class of a decompiler interface + @return: The DecompilerInterface associated with the current decompiler env + """ + if force_decompiler and force_decompiler not in SUPPORTED_DECOMPILERS: + raise ValueError(f"Unsupported decompiler {force_decompiler}") + + if force_decompiler: + if force_decompiler not in SUPPORTED_DECOMPILERS: + raise ValueError(f"Unsupported decompiler {force_decompiler}, please use one of {SUPPORTED_DECOMPILERS}") + current_decompiler = force_decompiler + else: + current_decompiler = DecompilerInterface.find_current_decompiler(force=force_decompiler) + + # `project_dir` is a user-facing kwarg that translates to the + # backend-specific cache/project location. Backends without a concept + # of this simply ignore it. + project_dir = interface_kwargs.pop("project_dir", None) + + if current_decompiler == IDA_DECOMPILER: + from declib.decompilers.ida.interface import IDAInterface + deci_class = IDAInterface + extra_kwargs = {} + if project_dir: + extra_kwargs["project_dir"] = project_dir + elif current_decompiler == BINJA_DECOMPILER: + from declib.decompilers.binja.interface import BinjaInterface + deci_class = BinjaInterface + extra_kwargs = {"bv": DecompilerInterface._find_global_in_call_frames('bv')} + elif current_decompiler == ANGR_DECOMPILER: + from declib.decompilers.angr.interface import AngrInterface + deci_class = AngrInterface + extra_kwargs = {"workspace": DecompilerInterface._find_global_in_call_frames('workspace')} + elif current_decompiler == GHIDRA_DECOMPILER: + from declib.decompilers.ghidra.interface import GhidraDecompilerInterface + deci_class = GhidraDecompilerInterface + extra_kwargs = {"flat_api": DecompilerInterface._find_global_in_call_frames('__this__')} + if project_dir: + extra_kwargs["project_location"] = project_dir + else: + raise ValueError("Please use DecLib with our supported decompiler set!") + + if interface_overrides is not None and current_decompiler in interface_overrides: + deci_class = interface_overrides[current_decompiler] + + interface_kwargs.update(extra_kwargs) + return deci_class(**interface_kwargs) diff --git a/declib/api/decompiler_server.py b/declib/api/decompiler_server.py new file mode 100644 index 00000000..e6407ae1 --- /dev/null +++ b/declib/api/decompiler_server.py @@ -0,0 +1,782 @@ +# Note to reader: most of this code was generated by Claude 4.5. It may contain errors and was designed +# in tandem with decompiler_client.py and the tests/test_client_server.py file. This comment will be +# removed when the majority of the file is owned by a human author. + +import logging +import pickle +import queue +import socket +import struct +import threading +import time +import tempfile +import os +from typing import Optional, Dict, Any, List + +from declib.api.decompiler_interface import DecompilerInterface +from declib.api import server_registry +from declib.artifacts.formatting import ArtifactFormat + +_l = logging.getLogger(__name__) + +# JSON, not TOML: the `toml` package's encoder mangles raw `\x` escapes, +# which show up in decompilation text for C char literals like `'\x01'`. +_WIRE_FMT = ArtifactFormat.JSON + +# Sentinel used to poke the main-thread dispatcher awake on shutdown. +_MAIN_THREAD_SHUTDOWN = object() + + +class _MainThreadError: + """Wrap exceptions that occurred on the main thread so the waiting client + thread can re-raise them after receiving the result.""" + + __slots__ = ("exc",) + + def __init__(self, exc: BaseException): + self.exc = exc + + +class SocketProtocol: + """Helper class for socket protocol message framing""" + + @staticmethod + def send_message(sock: socket.socket, data: Any) -> None: + """Send a pickled message with length prefix""" + try: + pickled_data = pickle.dumps(data) + msg_len = len(pickled_data) + + # Send 4-byte length prefix + sock.sendall(struct.pack('!I', msg_len)) + # Send pickled data + sock.sendall(pickled_data) + except (ConnectionError, BrokenPipeError, OSError) as e: + # Expected during shutdown when socket is closed, just re-raise + raise + except Exception as e: + # Unexpected error - log it + _l.error(f"Failed to send message (pickle.dumps): {e}") + _l.error(f"Data type: {type(data)}") + if hasattr(data, '__dict__'): + _l.error(f"Data dict: {data.__dict__}") + raise + + @staticmethod + def recv_message(sock: socket.socket) -> Any: + """Receive a pickled message with length prefix""" + pickled_data = b'' + try: + # Receive 4-byte length prefix + len_data = sock.recv(4) + if len(len_data) != 4: + raise ConnectionError("Failed to receive message length") + + msg_len = struct.unpack('!I', len_data)[0] + + # Receive the pickled data + while len(pickled_data) < msg_len: + chunk = sock.recv(msg_len - len(pickled_data)) + if not chunk: + raise ConnectionError("Connection closed while receiving message") + pickled_data += chunk + + return pickle.loads(pickled_data) + except (ConnectionError, socket.timeout): + # Expected during shutdown or normal timeout, just re-raise without logging + raise + except Exception as e: + # Unexpected error - log it + _l.error(f"Failed to receive message (pickle.loads): {e}") + if pickled_data: + _l.error(f"Received {len(pickled_data)} bytes of pickle data") + raise + + +class SocketServerHandler: + """Handler for individual client connections""" + + def __init__(self, deci: DecompilerInterface, server: 'DecompilerServer' = None): + self.deci = deci + self.server = server + self._light_caches = {} + self._cache_lock = threading.Lock() + self._cache_ttl = 10.0 + + def _dispatch(self, func, *args, **kwargs): + """Call ``func`` either directly or via the server's main-thread queue. + + Backends like IDA reject cross-thread API access, so the server + declares ``_requires_main_thread`` and we route everything through + its dispatcher. For thread-safe backends (ghidra headless, angr, + binja) we short-circuit to a direct call. + """ + if self.server is None or not self.server.requires_main_thread: + return func(*args, **kwargs) + return self.server.run_on_main_thread(func, *args, **kwargs) + + def handle_client(self, client_socket: socket.socket, addr: str): + """Handle a client connection""" + _l.info(f"Client connected: {addr}") + + try: + while True: + try: + request = SocketProtocol.recv_message(client_socket) + response = self._process_request(request, client_socket=client_socket) + SocketProtocol.send_message(client_socket, response) + except ConnectionError: + # Client disconnected + break + except Exception as e: + # Send error response + error_response = {"error": str(e), "type": type(e).__name__} + try: + SocketProtocol.send_message(client_socket, error_response) + except: + break + finally: + # Remove from event subscribers if subscribed + if self.server: + with self.server._event_subscribers_lock: + if client_socket in self.server._event_subscribers: + self.server._event_subscribers.remove(client_socket) + _l.debug("Removed client from event subscribers") + + client_socket.close() + _l.info(f"Client disconnected: {addr}") + + def _process_request(self, request: Dict[str, Any], client_socket: socket.socket = None) -> Any: + """Process a client request and return response""" + request_type = request.get("type") + + if request_type == "subscribe_events": + # Client wants to subscribe to artifact change events + if self.server and client_socket: + with self.server._event_subscribers_lock: + if client_socket not in self.server._event_subscribers: + self.server._event_subscribers.append(client_socket) + _l.info(f"Client subscribed to events (total subscribers: {len(self.server._event_subscribers)})") + return {"status": "subscribed"} + else: + return {"status": "error", "message": "Server not available"} + + elif request_type == "unsubscribe_events": + # Client wants to unsubscribe from events + if self.server and client_socket: + with self.server._event_subscribers_lock: + if client_socket in self.server._event_subscribers: + self.server._event_subscribers.remove(client_socket) + _l.info(f"Client unsubscribed from events (total subscribers: {len(self.server._event_subscribers)})") + return {"status": "unsubscribed"} + else: + return {"status": "error", "message": "Server not available"} + + elif request_type == "server_info": + # Return the metadata cached by the server at init time. Reading + # ``deci.binary_hash`` or ``deci.binary_path`` here would re-enter + # the backend from a worker thread — which IDA/idalib rejects + # with "Function can be called from the main thread only". + if self.server is not None and self.server._cached_server_info is not None: + return dict(self.server._cached_server_info) + return { + "name": "DecLib DecompilerServer (AF_UNIX)", + "version": "3.0.0", + "decompiler": self.deci.name if self.deci else "unknown", + "protocol": "unix_socket", + "binary_hash": None, + "binary_path": None, + "server_id": self.server.server_id if self.server else None, + } + + elif request_type == "get_light_artifacts": + collection_name = request.get("collection_name") + return self._get_light_artifacts(collection_name) + + elif request_type == "get_full_artifact": + collection_name = request.get("collection_name") + key = request.get("key") + + def _fetch_full_artifact(): + collection = getattr(self.deci, collection_name) + return collection[key] + + artifact = self._dispatch(_fetch_full_artifact) + + # Serialize the full artifact safely + if hasattr(artifact, 'dumps') and hasattr(artifact, '__class__'): + try: + return { + 'type': artifact.__class__.__name__, + 'module': artifact.__class__.__module__, + 'data': artifact.dumps(fmt=_WIRE_FMT), + 'is_artifact': True + } + except Exception as e: + _l.warning(f"Failed to serialize full artifact: {e}") + # Fall back to direct return, which might fail with pickle + return artifact + else: + return artifact + + elif request_type == "method_call": + method_name = request.get("method_name") + args = request.get("args", []) + kwargs = request.get("kwargs", {}) + + # Handle dotted method names like "art_lifter.lift" + if "." in method_name: + obj = self.deci + for attr in method_name.split("."): + obj = getattr(obj, attr) + method = obj + else: + # Get the method from the decompiler interface + method = getattr(self.deci, method_name) + result = self._dispatch(method, *args, **kwargs) + + # Check if result is an artifact and serialize it properly + if hasattr(result, 'dumps') and hasattr(result, '__class__'): + # This looks like an artifact, serialize it safely + try: + return { + 'type': result.__class__.__name__, + 'module': result.__class__.__module__, + 'data': result.dumps(fmt=_WIRE_FMT), + 'is_artifact': True + } + except Exception as e: + _l.warning(f"Failed to serialize result artifact: {e}") + # Fall back to direct return, which might fail with pickle + return result + else: + # Not an artifact, return as-is + return result + + elif request_type == "property_get": + property_name = request.get("property_name") + return self._dispatch(lambda: getattr(self.deci, property_name)) + + elif request_type == "shutdown_deci": + if self.deci and self.server is not None and not self.server._deci_shutdown_done: + # Route through the main-thread dispatcher for IDA — calling + # idapro.close_database() from a worker thread raises + # "Function can be called from the main thread only". + self._dispatch(self.deci.shutdown) + self.server._deci_shutdown_done = True + return {"status": "shutdown"} + + elif request_type == "shutdown_server": + # Tear the server down asynchronously so we can still reply. + if self.server is not None: + threading.Thread( + target=self.server.stop, name="declib-server-shutdown", daemon=True + ).start() + return {"status": "stopping"} + + else: + raise ValueError(f"Unknown request type: {request_type}") + + def _get_light_artifacts(self, collection_name: str) -> Dict: + """Get light artifacts for a collection, computing and caching on first request""" + with self._cache_lock: + cache_entry = self._light_caches.get(collection_name) + + # Check if we have a valid cache entry + if cache_entry and time.time() - cache_entry["timestamp"] < self._cache_ttl: + return cache_entry["items"] + + # Cache miss or stale - compute light artifacts on-demand + _l.debug(f"Computing light artifacts for {collection_name} on-demand") + try: + collection = getattr(self.deci, collection_name) + if hasattr(collection, '_lifted_art_lister'): + start_time = time.time() + light_items = self._dispatch(collection._lifted_art_lister) + end_time = time.time() + + # Convert artifacts to serializable format using their own serialization + serializable_items = {} + for addr, artifact in light_items.items(): + try: + # Use the artifact's built-in serialization which handles complex objects + serialized = artifact.dumps(fmt=_WIRE_FMT) + # Store as a tuple of (type_name, serialized_data) for reconstruction + serializable_items[addr] = { + 'type': artifact.__class__.__name__, + 'module': artifact.__class__.__module__, + 'data': serialized + } + except Exception as e: + _l.warning(f"Failed to serialize {artifact.__class__.__name__} at 0x{addr:x}: {e}") + # Skip problematic artifacts rather than failing completely + continue + + # Cache the serializable artifacts + self._light_caches[collection_name] = { + "items": serializable_items, + "timestamp": time.time() + } + + _l.info(f"Computed {len(serializable_items)} light {collection_name} in {end_time - start_time:.3f}s") + return serializable_items + else: + _l.warning(f"Collection {collection_name} does not support light artifacts") + return {} + + except Exception as e: + _l.warning(f"Failed to compute light artifacts for {collection_name}: {e}") + # Return stale cache if available, otherwise empty dict + if cache_entry: + _l.debug(f"Returning stale cache for {collection_name} due to error") + return cache_entry["items"] + return {} + + +class DecompilerServer: + """ + A server that exposes DecompilerInterface APIs over AF_UNIX sockets. + + This server wraps a DecompilerInterface instance and provides network access + to all its public methods and artifact collections through AF_UNIX sockets. + """ + + def __init__(self, + decompiler_interface: Optional[DecompilerInterface] = None, + socket_path: Optional[str] = None, + server_id: Optional[str] = None, + register: bool = True, + **interface_kwargs): + """ + Initialize the DecompilerServer. + + Args: + decompiler_interface: An existing DecompilerInterface instance. If None, + one will be created using DecompilerInterface.discover() + socket_path: Path for the AF_UNIX socket. If None, a path is derived from server_id. + server_id: Optional explicit server ID. If None, a new one is generated. + register: If True, write the server info into the shared registry. + **interface_kwargs: Arguments passed to DecompilerInterface.discover() if + decompiler_interface is None + """ + + self.server_id = server_id or server_registry.new_server_id() + self.socket_path = socket_path + self._register = register + self._registered = False + self._server_socket = None + self._server_thread = None + self._running = False + self._clients = [] + self._client_threads = [] + + # Main-thread dispatch: some backends (notably IDA/idalib) reject + # cross-thread API access. For those we route backend calls through + # a queue so they run on the thread that set the backend up. + self._main_thread_queue: "queue.Queue" = queue.Queue() + self._main_thread_ident: Optional[int] = None + + # Event subscription tracking + self._event_subscribers = [] # List of sockets subscribed to events + self._event_subscribers_lock = threading.Lock() + + # Track whether deci.shutdown() already ran, so teardown is idempotent + # across the worker-initiated stop() and the main-thread __exit__. + self._deci_shutdown_done = False + + # Initialize the decompiler interface + if decompiler_interface is not None: + self.deci = decompiler_interface + else: + if interface_kwargs and interface_kwargs.get("headless", False): + forced_decompiler = interface_kwargs.get("force_decompiler", None) + if forced_decompiler is None: + _l.warning(f"Using a headless interface without setting a decompiler has unpredictable behavior!") + _l.info(f"Using headless interface utilizing %s", forced_decompiler) + else: + _l.info("Discovering decompiler interface...") + + self.deci = DecompilerInterface.discover(**interface_kwargs) + if self.deci is None: + raise RuntimeError("Failed to discover decompiler interface") + + # Cache static metadata on the *main* thread so that the connection + # handshake (`server_info`) never touches the backend from a worker + # thread — IDA/idalib raises "Function can be called from the main + # thread only" the moment such access happens. + self._cached_server_info = self._build_static_server_info() + + # Create socket handler + self.handler = SocketServerHandler(self.deci, server=self) + + # Register artifact change callbacks to broadcast events + self._register_artifact_callbacks() + + # Generate socket path if not provided + if self.socket_path is None: + socket_path = server_registry.default_socket_path(self.server_id) + self.socket_path = socket_path + self._temp_dir = os.path.dirname(socket_path) + else: + self._temp_dir = None + + _l.info(f"DecompilerServer initialized with {self.deci.name} interface (id={self.server_id})") + _l.info(f"Socket path: {self.socket_path}") + + def _build_static_server_info(self) -> Dict[str, Any]: + """Collect immutable server metadata on whatever thread calls us. + + This runs from ``__init__`` — i.e. the thread that constructed the + deci (the main thread in the CLI path). Capturing the values here + means ``server_info`` replies can be served from any worker thread + without re-entering backends like IDA that reject cross-thread API + calls. + """ + binary_path = None + binary_hash = None + if self.deci: + try: + raw_path = self.deci.binary_path + binary_path = str(raw_path) if raw_path else None + except Exception as exc: + _l.debug("Failed to cache binary_path: %s", exc) + try: + binary_hash = self.deci.binary_hash + except Exception as exc: + _l.debug("Failed to cache binary_hash: %s", exc) + + return { + "name": "DecLib DecompilerServer (AF_UNIX)", + "version": "3.0.0", + "decompiler": self.deci.name if self.deci else "unknown", + "protocol": "unix_socket", + "binary_hash": binary_hash, + "binary_path": binary_path, + "server_id": self.server_id, + } + + def _register_artifact_callbacks(self): + """Register callbacks to broadcast artifact changes to subscribed clients""" + from declib.artifacts import Comment, Struct, Enum, Typedef, GlobalVariable, FunctionHeader, StackVariable + + # Register callbacks for different artifact types + self.deci.artifact_change_callbacks[Comment].append( + lambda artifact, **kwargs: self._broadcast_event("comment_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[Struct].append( + lambda artifact, **kwargs: self._broadcast_event("struct_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[Enum].append( + lambda artifact, **kwargs: self._broadcast_event("enum_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[Typedef].append( + lambda artifact, **kwargs: self._broadcast_event("typedef_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[GlobalVariable].append( + lambda artifact, **kwargs: self._broadcast_event("global_variable_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[FunctionHeader].append( + lambda artifact, **kwargs: self._broadcast_event("function_header_changed", artifact, **kwargs) + ) + self.deci.artifact_change_callbacks[StackVariable].append( + lambda artifact, **kwargs: self._broadcast_event("stack_variable_changed", artifact, **kwargs) + ) + + def _broadcast_event(self, event_type: str, artifact, **kwargs): + """Broadcast an artifact change event to all subscribed clients""" + with self._event_subscribers_lock: + if not self._event_subscribers: + _l.debug(f"No subscribers for event: {event_type}") + return + + # Serialize the artifact + try: + serialized_artifact = { + 'type': artifact.__class__.__name__, + 'module': artifact.__class__.__module__, + 'data': artifact.dumps(fmt=_WIRE_FMT), + 'is_artifact': True + } + + event_message = { + "event_type": event_type, + "artifact": serialized_artifact, + "kwargs": kwargs + } + + # Send to all subscribers + dead_subscribers = [] + for subscriber_socket in self._event_subscribers: + try: + SocketProtocol.send_message(subscriber_socket, event_message) + _l.debug(f"Broadcasted {event_type} to subscriber") + except Exception as e: + _l.warning(f"Failed to send event to subscriber: {e}") + dead_subscribers.append(subscriber_socket) + + # Remove dead subscribers + for dead_socket in dead_subscribers: + self._event_subscribers.remove(dead_socket) + _l.debug("Removed dead subscriber") + + except Exception as e: + _l.error(f"Failed to broadcast event {event_type}: {e}") + + def start(self): + """Start the server in a separate thread""" + if self._running: + _l.warning("Server is already running") + return + + _l.info(f"Starting DecompilerServer on {self.socket_path}") + + # Create socket (AF_UNIX if available, else AF_INET) + if hasattr(socket, "AF_UNIX"): + self._server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._server_socket.settimeout(1.0) + if os.path.exists(self.socket_path): + os.unlink(self.socket_path) + self._server_socket.bind(self.socket_path) + else: + self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_socket.settimeout(1.0) + self._server_socket.bind(('127.0.0.1', 0)) + port = self._server_socket.getsockname()[1] + try: + with open(self.socket_path, 'w') as f: + f.write(str(port)) + except Exception as e: + _l.error(f"Failed to write port to {self.socket_path}: {e}") + self._server_socket.listen(5) + + # Set running flag before starting thread + self._running = True + + # Start server in a separate thread + self._server_thread = threading.Thread(target=self._server_loop, daemon=True) + self._server_thread.start() + + # Register in shared registry so other processes can find us. + if self._register: + try: + binary_path = str(self.deci.binary_path) if self.deci and self.deci.binary_path else None + binary_hash = None + try: + binary_hash = self.deci.binary_hash if self.deci else None + except Exception: + binary_hash = None + server_registry.register_server({ + "id": self.server_id, + "socket_path": self.socket_path, + "backend": self.deci.name if self.deci else None, + "binary_path": binary_path, + "binary_hash": binary_hash, + }) + self._registered = True + except Exception as exc: + _l.warning("Failed to register server: %s", exc) + + _l.info(f"DecompilerServer started successfully on unix://{self.socket_path}") + _l.info("Connect with: DecompilerClient.discover('unix://{}')".format(self.socket_path)) + + def _server_loop(self): + """Main server loop""" + try: + while self._running: + try: + client_socket, addr = self._server_socket.accept() + self._clients.append(client_socket) + + # Handle client in separate thread + client_thread = threading.Thread( + target=self.handler.handle_client, + args=(client_socket, str(addr)), + daemon=True + ) + self._client_threads.append(client_thread) + client_thread.start() + + except socket.timeout: + # Normal timeout, continue loop to check if we should stop + continue + except OSError: + # Socket was closed + break + except Exception as e: + _l.error(f"Error accepting client: {e}") + + except Exception as e: + _l.error(f"Server loop error: {e}") + finally: + _l.info("Server loop ended") + + def stop(self): + """Stop the server""" + if not self._running: + _l.warning("Server is not running") + return + + _l.info("Stopping DecompilerServer...") + self._running = False + + # Wake the main-thread dispatcher so it can exit `wait_for_shutdown`. + try: + self._main_thread_queue.put_nowait(_MAIN_THREAD_SHUTDOWN) + except Exception: + pass + + # Close all client connections + for client in self._clients: + try: + client.close() + except: + pass + + # Close server socket + if self._server_socket: + self._server_socket.close() + + # Wait for threads to finish (short timeout since we use daemon threads) + if self._server_thread and self._server_thread.is_alive(): + self._server_thread.join(timeout=2.0) + + for thread in self._client_threads: + if thread.is_alive(): + thread.join(timeout=0.5) + + # Clean up socket file and temp directory + if os.path.exists(self.socket_path): + os.unlink(self.socket_path) + + if self._temp_dir and os.path.exists(self._temp_dir): + try: + os.rmdir(self._temp_dir) + except: + pass + + # Remove from registry + if self._registered: + try: + server_registry.unregister_server(self.server_id) + except Exception as exc: + _l.debug("Failed to unregister server %s: %s", self.server_id, exc) + self._registered = False + + # Shutdown the decompiler interface. For backends that need the main + # thread (IDA/idalib), defer to the main thread which will run the + # shutdown after leaving the dispatch loop — doing it from a worker + # thread here raises "Function can be called from the main thread + # only". wait_for_shutdown() / __exit__ pick it up via + # _shutdown_deci_if_needed(). + if self.deci and not self._deci_shutdown_done: + on_main = ( + self._main_thread_ident is None + or threading.get_ident() == self._main_thread_ident + ) + if on_main or not self.requires_main_thread: + try: + self.deci.shutdown() + self._deci_shutdown_done = True + except Exception as e: + _l.warning(f"Error shutting down decompiler: {e}") + + _l.info("DecompilerServer stopped") + + def is_running(self) -> bool: + """Check if the server is currently running""" + return self._running + + @property + def requires_main_thread(self) -> bool: + """Whether backend API calls must be routed to the main thread. + + Set by the decompiler interface; IDA's idalib is the canonical case. + """ + if not self.deci: + return False + return bool(getattr(self.deci, "requires_main_thread_dispatch", False)) + + def run_on_main_thread(self, func, *args, **kwargs): + """Run ``func(*args, **kwargs)`` on the server's main thread. + + If the calling thread *is* the main thread, execute inline — this + avoids a deadlock when the main thread is itself invoking a method + (e.g. during ``__enter__`` / ``start``). + """ + if self._main_thread_ident is not None and threading.get_ident() == self._main_thread_ident: + return func(*args, **kwargs) + + result_q: "queue.Queue" = queue.Queue(maxsize=1) + self._main_thread_queue.put((func, args, kwargs, result_q)) + result = result_q.get() + if isinstance(result, _MainThreadError): + raise result.exc + return result + + def _main_thread_dispatch_loop(self): + """Drain backend work from the main-thread queue until shutdown. + + Only used for backends that require main-thread dispatch (IDA). + Runs on the thread that called ``wait_for_shutdown`` — i.e. the + thread that originally created the ``deci``. + """ + self._main_thread_ident = threading.get_ident() + while self._running: + try: + item = self._main_thread_queue.get(timeout=0.25) + except queue.Empty: + continue + if item is _MAIN_THREAD_SHUTDOWN: + break + func, args, kwargs, result_q = item + try: + result = func(*args, **kwargs) + except BaseException as exc: # relay every failure, including Java exceptions + result = _MainThreadError(exc) + result_q.put(result) + + def wait_for_shutdown(self): + """Wait for the server to be shut down (blocking)""" + if self.requires_main_thread: + # Become the main-thread dispatcher. This blocks until stop(). + try: + self._main_thread_dispatch_loop() + except KeyboardInterrupt: + _l.info("Received interrupt signal, stopping server...") + self.stop() + # Now that we're back on the main thread with the dispatch loop + # drained, finish any backend teardown stop() had to defer. + self._shutdown_deci_if_needed() + return + + if self._server_thread and self._server_thread.is_alive(): + try: + self._server_thread.join() + except KeyboardInterrupt: + _l.info("Received interrupt signal, stopping server...") + self.stop() + + def _shutdown_deci_if_needed(self): + """Run deci.shutdown() once, from the caller's thread. + + Callers must ensure they are on the thread that owns the backend + (typically the main thread). Idempotent. + """ + if not self.deci or self._deci_shutdown_done: + return + try: + self.deci.shutdown() + except Exception as e: + _l.warning(f"Error shutting down decompiler: {e}") + finally: + self._deci_shutdown_done = True + + def __enter__(self): + """Context manager entry""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit""" + self.stop() + self._shutdown_deci_if_needed() \ No newline at end of file diff --git a/declib/api/server_registry.py b/declib/api/server_registry.py new file mode 100644 index 00000000..6b1c8164 --- /dev/null +++ b/declib/api/server_registry.py @@ -0,0 +1,171 @@ +""" +Server registry for declib DecompilerServer instances. + +Each running server writes a small JSON descriptor into a shared registry +directory so that the `decompiler` CLI (and DecompilerClient.discover) can +find, filter, and connect to the right server instance. Stale records +(servers whose process has exited or whose socket has vanished) are pruned +on read. +""" +import json +import logging +import os +import tempfile +import time +import uuid +from pathlib import Path +from typing import Dict, List, Optional + +import psutil +from platformdirs import user_state_dir + +_l = logging.getLogger(__name__) + + +def _registry_dir() -> Path: + """Return the registry directory, creating it if missing.""" + env_override = os.environ.get("DECLIB_SERVER_REGISTRY") + if env_override: + path = Path(env_override) + else: + path = Path(user_state_dir("declib")) / "servers" + path.mkdir(parents=True, exist_ok=True) + return path + + +def new_server_id() -> str: + """Generate a short unique ID for a new server.""" + return uuid.uuid4().hex[:10] + + +def default_socket_path(server_id: str) -> str: + """Compute a default socket path for a server with the given ID.""" + temp_dir = Path(tempfile.gettempdir()) / f"declib_server_{server_id}" + temp_dir.mkdir(parents=True, exist_ok=True) + return str(temp_dir / "decompiler.sock") + + +def registry_path(server_id: str) -> Path: + return _registry_dir() / f"{server_id}.json" + + +def register_server(info: Dict) -> Path: + """Write a server descriptor into the registry. Required keys: id, socket_path.""" + server_id = info["id"] + path = registry_path(server_id) + payload = dict(info) + payload.setdefault("started_at", time.time()) + payload.setdefault("pid", os.getpid()) + tmp_path = path.with_suffix(".json.tmp") + with open(tmp_path, "w") as f: + json.dump(payload, f, indent=2, default=str) + os.replace(tmp_path, path) + return path + + +def unregister_server(server_id: str) -> bool: + path = registry_path(server_id) + try: + path.unlink() + return True + except FileNotFoundError: + return False + + +def _is_record_live(record: Dict) -> bool: + pid = record.get("pid") + socket_path = record.get("socket_path") + if not socket_path or not os.path.exists(socket_path): + return False + if pid is not None: + try: + if not psutil.pid_exists(int(pid)): + return False + except Exception: + return False + return True + + +def list_servers(prune_stale: bool = True) -> List[Dict]: + """Return all server records, optionally dropping and removing stale entries.""" + records: List[Dict] = [] + try: + entries = sorted(_registry_dir().glob("*.json")) + except FileNotFoundError: + return [] + + for entry in entries: + try: + with open(entry, "r") as f: + record = json.load(f) + except Exception as exc: + _l.debug("Failed to read server registry file %s: %s", entry, exc) + continue + + if prune_stale and not _is_record_live(record): + try: + entry.unlink() + except FileNotFoundError: + pass + except Exception as exc: + _l.debug("Failed to remove stale registry entry %s: %s", entry, exc) + continue + + records.append(record) + return records + + +def find_server( + server_id: Optional[str] = None, + binary_path: Optional[str] = None, + binary_hash: Optional[str] = None, + backend: Optional[str] = None, +) -> Optional[Dict]: + """Return the first server record matching all provided filters, else None.""" + binary_path_resolved = str(Path(binary_path).expanduser().resolve()) if binary_path else None + for record in list_servers(): + if server_id and record.get("id") != server_id: + continue + if binary_path_resolved: + record_path = record.get("binary_path") + if not record_path: + continue + try: + if str(Path(record_path).expanduser().resolve()) != binary_path_resolved: + continue + except Exception: + if record_path != binary_path_resolved: + continue + if binary_hash and record.get("binary_hash") != binary_hash: + continue + if backend and record.get("backend") != backend: + continue + return record + return None + + +def find_servers( + binary_path: Optional[str] = None, + binary_hash: Optional[str] = None, + backend: Optional[str] = None, +) -> List[Dict]: + """Return all server records matching the provided filters.""" + matches: List[Dict] = [] + binary_path_resolved = str(Path(binary_path).expanduser().resolve()) if binary_path else None + for record in list_servers(): + if binary_path_resolved: + record_path = record.get("binary_path") + if not record_path: + continue + try: + if str(Path(record_path).expanduser().resolve()) != binary_path_resolved: + continue + except Exception: + if record_path != binary_path_resolved: + continue + if binary_hash and record.get("binary_hash") != binary_hash: + continue + if backend and record.get("backend") != backend: + continue + matches.append(record) + return matches diff --git a/declib/api/type_definition_parser.py b/declib/api/type_definition_parser.py new file mode 100644 index 00000000..d57757ca --- /dev/null +++ b/declib/api/type_definition_parser.py @@ -0,0 +1,201 @@ +""" +Parse a single C type *definition* string into the matching declib artifact. + +Unlike ``CTypeParser`` (declib/api/type_parser.py), which is deliberately scoped to +type *expressions* ("int *", "struct Foo *"), this module handles full type +*definitions* with bodies: + + - ``struct Name { };`` -> :class:`declib.artifacts.Struct` + - ``enum Name { A, B=5, C };`` -> :class:`declib.artifacts.Enum` + - ``typedef Name;`` -> :class:`declib.artifacts.Typedef` + +It is intentionally decompiler-free and unit-testable: the heavy lifting is done by +``pycparser`` (already a declib dependency) for the AST and member type-string +rendering, and by ``CTypeParser`` for member sizing. The resulting artifact is then +applied to a decompiler via the normal ``deci.structs[name] = struct`` / +``deci.set_artifact(...)`` path, which is portable across every backend. +""" +import logging +import re +from typing import Optional, Union + +import pycparser +from pycparser import c_ast, c_generator +from pycparser.c_parser import ParseError + +from declib.artifacts import Struct, StructMember, Enum, Typedef +from declib.api.type_parser import CTypeParser + +_l = logging.getLogger(__name__) + +# Reuse single instances; both are stateless across parses. +_GENERATOR = c_generator.CGenerator() +_PARSER = pycparser.CParser() +_DEFAULT_TYPE_PARSER = CTypeParser() + +# Member natural alignment is its own size in System V, capped at the platform +# word width (pointers/long are 8 in CTypeParser's defaults). +_MAX_ALIGN = 8 + + +class TypeDefinitionParseError(ValueError): + """Raised when a C type-definition string cannot be turned into a declib artifact.""" + + +def parse_type_definition( + text: str, + type_parser: Optional[CTypeParser] = None, +) -> Union[Struct, Enum, Typedef]: + """ + Parse a single C type *definition* into the matching declib artifact. + + Supports exactly one top-level definition: a named ``struct``, ``enum``, or + ``typedef``. Raises :class:`TypeDefinitionParseError` on anything unparseable, + anonymous, multi-definition, or otherwise unsupported. + + >>> parse_type_definition("struct Point { int x; int y; }") + + """ + tp = type_parser or _DEFAULT_TYPE_PARSER + ast = _parse_ast(_normalize(text)) + top = ast.ext[0] + + if isinstance(top, c_ast.Typedef): + return _typedef_from_ast(top) + + # struct/enum arrive wrapped in a Decl whose .type is the Struct/Enum node + if isinstance(top, c_ast.Decl): + inner = top.type + if isinstance(inner, c_ast.Struct): + return _struct_from_ast(inner, tp) + if isinstance(inner, c_ast.Enum): + return _enum_from_ast(inner, tp) + + raise TypeDefinitionParseError( + f"Unsupported top-level definition: {type(top).__name__}. " + "Expected a named struct, enum, or typedef." + ) + + +def _normalize(text: str) -> str: + if not text or not text.strip(): + raise TypeDefinitionParseError("Empty type definition.") + # strip C comments (same approach as CTypeParser.parse_type_with_name) + text = re.sub(r"/\*.*?\*/", "", text, flags=re.DOTALL) + text = re.sub(r"//.*?$", "", text, flags=re.MULTILINE) + text = text.strip() + if not text.endswith(";"): + text += ";" + return text + + +def _parse_ast(text: str) -> c_ast.FileAST: + try: + ast = _PARSER.parse(text) + except ParseError as exc: + raise TypeDefinitionParseError(f"could not parse C definition: {exc}") + if not ast.ext: + raise TypeDefinitionParseError("no type definition found.") + if len(ast.ext) != 1: + raise TypeDefinitionParseError( + "expected exactly one type definition, got " + f"{len(ast.ext)}. Define one type at a time." + ) + return ast + + +def _render_type(node) -> str: + """Render a member/typedef type node back to a C type string, e.g. "char *".""" + rendered = _GENERATOR.visit(node).strip() + if "\n" in rendered or "{" in rendered: + raise TypeDefinitionParseError( + "inline/nested type definitions are unsupported here; define the " + "inner type separately and reference it by name." + ) + return rendered + + +def _member_size(tp: CTypeParser, type_str: str) -> int: + ct = tp.parse_type(type_str) + if ct is None or not ct.size: + # Unknown, user-defined non-pointer type (e.g. "struct Bar" before Bar + # exists): we cannot reliably size it, so reject rather than emit a + # 0-size member that would corrupt every subsequent offset. + raise TypeDefinitionParseError( + f"could not determine the size of member type {type_str!r}. " + "Define referenced types first, or use a pointer/primitive." + ) + return ct.size + + +def _struct_from_ast(struct_node: c_ast.Struct, tp: CTypeParser) -> Struct: + if not struct_node.name: + raise TypeDefinitionParseError( + "anonymous structs are not supported; give the struct a name." + ) + if not struct_node.decls: + raise TypeDefinitionParseError( + f"struct {struct_node.name!r} has no members to define." + ) + + members = {} + offset = 0 + max_align = 1 + for decl in struct_node.decls: + if decl.name is None: + raise TypeDefinitionParseError( + f"unnamed member in struct {struct_node.name!r} is unsupported." + ) + type_str = _render_type(decl.type) + size = _member_size(tp, type_str) + align = min(size, _MAX_ALIGN) if size else 1 + # round the running offset up to this member's natural alignment + if align > 1 and offset % align: + offset += align - (offset % align) + members[offset] = StructMember( + name=decl.name, offset=offset, type_=type_str, size=size, + ) + offset += size + max_align = max(max_align, align) + + total = offset + if max_align > 1 and total % max_align: + total += max_align - (total % max_align) + + return Struct(name=struct_node.name, size=total, members=members) + + +def _enum_from_ast(enum_node: c_ast.Enum, tp: CTypeParser) -> Enum: + if not enum_node.name: + raise TypeDefinitionParseError( + "anonymous enums are not supported; give the enum a name." + ) + if not enum_node.values or not enum_node.values.enumerators: + raise TypeDefinitionParseError( + f"enum {enum_node.name!r} has no members to define." + ) + + members = {} + next_val = 0 + for en in enum_node.values.enumerators: + if en.value is None: + val = next_val + else: + try: + val = tp._parse_const(en.value) + except Exception: + raise TypeDefinitionParseError( + f"could not evaluate enum value for {en.name!r}." + ) + members[en.name] = val + next_val = val + 1 + + return Enum(name=enum_node.name, members=members) + + +def _typedef_from_ast(typedef_node: c_ast.Typedef) -> Typedef: + name = typedef_node.name + if not name: + raise TypeDefinitionParseError("typedef is missing a name.") + type_str = _render_type(typedef_node.type) + return Typedef(name=name, type_=type_str) diff --git a/declib/api/type_parser.py b/declib/api/type_parser.py new file mode 100644 index 00000000..31981be3 --- /dev/null +++ b/declib/api/type_parser.py @@ -0,0 +1,409 @@ +import re +import logging +from collections import OrderedDict, defaultdict, ChainMap +from typing import Optional + +import pycparser +from pycparser import c_ast +from pycparser.c_parser import ParseError + +# pycparser hack to parse type expressions +errorlog = logging.getLogger(name=__name__ + ".yacc") +errorlog.setLevel(logging.ERROR) + + +l = logging.getLogger(__name__) + + +def _patch_pycparser(): + """ + Adds a `parse_type_with_name` method to pycparser.CParser that parses a bare + type expression (like "int *") rather than a full translation unit. pycparser + 3.0 removed the ability to customize the start production via ply.yacc. + """ + if hasattr(pycparser.CParser, "parse_type_with_name"): + return + + def parse_type_with_name(self, text, filename="", scope_stack=None) -> c_ast.Typename: + self.clex._filename = filename + self.clex._lineno = 1 + self._scope_stack = [{}] if scope_stack is None else scope_stack + + self.clex.input(text, filename) + self._tokens = pycparser.c_parser._TokenStream(self.clex) + + return self._parse_type_name() + + pycparser.CParser.parse_type_with_name = parse_type_with_name + + +_patch_pycparser() + + +class CType: + def __init__( + self, + type_=None, + size=0, + is_primitive=True, + is_array=False, + is_ptr=False, + is_unknown=False + ): + self.type = type_ + self._size = size + + self.is_primitive = is_primitive + self.is_array = is_array + self.is_ptr = is_ptr + self.is_unknown = is_unknown + + def __str__(self): + return f"<{self.__class__.__name__}: {self.type} {'[]' if self.is_array else ''}{'*' if self.is_ptr else ''}" \ + f"{'U' if self.is_unknown else ''} ({self._size})>" + + def __repr__(self): + return self.__str__() + + @property + def type_str(self): + if isinstance(self.type, CType) and self.is_array: + return self.type.type_str + f"[{self._size}]" + + return self.type + + @property + def base_type(self): + if isinstance(self.type, str): + return self + elif isinstance(self.type, CType): + return self.type.base_type + + return self.type + + @property + def size(self): + if isinstance(self.type, CType) and self.is_array: + return self.type.size * self._size + + return self._size + + +class CTypeParser: + """ + Most of this code is ripped from angr's sim_type: + https://github.com/angr/angr/blob/master/angr/sim_type.py + + It is highly simplified and drops a lot of support for real declaration parsing (like a struct dec). + Instead, we just use it to parse types. + """ + def __init__( + self, + sizeof_ptr=8, + sizeof_long=8, + sizeof_double=8, + sizeof_int=4, + sizeof_float=4, + sizeof_short=2, + sizeof_char=1, + sizeof_bool=1, + extra_types=None + ): + # sizes + self.sizeof_ptr = sizeof_ptr + self.sizeof_long = sizeof_long + self.sizeof_double = sizeof_double + self.sizeof_int = sizeof_int + self.sizeof_float = sizeof_float + self.sizeof_short = sizeof_short + self.sizeof_char = sizeof_char + self.sizeof_bool = sizeof_bool + + # hack in type parsing + self._type_parser_singleton = pycparser.CParser() + self.ALL_TYPES = {} + self.BASIC_TYPES = {} + self.STDINT_TYPES = {} + self.extra_types = extra_types or {} + self.SIZE_TO_TYPES = {} + self._init_all_types() + + def _init_all_types(self): + self.BASIC_TYPES = { + "char": CType(type_="char", size=self.sizeof_char), + "signed char": CType(type_="signed char", size=self.sizeof_char), + "unsigned char": CType(type_="unsigned char", size=self.sizeof_char), + "short": CType(type_="short", size=self.sizeof_short), + "signed short": CType(type_="signed short", size=self.sizeof_short), + "unsigned short": CType(type_="unsigned short", size=self.sizeof_short), + "short int": CType(type_="short int", size=self.sizeof_short), + "signed short int": CType(type_="signed short int", size=self.sizeof_short), + "unsigned short int": CType(type_="unsigned short int", size=self.sizeof_short), + "int": CType(type_="int", size=self.sizeof_int), + "signed": CType(type_="signed", size=self.sizeof_int), + "unsigned": CType(type_="unsigned", size=self.sizeof_int), + "signed int": CType(type_="signed int", size=self.sizeof_int), + "unsigned int": CType(type_="unsigned int", size=self.sizeof_int), + "long": CType(type_="long", size=self.sizeof_long), + "signed long": CType(type_="signed long", size=self.sizeof_long), + "long signed": CType(type_="long signed", size=self.sizeof_long), + "unsigned long": CType(type_="unsigned long", size=self.sizeof_long), + "long int": CType(type_="long int", size=self.sizeof_long), + "signed long int": CType(type_="signed long int", size=self.sizeof_long), + "unsigned long int": CType(type_="unsigned long int", size=self.sizeof_long), + "long unsigned int": CType(type_="long unsigned int", size=self.sizeof_long), + "long long": CType(type_="long long", size=self.sizeof_long), + "signed long long": CType(type_="signed long long", size=self.sizeof_long), + "unsigned long long": CType(type_="unsigned long long", size=self.sizeof_long), + "long long int": CType(type_="long long int", size=self.sizeof_long), + "signed long long int": CType(type_="signed long long int", size=self.sizeof_long), + "unsigned long long int": CType(type_="unsigned long long int", size=self.sizeof_long), + "__int128": CType(type_="__int128", size=16), + "unsigned __int128": CType(type_="unsigned __int128", size=16), + "__int256": CType(type_="__int256", size=32), + "unsigned __int256": CType(type_="unsigned __int256", size=32), + "bool": CType(type_="bool", size=self.sizeof_bool), + "_Bool": CType(type_="_Bool", size=self.sizeof_bool), + "float": CType(type_="float", size=self.sizeof_float), + "double": CType(type_="double", size=self.sizeof_double), + "long double": CType(type_="double", size=self.sizeof_double), + "void": CType(type_="void", size=self.sizeof_ptr), + } + self.ALL_TYPES.update(self.BASIC_TYPES) + + self.STDINT_TYPES = { + "int8_t": CType(type_="int8_t", size=1), + "uint8_t": CType(type_="uint8_t", size=1), + "byte": CType(type_="byte", size=1), + "int16_t": CType(type_="int16_t", size=2), + "uint16_t": CType(type_="uint16_t", size=2), + "word": CType(type_="word", size=2), + "int32_t": CType(type_="int32_t", size=4), + "uint32_t": CType(type_="uint32_t", size=4), + "dword": CType(type_="dword", size=4), + "int64_t": CType(type_="int64_t", size=8), + "uint64_t": CType(type_="uint64_t", size=8), + } + self.ALL_TYPES.update(self.STDINT_TYPES) + self.ALL_TYPES.update(self.extra_types) + + for name, ctype in self.ALL_TYPES.items(): + if name.startswith("unsigned"): + self.SIZE_TO_TYPES[ctype.size] = ctype + + def extract_type_name(self, type_str: str) -> str | None: + """ + Normalizes types that may come in as declarations, removing any extraneous information that may be present + and just getting the name of that type. + + In the case of: + "char *" that would return "char *" + "struct foo *" that would return "foo *" + "typedef int my_type" that would return "my_type" + """ + # XXX: this could be subverted by a type name that contains a scope or external ref ("extern") + is_defined = any(type_str.strip().startswith(t) for t in ["struct", "enum", "union", "typedef"]) + if not is_defined: + return type_str + + parsable_type = type_str.replace(";", "").strip() + type_name = None + try: + ast = self._type_parser_singleton.parse(text=parsable_type + ";") + if ast.ext: + type_name = ast.ext[0].name + except ParseError: + pass + + if type_name is None: + # do a hackish parse to get the type name, which may be inside a defined type-in-place + # remove "struct", "enum", "union", and "typedef" keywords, select the final type + if any(parsable_type.startswith(t) for t in ["struct", "enum", "union", "typedef"]): + final_type = type_str.split(" ")[-1] + final_type = final_type.replace(";", "").strip() + if " " not in final_type: + # final sanity check that it really is just a name + type_name = final_type + + return type_name + + def parse_type(self, defn, predefined_types=None, arch=None) -> Optional[CType]: # pylint:disable=unused-argument + """ + Parse a simple type expression into a SimType + + >>> self.parse_type('int *') + """ + return self.parse_type_with_name(defn, predefined_types=predefined_types, arch=arch)[0] + + def parse_type_with_name(self, defn, predefined_types=None, arch=None): # pylint:disable=unused-argument + """ + Parse a simple type expression into a SimType, returning the a tuple of the type object and any associated name + that might be found in the place a name would go in a type declaration. + + >>> self.parse_type_with_name('int *foo') + """ + if not defn: + return None + + if pycparser is None: + raise ImportError("Please install pycparser in order to parse C definitions") + + defn = re.sub(r"/\*.*?\*/", r"", defn, flags=re.DOTALL) + defn = re.sub(r"//.*?$", r"", defn, flags=re.MULTILINE) + + failed_parse = False + try: + node = self._type_parser_singleton.parse_type_with_name(text=defn) + except ParseError: + failed_parse = True + + # + # in the event of a failed type parse it may just be a custom type, so we should try again + # with the struct specifier and see if it works out + # + if failed_parse: + try: + node = self._type_parser_singleton.parse_type_with_name(text="struct " + defn) + except Exception: + return (None, ) + + if not isinstance(node, pycparser.c_ast.Typename) and \ + not isinstance(node, pycparser.c_ast.Decl): + raise pycparser.c_parser.ParseError("Got an unexpected type out of pycparser") + + decl = node.type + extra_types = {} if not predefined_types else dict(predefined_types) + return self._decl_to_type(decl, extra_types=extra_types), node.name + + def _decl_to_type(self, decl, extra_types=None) -> Optional[CType]: + if not decl: + return decl + + if extra_types is None: extra_types = {} + + if isinstance(decl, pycparser.c_ast.FuncDecl): + return None + + elif isinstance(decl, pycparser.c_ast.TypeDecl): + return self._decl_to_type(decl.type, extra_types) + + elif isinstance(decl, pycparser.c_ast.Typedef): + return self._decl_to_type(decl.type, extra_types) + + elif isinstance(decl, pycparser.c_ast.PtrDecl): + pts_to = self._decl_to_type(decl.type, extra_types) + if not pts_to: + return None + + return CType(type_=pts_to.type, size=self.sizeof_ptr, is_ptr=True, is_unknown=pts_to.is_unknown) + + elif isinstance(decl, pycparser.c_ast.ArrayDecl): + elem_type = self._decl_to_type(decl.type, extra_types) + + if decl.dim is None: + """ + r = SimTypeArray(elem_type) + r._arch = arch + return r + """ + return CType(type_=elem_type, is_array=True, size=0) + try: + size = self._parse_const(decl.dim, extra_types=extra_types) + except ValueError as e: + #l.warning("Got error parsing array dimension, defaulting to zero: %s", e) + size = 0 + """ + r = SimTypeFixedSizeArray(elem_type, size) + r._arch = arch + """ + return CType(type_=elem_type, is_array=True, size=size) + + elif isinstance(decl, pycparser.c_ast.Struct): + if decl is None: + return None + + return CType(type_=decl.name, is_unknown=True) + + elif isinstance(decl, pycparser.c_ast.Union): + return None + + elif isinstance(decl, pycparser.c_ast.IdentifierType): + key = ' '.join(decl.names) + if key in extra_types: + return extra_types[key] + elif key in self.ALL_TYPES: + return self.ALL_TYPES[key] + else: + #raise TypeError("Unknown type '%s'" % key) + return CType(type_=key, is_unknown=True) + + elif isinstance(decl, pycparser.c_ast.Enum): + # See C99 at 6.7.2.2 + return self.ALL_TYPES['int'] + + raise ValueError("Unknown type!") + + def _make_scope(self, predefined_types=None): + """ + Generate CParser scope_stack argument to parse method + """ + all_types = ChainMap(predefined_types or {}, self.ALL_TYPES) + scope = dict() + for ty in all_types: + if ty in self.BASIC_TYPES: + continue + if ' ' in ty: + continue + + typ = all_types[ty] + scope[ty] = True + return [scope] + + def _parse_const(self, c, extra_types=None): + if type(c) is pycparser.c_ast.Constant: + return int(c.value, base=0) + elif type(c) is pycparser.c_ast.BinaryOp: + if c.op == '+': + return self._parse_const(c.children()[0][1], extra_types=extra_types) + self._parse_const( + c.children()[1][1], extra_types=extra_types) + if c.op == '-': + return self._parse_const(c.children()[0][1], extra_types=extra_types) - self._parse_const( + c.children()[1][1], extra_types=extra_types) + if c.op == '*': + return self._parse_const(c.children()[0][1], extra_types=extra_types) * self._parse_const( + c.children()[1][1], extra_types=extra_types) + if c.op == '/': + return self._parse_const(c.children()[0][1], extra_types=extra_types) // self._parse_const( + c.children()[1][1], extra_types=extra_types) + if c.op == '<<': + return self._parse_const(c.children()[0][1], extra_types=extra_types) << self._parse_const( + c.children()[1][1], extra_types=extra_types) + if c.op == '>>': + return self._parse_const(c.children()[0][1], extra_types=extra_types) >> self._parse_const( + c.children()[1][1], extra_types=extra_types) + raise ValueError('Binary op %s' % c.op) + elif type(c) is pycparser.c_ast.UnaryOp: + if c.op == 'sizeof': + return self._decl_to_type(c.expr.type, extra_types=extra_types).size + else: + raise ValueError("Unary op %s" % c.op) + elif type(c) is pycparser.c_ast.Cast: + return self._parse_const(c.expr, extra_types=extra_types) + else: + raise ValueError(c) + + def size_to_type(self, size: int) -> CType: + if not size: + raise ValueError("A type size must be greater than 0") + + ctype = self.SIZE_TO_TYPES.get(size, None) + if ctype is None: + # one of two possible things have happend here: + # 1. this is a type that is larger than the simple types, in which case it's an array + # 2. this is a type with a non-aligned size, in which case we default to an array as well + array_size = size // self.sizeof_char + ctype = self.parse_type(f"char[{array_size}]") + if ctype is None: + raise RuntimeError(f"Failed to create a CType of array size {array_size}") + + return ctype \ No newline at end of file diff --git a/declib/api/utils.py b/declib/api/utils.py new file mode 100644 index 00000000..3fda0bd6 --- /dev/null +++ b/declib/api/utils.py @@ -0,0 +1,31 @@ +import math + +import tqdm + + +def progress_bar(items, gui=True, desc="Progressing..."): + if gui: + from declib.ui.utils import QProgressBarDialog + pbar = QProgressBarDialog(label_text=desc) + pbar.show() + callback_stub = pbar.update_progress + else: + t = tqdm.tqdm(desc=desc) + callback_stub = t.update + + bucket_size = len(items) / 100.0 + if bucket_size < 1: + callback_amt = int(1 / (bucket_size)) + bucket_size = 1 + else: + callback_amt = 1 + bucket_size = math.ceil(bucket_size) + + for i, item in enumerate(items): + yield item + if i % bucket_size == 0: + callback_stub(callback_amt) + + if gui: + # close the progress bar since it may not hit 100% + pbar.close() diff --git a/declib/artifacts/__init__.py b/declib/artifacts/__init__.py new file mode 100644 index 00000000..dc96bc2a --- /dev/null +++ b/declib/artifacts/__init__.py @@ -0,0 +1,93 @@ +import json + +import toml + +from .formatting import TomlHexEncoder, ArtifactFormat +from .artifact import Artifact +from .comment import Comment +from .decompilation import Decompilation +from .enum import Enum +from .func import Function, FunctionHeader, FunctionArgument +from .global_variable import GlobalVariable +from .patch import Patch +from .segment import Segment +from .stack_variable import StackVariable +from .struct import Struct, StructMember +from .context import Context +from .typedef import Typedef + +ART_NAME_TO_CLS = { + Function.__name__: Function, + FunctionHeader.__name__: FunctionHeader, + FunctionArgument.__name__: FunctionArgument, + StackVariable.__name__: StackVariable, + Comment.__name__: Comment, + GlobalVariable.__name__: GlobalVariable, + Enum.__name__: Enum, + Struct.__name__: Struct, + StructMember.__name__: StructMember, + Patch.__name__: Patch, + Decompilation.__name__: Decompilation, + Context.__name__: Context, + Typedef.__name__: Typedef, + Segment.__name__: Segment, +} + +ALL_ARTIFACTS = list(ART_NAME_TO_CLS.values()) + + +def _dict_from_str(art_str: str, fmt=ArtifactFormat.TOML) -> dict: + if fmt == ArtifactFormat.TOML: + return toml.loads(art_str) + elif fmt == ArtifactFormat.JSON: + return json.loads(art_str) + else: + raise ValueError(f"Loading from format {fmt} is not yet supported.") + + +def _art_from_dict(art_dict: dict) -> Artifact: + art_type_str = art_dict.get(Artifact.ART_TYPE_STR, None) + if art_type_str is None: + raise ValueError(f"Artifact type string not found in artifact data: {art_dict}. Is this a valid artifact?") + + art_cls = ART_NAME_TO_CLS[art_type_str] + art = art_cls() + art.__setstate__(art_dict) + return art + + +def _load_arts_from_list(art_strs: list[str], fmt=ArtifactFormat.TOML) -> list[Artifact]: + arts = [] + for art_str in art_strs: + data_dict = _dict_from_str(art_str, fmt=fmt) + art = _art_from_dict(data_dict) + arts.append(art) + return arts + + +def _load_arts_from_string(art_str: str, fmt=ArtifactFormat.TOML) -> list[Artifact]: + data_dict = _dict_from_str(art_str, fmt=fmt) + if isinstance(data_dict, dict): + data_dicts = list(data_dict.values()) + elif isinstance(data_dict, list): + data_dicts = data_dict + else: + raise ValueError(f"Unexpected data type: {type(data_dict)}") + + arts = [] + for v in data_dicts: + art = _art_from_dict(v) + arts.append(art) + + return arts + + +def load_many_artifacts(art_strings: list[str], fmt=ArtifactFormat.TOML) -> list[Artifact]: + """ + A helper function to load many dumped artifacts from a list of strings. Each string should have been dumped + using the `dumps` method of an artifact. + + :param art_strings: A list of strings or a single string containing multiple dumped artifacts. + :param fmt: The format of the dumped artifacts. + """ + return _load_arts_from_list(art_strings, fmt=fmt) diff --git a/declib/artifacts/artifact.py b/declib/artifacts/artifact.py new file mode 100644 index 00000000..dca48ac3 --- /dev/null +++ b/declib/artifacts/artifact.py @@ -0,0 +1,311 @@ +import json +from typing import Dict, Optional, List +import datetime + +import toml + +from .formatting import ArtifactFormat, TomlHexEncoder + +from toml.tz import TomlTz + +class Artifact: + """ + The Artifact class acts as the base for all other artifacts that can be produced by a decompiler (or decompiler + adjacent tool). In general, the comparisons of these derived classes should only be done on the attributes in + __slots__, except for the last_change property. + """ + LST_CHNG_ATTR = "last_change" + ADDR_ATTR = "addr" + ART_TYPE_STR = "artifact_type" + SCOPE_ATTR = "scope" + + ATTR_ATTR_IGNORE_SET = "_attr_ignore_set" + __slots__ = ( + LST_CHNG_ATTR, + ATTR_ATTR_IGNORE_SET, + SCOPE_ATTR + ) + + def __init__(self, last_change: Optional[datetime.datetime] = None, scope: Optional[str] = None): + self.last_change = last_change + self.scope = scope + self._attr_ignore_set = set() + + @staticmethod + def _normalize_datetime(dt): + """ + Convert TomlTz datetime objects to standard Python datetime objects. + TomlTz objects from TOML deserialization don't pickle correctly. + """ + if not isinstance(dt, datetime.datetime): + return dt + + # If the datetime has a TomlTz tzinfo, convert it to standard timezone + if dt.tzinfo is not None and isinstance(dt.tzinfo, TomlTz): + # Get the offset and convert to standard timezone + offset = dt.utcoffset() + if offset is not None: + std_tz = datetime.timezone(offset) + # Replace the TomlTz with standard timezone + return dt.replace(tzinfo=std_tz) + + return dt + + def __getstate__(self) -> Dict: + state = {} + for k in self.slots: + value = getattr(self, k) + # Normalize datetime objects to ensure they pickle correctly + if isinstance(value, datetime.datetime): + value = self._normalize_datetime(value) + state[k] = value + return state + + def __setstate__(self, state): + # When pickle calls __setstate__, __init__ is never called, so we need to + # initialize _attr_ignore_set before accessing self.slots (which uses it) + if not hasattr(self, '_attr_ignore_set'): + self._attr_ignore_set = set() + + for k in self.slots: + if k in state: + setattr(self, k, state[k]) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + + for k in self.slots: + if k == self.LST_CHNG_ATTR: + continue + elif k == self.SCOPE_ATTR: + # special case scopes: a scope of None indicates that the artifact could be in any scope + this_scope = getattr(self, k) + that_scope = getattr(other, k) + if this_scope is None or that_scope is None: + continue + + if getattr(self, k) != getattr(other, k): + return False + + return True + + def __hash__(self): + long_str = "" + for attr in self.slots: + long_str += str(getattr(self, attr)) + + return hash(long_str) + + def __repr__(self): + return self.__str__() + + @property + def slots(self): + return [s for s in self.__slots__ if s != self.ATTR_ATTR_IGNORE_SET and s not in self._attr_ignore_set] + + def copy(self) -> "Artifact": + new_obj = self.__class__() + for attr in self.slots: + attr_v = getattr(self, attr) + if isinstance(attr_v, list): + new_list = [] + for lobj in attr_v: + if hasattr(lobj, "copy"): + new_list.append(lobj.copy()) + setattr(new_obj, attr, new_list) + elif isinstance(attr_v, dict): + new_dict = {} + for dk, dv in attr_v.items(): + new_dk = dk.copy() if hasattr(dk, "copy") else dk + new_dv = dv.copy() if hasattr(dv, "copy") else dv + new_dict[new_dk] = new_dv + setattr(new_obj, attr, new_dict) + elif isinstance(attr_v, Artifact): + setattr(new_obj, attr, attr_v.copy()) + else: + setattr(new_obj, attr, attr_v) + + return new_obj + + # + # Serialization + # + + def _to_c_string(self): + raise NotImplementedError + + @classmethod + def _from_c_string(cls, cstring) -> Dict: + raise NotImplementedError + + def dumps(self, fmt=ArtifactFormat.TOML) -> str: + dict_data = self.__getstate__() + # encode the artifact type only for JSON format + if fmt == ArtifactFormat.JSON: + dict_data.update({self.ART_TYPE_STR: self.__class__.__name__}) + + if fmt == ArtifactFormat.TOML: + return toml.dumps(dict_data, encoder=TomlHexEncoder()) + elif fmt == ArtifactFormat.JSON: + return json.dumps(dict_data) + elif fmt == ArtifactFormat.C_LANG: + return self._to_c_string() + else: + raise ValueError(f"Dumping to format {fmt} is not yet supported.") + + def dump(self, fp, fmt=ArtifactFormat.TOML): + data = self.dumps(fmt=fmt) + fp.write(data) + + @classmethod + def loads(cls, string, fmt=ArtifactFormat.TOML) -> "Artifact": + if fmt == ArtifactFormat.TOML: + dict_data = toml.loads(string) + elif fmt == ArtifactFormat.JSON: + dict_data = json.loads(string) + elif fmt == ArtifactFormat.C_LANG: + dict_data = cls._from_c_string(string) + else: + raise ValueError(f"Loading from format {fmt} is not yet supported.") + + # remove the artifact type (if it exists) + dict_data.pop(Artifact.ART_TYPE_STR, None) + art = cls() + art.__setstate__(dict_data) + return art + + @classmethod + def load(cls, fp, fmt=ArtifactFormat.TOML): + data = fp.read() + return cls.loads(data, fmt=fmt) + + @classmethod + def dumps_many(cls, artifacts: List["Artifact"], key_attr=ADDR_ATTR, fmt=ArtifactFormat.TOML) -> str: + artifacts_dict = {} + for art in artifacts: + k = getattr(art, key_attr) + if isinstance(k, int): + k = hex(k) + + artifacts_dict[k] = art.__getstate__() + + if fmt == ArtifactFormat.TOML: + return toml.dumps(artifacts_dict, encoder=TomlHexEncoder()) + elif fmt == ArtifactFormat.JSON: + return json.dumps(artifacts_dict) + else: + raise ValueError(f"Dumping many to format {fmt} is not yet supported.") + + @classmethod + def loads_many(cls, string: str, fmt=ArtifactFormat.TOML) -> List["Artifact"]: + if fmt == ArtifactFormat.TOML: + dict_data = toml.loads(string) + elif fmt == ArtifactFormat.JSON: + dict_data = json.loads(string) + else: + raise ValueError(f"Loading many from format {fmt} is not yet supported.") + + arts = [] + for _, v in dict_data.items(): + art = cls() + art.__setstate__(v) + arts.append(art) + + return arts + + # + # Public API + # + + @property + def scoped_name(self) -> str: + """ + Returns the name of the artifact with its scope, if it has one. + """ + if hasattr(self, "name"): + if self.scope: + return f"{self.scope}::{self.name}" + + return self.name + return "" + + @property + def commit_msg(self) -> str: + return f"Updated {self}" + + def diff(self, other, **kwargs) -> Dict: + diff_dict = {} + if not isinstance(other, self.__class__): + for k in self.slots: + if k == self.LST_CHNG_ATTR: + continue + + diff_dict[k] = { + "before": getattr(self, k), + "after": None + } + return diff_dict + + for k in self.slots: + self_attr, other_attr = getattr(self, k), getattr(other, k) + if self_attr != other_attr: + if k == self.LST_CHNG_ATTR: + continue + + diff_dict[k] = { + "before": self_attr, + "after": other_attr + } + return diff_dict + + @classmethod + def invert_diff(cls, diff_dict: Dict): + inverted_diff = {} + for k, v in diff_dict.items(): + if k == "before": + inverted_diff["after"] = v + elif k == "after": + inverted_diff["before"] = v + elif isinstance(v, Dict): + inverted_diff[k] = cls.invert_diff(v) + else: + inverted_diff[k] = v + + return inverted_diff + + def reset_last_change(self): + """ + Resets the change time of the Artifact. + In subclasses, this should also reset all artifacts with nested artifacts + """ + self.last_change = None + + def overwrite_merge(self, obj2: "Artifact", **kwargs): + """ + This function should really be overwritten by its subclass + """ + merge_obj = self.copy() + if not obj2 or merge_obj == obj2: + return merge_obj + + for attr in self.slots: + a2 = getattr(obj2, attr) + if a2 is not None: + setattr(merge_obj, attr, a2) + + return merge_obj + + def nonconflict_merge(self, obj2: "Artifact", **kwargs): + obj1 = self.copy() + if not obj2 or obj1 == obj2: + return obj1 + + obj_diff = obj1.diff(obj2) + merge_obj = obj1.copy() + + for attr in self.slots: + if attr in obj_diff and obj_diff[attr]["before"] is None: + setattr(merge_obj, attr, getattr(obj2, attr)) + + return merge_obj diff --git a/declib/artifacts/comment.py b/declib/artifacts/comment.py new file mode 100644 index 00000000..857be55d --- /dev/null +++ b/declib/artifacts/comment.py @@ -0,0 +1,49 @@ +import textwrap +from typing import Optional + +from .artifact import Artifact + + +class Comment(Artifact): + __slots__ = Artifact.__slots__ + ( + "addr", + "func_addr", + "comment", + "decompiled", + ) + + def __init__( + self, + addr: int = None, + comment: Optional[str] = None, + func_addr: int = None, + decompiled: bool = False, + **kwargs + ): + super().__init__(**kwargs) + self.addr = addr + self.comment = self.linewrap_comment(comment) if comment else None + self.func_addr = func_addr + self.decompiled = decompiled + + def __str__(self): + cmt_len = len(self.comment) if self.comment else 0 + return f"" + + @staticmethod + def linewrap_comment(comment: str, width=100) -> str: + # Split the comment into lines based on existing newlines + lines = comment.split('\n') + # Wrap each line individually and preserve newlines + wrapped_lines = [textwrap.fill(line, width=width) for line in lines] + # Join the wrapped lines with newline characters + wrapped_text = '\n'.join(wrapped_lines) + return wrapped_text + + def nonconflict_merge(self, obj2: "Comment", **kwargs) -> "Comment": + obj1: "Comment" = self.copy() + if not obj2 or obj1 == obj2: + return obj1 + + merge_comment = obj1 + return merge_comment diff --git a/declib/artifacts/context.py b/declib/artifacts/context.py new file mode 100644 index 00000000..f497059e --- /dev/null +++ b/declib/artifacts/context.py @@ -0,0 +1,61 @@ +from typing import Optional + +from .artifact import Artifact + + +class Context(Artifact): + ACT_VIEW_OPEN = "view_open" + ACT_MOUSE_CLICK = "mouse_click" + ACT_MOUSE_MOVE = "mouse_move" + ACT_UNKNOWN = "unknown" + + __slots__ = Artifact.__slots__ + ( + "addr", + "func_addr", + "line_number", + "col_number", + "screen_name", + "variable", + "action", + "extras", + ) + + def __init__( + self, + addr: Optional[int] = None, + func_addr: Optional[int] = None, + line_number: Optional[int] = None, + col_number: Optional[int] = None, + screen_name: Optional[str] = None, + variable: Optional[str] = None, + action: Optional[str] = None, + extras: Optional[dict] = None, + **kwargs + ): + self.addr = addr + self.func_addr = func_addr + self.line_number = line_number + self.col_number = col_number + self.screen_name = screen_name + self.variable = variable + self.action: str = action or self.ACT_UNKNOWN + self.extras = extras or {} + super().__init__(**kwargs) + + def __str__(self): + post_text = f" screen={self.screen_name}" if self.screen_name else "" + post_text += f" var={self.variable}" if self.variable else "" + if self.func_addr is not None: + post_text = f"@{hex(self.func_addr)}" + post_text + if self.addr is not None: + post_text = hex(self.addr) + post_text + if self.line_number is not None: + post_text += f" line={self.line_number}" + if self.col_number is not None: + post_text += f" col={self.col_number}" + if self.action != self.ACT_UNKNOWN: + post_text += f" action={self.action}" + if self.extras: + post_text += f" extras={self.extras}" + + return f"" diff --git a/declib/artifacts/decompilation.py b/declib/artifacts/decompilation.py new file mode 100644 index 00000000..88d1e1ca --- /dev/null +++ b/declib/artifacts/decompilation.py @@ -0,0 +1,35 @@ +import toml + +from .artifact import Artifact + + +class Decompilation(Artifact): + __slots__ = Artifact.__slots__ + ( + "addr", + "text", + "line_map", + "decompiler", + "bs_func", + ) + + def __init__( + self, + addr: int = None, + text: str = None, + line_map: dict = None, + decompiler: str = None, + bs_func = None, + **kwargs + ): + super().__init__(**kwargs) + self.addr = addr + self.text = text + self.line_map = line_map or {} + self.decompiler = decompiler + self.bs_func = bs_func + + def __str__(self): + return f"//ADDR: {hex(self.addr)}\n// SOURCE: {self.decompiler}\n{self.text}" + + def __repr__(self): + return f"" diff --git a/declib/artifacts/enum.py b/declib/artifacts/enum.py new file mode 100644 index 00000000..c47944b8 --- /dev/null +++ b/declib/artifacts/enum.py @@ -0,0 +1,53 @@ +from collections import OrderedDict +from typing import Dict + +from .artifact import Artifact + + +class Enum(Artifact): + __slots__ = Artifact.__slots__ + ( + "name", + "members", + ) + + def __init__( + self, + name: str = None, + members: Dict[str, int] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.name = name + # sorts map by the int value + self.members = self._order_members(members) if members else {} + + def __str__(self): + scope_str = f" scope={self.scope}" if self.scope else "" + return f"" + + @staticmethod + def _order_members(members): + return OrderedDict(sorted(members.items(), key=lambda kv: kv[1])) + + def nonconflict_merge(self, enum2: "Enum", **kwargs): + enum1: Enum = self.copy() + if not enum2 or enum1 == enum2: + return enum1.copy() + + master_state = kwargs.get("master_state", None) + local_names = {mem for mem in enum1.members} + if master_state: + for _, enum in master_state.get_enums().items(): + local_names.union(set(enum.members.keys())) + else: + local_names = enum1.members + + constants = { + value for value in enum1.members.values() + } + + for name, constant in enum2.members.items(): + if name in local_names or constant in constants: + continue + enum1.members[name] = constant + return enum1 diff --git a/declib/artifacts/formatting.py b/declib/artifacts/formatting.py new file mode 100644 index 00000000..ece4adba --- /dev/null +++ b/declib/artifacts/formatting.py @@ -0,0 +1,27 @@ +import typing + +from toml import TomlEncoder + +if typing.TYPE_CHECKING: + from ..api import CTypeParser, CType + + +class ArtifactFormat: + TOML = "toml" + JSON = "json" + C_LANG = "c" + + +class TomlHexEncoder(TomlEncoder): + def __init__(self, _dict=dict, preserve=False): + super(TomlHexEncoder, self).__init__(_dict, preserve=preserve) + self.dump_funcs[int] = lambda v: hex(v) if v >= 0 else v + + +def ctype_from_size(size, type_parser: typing.Optional["CTypeParser"] = None) -> "CType": + if type_parser is None: + from ..api.type_parser import CTypeParser + type_parser = CTypeParser() + + ctype = type_parser.size_to_type(size) + return ctype diff --git a/declib/artifacts/func.py b/declib/artifacts/func.py new file mode 100644 index 00000000..4566235c --- /dev/null +++ b/declib/artifacts/func.py @@ -0,0 +1,433 @@ +from typing import Dict, Optional + +from .artifact import Artifact +from .stack_variable import StackVariable + + +# +# Function Header Classes +# + +class FunctionArgument(Artifact): + __slots__ = Artifact.__slots__ + ( + "offset", + "name", + "type", + "size", + ) + + def __init__( + self, + offset: int = None, + name: str = None, + type_: str = None, + size: int = None, + **kwargs + ): + super().__init__(**kwargs) + self.offset = offset + self.name = name + self.type = type_ + self.size = size + + def __str__(self): + return f"" + + +class FunctionHeader(Artifact): + __slots__ = Artifact.__slots__ + ( + "name", + "addr", + "type", + "args" + ) + + def __init__( + self, + name: str = None, + addr: int = None, + type_: str = None, + args: Optional[Dict[int, FunctionArgument]] = None, + **kwargs + ): + super().__init__(**kwargs) + self.name = name + self.addr = addr + self.type = type_ + self.args: Dict = args or {} + + def __str__(self): + return f"" + + def __getstate__(self): + data_dict = super().__getstate__() + args_dict = data_dict["args"] + if args_dict is None: + return data_dict + + new_args_dict = {hex(k): v.__getstate__() for k, v in args_dict.items()} + data_dict["args"] = new_args_dict + return data_dict + + def __setstate__(self, state): + # Pop nested object data and reconstruct in local variable + args_dict = state.pop("args", {}) + new_args_dict = {} + for k, v in args_dict.items(): + fa = FunctionArgument() + fa.__setstate__(v) + new_args_dict[int(k, 0)] = fa + + # Put reconstructed objects back in state + state["args"] = new_args_dict + + # Let super set all attributes at once + super().__setstate__(state) + + def diff(self, other, **kwargs) -> Dict: + diff_dict = {} + # early exit if the two do not match type + if not isinstance(other, FunctionHeader): + for k in ["name", "addr", "type"]: + diff_dict[k] = { + "before": getattr(self, k), + "after": None + } + + diff_dict["args"] = {idx: arg.diff(None) for idx, arg in self.args.items()} + return diff_dict + + # metadata + for k in ["name", "addr", "type"]: + if getattr(self, k) == getattr(other, k): + continue + + diff_dict[k] = { + "before": getattr(self, k), + "after": getattr(other, k) + } + + # args + diff_dict["args"] = {} + for idx, self_arg in self.args.items(): + try: + other_arg = other.args[idx] + except KeyError: + other_arg = None + + diff_dict["args"][idx] = self_arg.diff(other_arg) + + for idx, other_arg in other.args.items(): + if idx in diff_dict["args"]: + continue + + diff_dict["args"][idx] = self.invert_diff(other_arg.diff(None)) + + return diff_dict + + def reset_last_change(self): + if self.args: + for arg in self.args.values(): + arg.reset_last_change() + + def overwrite_merge(self, obj2: "Artifact", **kwargs): + fh2: "FunctionHeader" = obj2 + merged_fh: "FunctionHeader" = self.copy() + if not fh2 or not isinstance(fh2, FunctionHeader) or self == fh2: + return merged_fh + + if fh2.name is not None: + merged_fh.name = fh2.name + if fh2.type is not None: + merged_fh.type = fh2.type + + # header args + for off, var in fh2.args.items(): + if var is not None: + if off in merged_fh.args: + merged_var = merged_fh.args[off].overwrite_merge(var) + else: + merged_var = var + + merged_fh.args[off] = merged_var + + return merged_fh + + def nonconflict_merge(self, fh2: "FunctionHeader", **kwargs): + fh1: "FunctionHeader" = self.copy() + if not fh2 or not isinstance(fh2, FunctionHeader): + return fh1 + + if fh1.name is None: + fh1.name = fh2.name + + if fh1.type is None: + fh1.type = fh2.type + + # header args + for off, var in fh2.args.items(): + merge_var: FunctionArgument = fh1.args[off].copy() if off in fh1.args else var + merge_var = merge_var.nonconflict_merge(var) + fh1.args[off] = merge_var + + return fh1 + + +# +# Full Function Class +# + +class Function(Artifact): + """ + The Function class describes a Function found a decompiler. There are three components to a function: + 1. Metadata + 2. Header + 3. Stack Vars + + The metadata contains info on changes and size. The header holds the return type, + and arguments (including their types). The stack vars contain StackVariables. + """ + + __slots__ = Artifact.__slots__ + ( + "addr", + "size", + "header", + "stack_vars", + "dec_obj", + ) + + def __init__( + self, + addr: int = None, + size: int = None, + header: Optional[FunctionHeader] = None, + stack_vars: Optional[Dict[int, StackVariable]] = None, + dec_obj: Optional[object] = None, + name: str = None, + **kwargs + ): + super().__init__(**kwargs) + # never use dec_obj for comparison, dumping, etc. + self._attr_ignore_set.add("dec_obj") + + self.addr = addr + self.size = size + self.header = header + if name is not None: + self.name = name + self.stack_vars: Dict[int, StackVariable] = stack_vars or {} + + # a special property which can only be set while running inside the decompiler. + # contains a reference to the decompiler object associated with this function. + self.dec_obj = dec_obj + + def __str__(self): + if self.header: + return f"" + + return f"" + + def __getstate__(self): + header = self.header.__getstate__() if self.header else None + stack_vars = { + hex(offset): stack_var.__getstate__() for offset, stack_var in self.stack_vars.items() + } if self.stack_vars else None + + state = super().__getstate__() + # give alias for name and type for convenience + state["name"] = self.name + state["type"] = self.type + state["header"] = header + state["stack_vars"] = stack_vars + return state + + def __setstate__(self, state): + # When pickle calls __setstate__, __init__ is never called + # Initialize _attr_ignore_set and add dec_obj to it (as done in __init__) + if not hasattr(self, '_attr_ignore_set'): + self._attr_ignore_set = set() + self._attr_ignore_set.add("dec_obj") + + # XXX: this is a backport of the old state format. Remove this after a few releases. + if "metadata" in state: + metadata: Dict = state.pop("metadata") + metadata.update(state) + state = metadata + + # Pop nested object data and reconstruct in local variables + header_dat = state.pop("header", None) + if header_dat: + header = FunctionHeader() + header.__setstate__(header_dat) + else: + header = None + + # Handle name/type aliases that override header values + # We modify the header object directly instead of using property setters + # to avoid accessing self.header and self.addr before they're initialized + name_override = state.pop("name", None) + type_override = state.pop("type", None) + + if name_override is not None and header is not None: + header.name = name_override + if type_override is not None and header is not None: + header.type = type_override + + stack_vars_dat = state.pop("stack_vars", {}) + stack_vars = {} + if stack_vars_dat: + for off, stack_var in stack_vars_dat.items(): + sv = StackVariable() + sv.__setstate__(stack_var) + stack_vars[int(off, 0)] = sv + + # Put reconstructed objects back in state + state["header"] = header + state["stack_vars"] = stack_vars + + # Let super set all attributes at once + super().__setstate__(state) + + # dec_obj is intentionally excluded from serialization (it's in + # _attr_ignore_set), so it is never present in `state`. Because Function + # uses __slots__, the attribute would otherwise be unset after a + # deserialization round-trip and any access to `func.dec_obj` (e.g. in + # get_dependencies or a backend's rename path) would raise AttributeError. + if not hasattr(self, "dec_obj"): + self.dec_obj = None + + def diff(self, other, **kwargs) -> Dict: + diff_dict = {} + if not isinstance(other, Function): + # metadata + for k in ["addr", "size"]: + diff_dict[k] = { + "before": getattr(self, k), + "after": None + } + + # header + diff_dict["header"] = self.header.diff(other.header) + # args + diff_dict["stack_vars"] = {off: var.diff(None) for off, var in self.stack_vars.items()} + return diff_dict + + # metadata + for k in ["addr", "size"]: + if getattr(self, k) == getattr(other, k): + continue + + diff_dict[k] = { + "before": getattr(self, k), + "after": getattr(other, k) + } + + # header + if self.header: + diff_dict["header"] = self.header.diff(other.header) + elif other.header: + diff_dict["header"] = self.invert_diff(other.header.diff(None)) + else: + diff_dict["header"] = {"before": None, "after": None} + + # stack vars + diff_dict["stack_vars"] = {} + for off, self_var in self.stack_vars.items(): + try: + other_var = other.stack_vars[off] + except KeyError: + other_var = None + + diff_dict["stack_vars"][off] = self_var.diff(other_var) + + for off, other_var in other.stack_vars.items(): + if off in diff_dict["stack_vars"]: + continue + + diff_dict["stack_vars"][off] = self.invert_diff(other_var.diff(None)) + + return diff_dict + + def reset_last_change(self): + if self.header: + self.header.reset_last_change() + + if self.stack_vars: + for sv in self.stack_vars.values(): + sv.reset_last_change() + + def overwrite_merge(self, obj2: "Artifact", **kwargs): + func2: "Function" = obj2 + merged_func: "Function" = self.copy() + if not func2 or self == func2: + return merged_func + + if merged_func.header is None: + merged_func.header = func2.header.copy() if func2.header else None + + if merged_func.header: + merged_func.header = merged_func.header.overwrite_merge(func2.header) + + for off, var in func2.stack_vars.items(): + if var is not None: + if off in merged_func.stack_vars: + merged_var = merged_func.stack_vars[off].overwrite_merge(var) + else: + merged_var = var + + merged_func.stack_vars[off] = merged_var + + return merged_func + + def nonconflict_merge(self, func2: "Artifact", **kwargs): + func1: "Function" = self.copy() + + if not func2 or func1 == func2: + return func1 + + merge_func: "Function" = func1.copy() + + if merge_func.header is None: + merge_func.header = func2.header.copy() if func2.header else None + elif func2.header is not None: + merge_func.header = merge_func.header.nonconflict_merge(func2.header) + + # stack vars + for off, var in func2.stack_vars.items(): + merge_var = func1.stack_vars[off].copy() if off in func1.stack_vars else var + merge_var = StackVariable.nonconflict_merge(merge_var, var) + + merge_func.stack_vars[off] = merge_var + + return merge_func + + # + # Property Shortcuts (Alias) + # + + @property + def name(self): + return self.header.name if self.header else None + + @name.setter + def name(self, value): + # create a header if one does not exist for this function + if not self.header: + self.header = FunctionHeader(name=None, addr=self.addr) + self.header.name = value + + @property + def type(self): + return self.header.type if self.header else None + + @type.setter + def type(self, value): + # create a header if one does not exist for this function + if not self.header: + self.header = FunctionHeader(name=None, addr=self.addr) + self.header.type = value + + @property + def args(self): + return self.header.args if self.header else {} diff --git a/declib/artifacts/global_variable.py b/declib/artifacts/global_variable.py new file mode 100644 index 00000000..e08ae6e3 --- /dev/null +++ b/declib/artifacts/global_variable.py @@ -0,0 +1,31 @@ +from typing import Optional + +import toml + +from .artifact import Artifact + + +class GlobalVariable(Artifact): + __slots__ = Artifact.__slots__ + ( + "addr", + "name", + "type", + "size" + ) + + def __init__( + self, + addr: int = None, + name: str = None, + type_: Optional[str] = None, + size: int = None, + **kwargs + ): + super().__init__(**kwargs) + self.addr = addr + self.name = name + self.type = type_ + self.size = size + + def __str__(self): + return f"" diff --git a/declib/artifacts/patch.py b/declib/artifacts/patch.py new file mode 100644 index 00000000..3c8f8d95 --- /dev/null +++ b/declib/artifacts/patch.py @@ -0,0 +1,49 @@ +import codecs + +import toml + +from .artifact import Artifact + + +class Patch(Artifact): + """ + Describes a patch on the binary code. + """ + __slots__ = Artifact.__slots__ + ( + "addr", + "name", + "bytes", + ) + + def __init__( + self, + addr: int = None, + bytes_: bytes = None, + name: str = None, + **kwargs + ): + super(Patch, self).__init__(**kwargs) + self.addr = addr + self.name = name + self.bytes = bytes_ + + def __str__(self): + return f"" + + def __getstate__(self): + data_dict = super().__getstate__() + data_dict["bytes"] = codecs.encode(self.bytes, "hex").decode() + return data_dict + + def __setstate__(self, state): + # Pop and decode bytes data + bytes_dat = state.pop("bytes", None) + decoded_bytes = None + if bytes_dat: + decoded_bytes = codecs.decode(bytes_dat, "hex") + + # Put decoded bytes back in state + state["bytes"] = decoded_bytes + + # Let super set all attributes at once + super().__setstate__(state) diff --git a/declib/artifacts/segment.py b/declib/artifacts/segment.py new file mode 100644 index 00000000..024d4edf --- /dev/null +++ b/declib/artifacts/segment.py @@ -0,0 +1,37 @@ +from typing import Optional + +from .artifact import Artifact + + +class Segment(Artifact): + __slots__ = Artifact.__slots__ + ( + "name", + "start_addr", + "end_addr", + "permissions" + ) + + def __init__( + self, + name: str = None, + start_addr: int = None, + end_addr: int = None, + permissions: Optional[str] = None, + **kwargs + ): + super().__init__(**kwargs) + self.name = name + self.start_addr = start_addr + self.end_addr = end_addr + self.permissions = permissions + + def __str__(self): + perms_str = f" [{self.permissions}]" if self.permissions else "" + return f"" + + @property + def size(self) -> Optional[int]: + """Returns the size of the segment in bytes.""" + if self.start_addr is not None and self.end_addr is not None: + return self.end_addr - self.start_addr + return None \ No newline at end of file diff --git a/declib/artifacts/stack_variable.py b/declib/artifacts/stack_variable.py new file mode 100644 index 00000000..782bb169 --- /dev/null +++ b/declib/artifacts/stack_variable.py @@ -0,0 +1,50 @@ +import toml + +from .artifact import Artifact + + +class StackVariable(Artifact): + """ + Describes a stack variable for a given function. + """ + + __slots__ = Artifact.__slots__ + ( + "offset", + "name", + "type", + "size", + "addr", + ) + + def __init__( + self, + stack_offset: int = None, + name: str = None, + type_: str = None, + size: int = None, + addr: int = None, + **kwargs + ): + super().__init__(**kwargs) + self.offset = stack_offset + self.name = name + self.type = type_ + self.size = size + self.addr = addr + + def __str__(self): + return f"" + + @classmethod + def load_many(cls, svs_toml): + for sv_toml in svs_toml.values(): + sv = StackVariable(None, None, None, None, None) + sv.__setstate__(sv_toml) + yield sv + + @classmethod + def dump_many(cls, svs): + d = { } + for v in sorted(svs.values(), key=lambda x: x.addr): + d[hex(v.addr)] = v.__getstate__() + return d diff --git a/declib/artifacts/struct.py b/declib/artifacts/struct.py new file mode 100644 index 00000000..236a11d3 --- /dev/null +++ b/declib/artifacts/struct.py @@ -0,0 +1,184 @@ +from typing import Dict, List, Optional + +import toml + +from .artifact import Artifact +from . import TomlHexEncoder + +import logging +l = logging.getLogger(name=__name__) + + +class StructMember(Artifact): + """ + Describes a struct member that corresponds to a struct. + Offset is the byte offset of the member from the start of the struct. + """ + + __slots__ = Artifact.__slots__ + ( + "name", + "offset", + "type", + "size", + ) + + def __init__( + self, + name: str = None, + offset: int = None, + type_: Optional[str] = None, + size: int = None, + **kwargs + ): + super().__init__(**kwargs) + self.name: str = name + self.offset: int = offset + self.type: str = type_ + self.size: int = size + + def __str__(self): + return f"" + + +class Struct(Artifact): + """ + Describes a struct. + All members are stored by their byte offset from the start of the struct. + """ + + __slots__ = Artifact.__slots__ + ( + "name", + "size", + "members", + ) + + def __init__( + self, + name: str = None, + size: int = None, + members: Dict[int, StructMember] = None, + **kwargs + ): + super().__init__(**kwargs) + self.name = name + self.size = size or 0 + self.members: Dict[int, StructMember] = members or {} + + def __str__(self): + scope_str = f" scope={self.scope}" if self.scope else "" + return f"" + + def __getstate__(self): + data_dict = super().__getstate__() + data_dict["members"] = { + hex(offset): member.__getstate__() for offset, member in self.members.items() + } + + return data_dict + + def __setstate__(self, state): + # XXX: this is a backport of the old state format. Remove this after a few releases. + if "metadata" in state: + metadata: Dict = state.pop("metadata") + metadata.update(state) + state = metadata + + # Pop nested object data and reconstruct in local variable + members_dat = state.pop("members", None) + members = {} + if members_dat: + for off, member in members_dat.items(): + sm = StructMember() + sm.__setstate__(member) + members[int(off, 0)] = sm + + # Put reconstructed objects back in state + state["members"] = members + + # Let super set all attributes at once + super().__setstate__(state) + + def add_struct_member(self, mname, moff, mtype, size): + self.members[moff] = StructMember(mname, moff, mtype, size) + + def append_struct_member(self, mname, mtype, size): + # first, find the next available offset + next_offset = 0 + for off in self.members.keys(): + if off >= next_offset: + next_offset = off + self.members[off].size + self.members[next_offset] = StructMember(mname, next_offset, mtype, size) + + def diff(self, other, **kwargs) -> Dict: + diff_dict = {} + if not isinstance(other, Struct): + return diff_dict + + for k in ["name", "size"]: + if getattr(self, k) == getattr(other, k): + continue + + diff_dict[k] = { + "before": getattr(self, k), + "after": getattr(other, k) + } + + # struct members + diff_dict["members"] = {} + for off, member in self.members.items(): + try: + other_mem = other.members[off] + except KeyError: + other_mem = None + + diff_dict["members"][off] = member.diff(other_mem) + + for off, other_mem in other.members.items(): + if off in diff_dict["members"]: + continue + + diff_dict["members"][off] = self.invert_diff(other_mem.diff(None)) + + return diff_dict + + def nonconflict_merge(self, struct2: "Struct", **kwargs) -> "Struct": + struct1: "Struct" = self.copy() + if not struct2 or struct1 == struct2: + return struct1 + + struct_diff = struct1.diff(struct2) + merge_struct = struct1 + + members_diff = struct_diff["members"] + for off, mem in struct2.members.items(): + # no difference + if off not in members_diff: + continue + + mem_diff = members_diff[off] + + # struct member is newly created + if "before" in mem_diff and mem_diff["before"] is None: + # check for overlap + new_mem_size = mem.size + new_mem_offset = mem.offset + + for off_check in range(new_mem_offset, new_mem_offset + new_mem_size): + if off_check in merge_struct.members: + break + else: + merge_struct.members[off] = mem.copy() + + continue + + # member differs + merge_mem = merge_struct.members.get(off, None) + if not merge_mem: + merge_mem = mem + + merge_mem = StructMember.nonconflict_merge(merge_mem, mem) + merge_struct.members[off] = merge_mem + + # compute the new size + merge_struct.size = sum(mem.size for mem in merge_struct.members.values()) + return merge_struct diff --git a/declib/artifacts/typedef.py b/declib/artifacts/typedef.py new file mode 100644 index 00000000..e4e9a509 --- /dev/null +++ b/declib/artifacts/typedef.py @@ -0,0 +1,59 @@ +from typing import Optional + +from .artifact import Artifact + + +class Typedef(Artifact): + """ + Describe a typedef. As an example: + typedef struct MyStruct { + int a; + int b; + } my_struct_t; + + name="my_struct_t" + type="MyStruct" + + Another example: + typedef int my_int_t; + + name="my_int_t" + type="int" + """ + + __slots__ = Artifact.__slots__ + ( + "name", + "type", + ) + + def __init__( + self, + name: str = None, + type_: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.name: str = name + self.type: str = type_ + + def __str__(self): + scope_str = f" scope={self.scope}" if self.scope else "" + return f"" + + def nonconflict_merge(self, typedef2: "Typedef", **kwargs): + typedef1: Typedef = self.copy() + if not typedef2 or typedef1 == typedef2: + return typedef1.copy() + + master_state = kwargs.get("master_state", None) + local_names = {typedef1.name} + if master_state: + for _, typedef in master_state.get_typedefs().items(): + local_names.add(typedef.name) + else: + local_names = {typedef1.name} + + if typedef2.name not in local_names: + typedef1.name = typedef2.name + typedef1.type = typedef2.type + return typedef1 diff --git a/declib/cli/__init__.py b/declib/cli/__init__.py new file mode 100644 index 00000000..0fe1a95a --- /dev/null +++ b/declib/cli/__init__.py @@ -0,0 +1,3 @@ +from declib.cli.decompiler_cli import main + +__all__ = ["main"] diff --git a/declib/cli/decompiler_cli.py b/declib/cli/decompiler_cli.py new file mode 100644 index 00000000..11acc236 --- /dev/null +++ b/declib/cli/decompiler_cli.py @@ -0,0 +1,1487 @@ +""" +The `decompiler` CLI: a simplified, LLM-friendly interface to declib. + +The CLI is a client that connects to a DecompilerServer. The first `load` of +a binary auto-starts a headless server in the background; subsequent CLI +invocations (including `load`s of other binaries) connect to the right server +via the shared server registry (see declib.api.server_registry). + +Subcommands implemented: +- load start a server on a binary +- list list running servers +- stop stop one or all servers +- list_functions list functions in the binary, optionally filtered by regex +- decompile decompile a function by name or address +- disassemble disassemble a function by name or address +- xref_to data + code references to a target +- xref_from things a function calls (callees) +- rename rename a function or local variable +- create-type define a new struct/enum/typedef from a C string +- retype change the type of a function's variable or argument +- sync copy work on a function from one server into another +- list_strings list strings in the binary, optionally filtered by regex +- get_callers functions (call sites only) that call a target +- read_memory read raw bytes from the binary at an address +- install-skill install the bundled Agent Skill so LLMs learn the CLI +""" +import argparse +import json +import logging +import os +import re +import shutil +import signal +import subprocess +import sys +import time +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +# Standardized exit codes — keep these consistent across subcommands so that +# `&&` chaining and scripts have predictable behavior. +EXIT_OK = 0 +EXIT_USER_ERROR = 1 # user asked for something that didn't happen +EXIT_NOT_FOUND = 1 # missing function/name/binary +EXIT_RUNTIME_ERROR = 1 # unhandled/unknown failure + +from declib.api import server_registry +from declib.decompilers import SUPPORTED_DECOMPILERS +from declib import skills + +_l = logging.getLogger("declib.cli.decompiler") + +_SERVER_START_TIMEOUT = 300.0 # seconds; Ghidra initial analysis can be slow +_SERVER_POLL_INTERVAL = 0.25 + + +def _configure_logging(verbose: bool) -> None: + level = logging.DEBUG if verbose else logging.WARNING + logging.basicConfig(level=level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") + # Keep declib chatter quiet unless --verbose; otherwise INFO logs clobber the CLI output. + if not verbose: + logging.getLogger("declib").setLevel(logging.WARNING) + + +def _parse_target(target: str) -> Tuple[Optional[int], Optional[str]]: + """Parse a user-supplied target into (addr, name). + + Accepts hex (0x...), decimal, or a symbol name. Returns (addr, None) if numeric, + otherwise (None, target). + """ + if target is None: + return None, None + t = target.strip() + if t.lower().startswith("0x"): + try: + return int(t, 16), None + except ValueError: + pass + if t.isdigit(): + try: + return int(t, 10), None + except ValueError: + pass + return None, t + + +def _resolve_function_addr(client, target: str) -> Optional[int]: + """Resolve a function reference to its address using a client. + + Names are resolved by scanning light artifacts. Addresses may be given in either + lifted (relative to base) or lowered (absolute/loaded) form; we match whichever + the server's artifact dict uses. + """ + addr, name = _parse_target(target) + if name is not None: + for _addr, func in client.functions.items(): + if func.name == name: + return _addr + return None + if addr is None: + return None + + # Addresses may be given as absolute; the server exposes lifted addresses. + known = set(client.functions.keys()) + if addr in known: + return addr + try: + base = client.binary_base_addr + except Exception: + base = 0 + if base and addr >= base and (addr - base) in known: + return addr - base + if base and (addr + base) in known: + return addr + base + return addr # let the caller raise if it's truly invalid + + +def _select_server( + server_id: Optional[str], + binary_path: Optional[str], + backend: Optional[str], +) -> Dict: + """Pick a server record from the registry, or error out with a helpful message.""" + records = server_registry.find_servers( + binary_path=binary_path, + backend=backend, + ) + if server_id: + records = [r for r in records if r.get("id") == server_id] + + if not records: + filters = {"id": server_id, "binary_path": binary_path, "backend": backend} + active = {k: v for k, v in filters.items() if v} + raise SystemExit( + "No running decompiler server matches " + f"{active or '(no filters)'}. Start one with `decompiler load `." + ) + if len(records) > 1 and not server_id: + lines = [ + f"{r['id']} backend={r.get('backend')} binary={r.get('binary_path')}" + for r in records + ] + raise SystemExit( + "Multiple servers match. Specify --id to disambiguate:\n " + + "\n ".join(lines) + ) + return records[0] + + +def _connect_client(record: Dict): + from declib.api.decompiler_client import DecompilerClient + + return DecompilerClient(socket_path=record["socket_path"]) + + +def _with_client(args): + """Resolve & connect to the selected server, returning the client.""" + record = _select_server( + server_id=getattr(args, "id", None), + binary_path=getattr(args, "binary", None), + backend=getattr(args, "backend", None), + ) + return _connect_client(record) + + +# --------------------------------------------------------------------------- +# load +# --------------------------------------------------------------------------- + +def _spawn_server( + binary_path: Path, + backend: str, + server_id: str, + project_dir: Optional[Path] = None, +) -> subprocess.Popen: + """Start a detached headless server process for the given binary.""" + cmd = [ + sys.executable, "-m", "declib", + "--server", + "--decompiler", backend, + "--headless", + "--binary-path", str(binary_path), + "--server-id", server_id, + ] + if project_dir is not None: + cmd.extend(["--project-dir", str(project_dir)]) + env = os.environ.copy() + # Inherit env so things like GHIDRA_INSTALL_DIR flow through. + + # Fully detach: new session so Ctrl-C in the CLI won't kill the server. + kwargs = { + "stdout": subprocess.DEVNULL, + "stderr": subprocess.DEVNULL, + "stdin": subprocess.DEVNULL, + "env": env, + "close_fds": True, + } + if os.name == "posix": + kwargs["start_new_session"] = True + else: + kwargs["creationflags"] = getattr(subprocess, "DETACHED_PROCESS", 0) | getattr( + subprocess, "CREATE_NEW_PROCESS_GROUP", 0 + ) + return subprocess.Popen(cmd, **kwargs) + + +def _wait_for_server(server_id: str, timeout: float = _SERVER_START_TIMEOUT) -> Dict: + """Block until a server with `server_id` appears in the registry or timeout.""" + deadline = time.time() + timeout + while time.time() < deadline: + record = server_registry.find_server(server_id=server_id) + if record and record.get("socket_path") and os.path.exists(record["socket_path"]): + return record + time.sleep(_SERVER_POLL_INTERVAL) + raise SystemExit( + f"Timed out waiting {timeout:.0f}s for server {server_id} to start. " + "Check backend dependencies (e.g. GHIDRA_INSTALL_DIR) and retry." + ) + + +def cmd_load(args) -> int: + binary_path = Path(args.binary).expanduser().resolve() + if not binary_path.exists(): + raise SystemExit(f"Binary not found: {binary_path}") + + backend = args.backend + if backend not in SUPPORTED_DECOMPILERS: + raise SystemExit( + f"Unsupported backend {backend!r}; pick one of: {sorted(SUPPORTED_DECOMPILERS)}" + ) + + # Existing server(s) for this binary+backend. + existing = server_registry.find_servers(binary_path=str(binary_path), backend=backend) + if existing and args.replace: + # --replace: tear the old one(s) down first, then start fresh. + for record in existing: + _stop_server_by_record(record) + existing = [] + if existing and not args.force: + record = existing[0] + _emit(args, { + "status": "already_loaded", + "id": record["id"], + "binary_path": record.get("binary_path"), + "backend": record.get("backend"), + "socket_path": record.get("socket_path"), + }) + return 0 + + server_id = args.id or server_registry.new_server_id() + # Default project/database location: a per-binary folder under the user + # cache dir so analysis artifacts don't pollute the binary's directory. + # Pass --project-dir "" to disable and let the backend drop files beside + # the binary (legacy behavior). + project_dir: Optional[Path] + if args.project_dir == "": + project_dir = None + elif args.project_dir is not None: + project_dir = Path(args.project_dir).expanduser().resolve() + else: + project_dir = _default_project_dir(binary_path, backend) + _spawn_server(binary_path, backend, server_id, project_dir=project_dir) + record = _wait_for_server(server_id) + _emit(args, { + "status": "started", + "id": record["id"], + "binary_path": record.get("binary_path"), + "backend": record.get("backend"), + "socket_path": record.get("socket_path"), + "project_dir": str(project_dir) if project_dir is not None else None, + }) + return 0 + + +def _default_project_dir(binary_path: Path, backend: str) -> Path: + """Return a stable per-binary cache dir under the user cache root. + + Keyed by binary name + short hash of the absolute path, so two binaries + with the same basename don't collide. The directory is created lazily + by the backend (Ghidra creates `/_ghidra/`; IDA writes its + `.id*` files directly into ``). + """ + from platformdirs import user_cache_dir + import hashlib + + path_hash = hashlib.sha1(str(binary_path).encode()).hexdigest()[:8] + return Path(user_cache_dir("declib")) / "projects" / f"{binary_path.name}-{path_hash}" + + +# --------------------------------------------------------------------------- +# list / stop +# --------------------------------------------------------------------------- + +def cmd_list(args) -> int: + records = server_registry.list_servers() + registry_dir = str(server_registry._registry_dir()) # type: ignore[attr-defined] + if args.show_registry and not args.json: + print(registry_dir) + return 0 + if args.json: + print(json.dumps({"registry_dir": registry_dir, "servers": records}, indent=2, default=str)) + return 0 + if not records: + print(f"No running decompiler servers. (registry: {registry_dir})") + return 0 + print(f"{'ID':<12} {'BACKEND':<8} {'PID':<8} BINARY") + for r in records: + print(f"{r.get('id',''):<12} {str(r.get('backend','')):<8} {str(r.get('pid','')):<8} {r.get('binary_path','')}") + print(f"\n(registry: {registry_dir})") + return 0 + + +def _stop_server_by_record(record: Dict) -> bool: + """Shut down the server process backing `record`. + + Asks the server to shut itself down gracefully, falling back to SIGTERM/SIGKILL + on the PID if the request fails. Returns True if we believe the process is + gone (or never existed) by the time we return. + """ + from declib.api.decompiler_client import DecompilerClient + + server_id = record.get("id") + pid = record.get("pid") + socket_path = record.get("socket_path") + graceful = False + try: + client = DecompilerClient(socket_path=socket_path) + except Exception as exc: + _l.warning("Could not connect to server %s: %s", server_id, exc) + client = None + if client is not None: + try: + client._send_request({"type": "shutdown_server"}) + graceful = True + except Exception as exc: + _l.debug("shutdown_server rejected by %s: %s", server_id, exc) + client.shutdown() + + if not _wait_for_process_exit(pid, timeout=3.0): + # Graceful request didn't land or server is stuck — escalate. + _signal_process(pid, signal.SIGTERM) + if not _wait_for_process_exit(pid, timeout=2.0): + _signal_process(pid, signal.SIGKILL) + _wait_for_process_exit(pid, timeout=1.0) + + server_registry.unregister_server(server_id) + return graceful or not _process_alive(pid) + + +def _process_alive(pid) -> bool: + if not pid: + return False + try: + import psutil + + return psutil.pid_exists(int(pid)) + except Exception: + return False + + +def _signal_process(pid, sig) -> None: + if not pid: + return + try: + os.kill(int(pid), sig) + except ProcessLookupError: + return + except Exception as exc: + _l.debug("Signal %s to pid %s failed: %s", sig, pid, exc) + + +def _wait_for_process_exit(pid, timeout: float) -> bool: + if not pid: + return True + deadline = time.time() + timeout + while time.time() < deadline: + if not _process_alive(pid): + return True + time.sleep(0.05) + return not _process_alive(pid) + + +def cmd_stop(args) -> int: + records = server_registry.list_servers() + if args.all: + targets = records + elif args.id: + targets = [r for r in records if r.get("id") == args.id] + elif args.binary: + targets = server_registry.find_servers(binary_path=args.binary) + else: + raise SystemExit("decompiler stop needs --id, --binary, or --all") + + if not targets: + raise SystemExit("No matching server to stop") + + results = [] + for record in targets: + ok = _stop_server_by_record(record) + results.append({"id": record.get("id"), "stopped": bool(ok)}) + _emit(args, {"stopped": results}) + return 0 + + +# --------------------------------------------------------------------------- +# decompile / disassemble +# --------------------------------------------------------------------------- + +def _known_function_addrs(client) -> set: + try: + return set(client.functions.keys()) + except Exception: + return set() + + +def cmd_decompile(args) -> int: + with _with_client(args) as client: + addr = _resolve_function_addr(client, args.target) + known = _known_function_addrs(client) + if addr is None: + raise SystemExit(f"Function not found: {args.target!r}") + if known and addr not in known: + raise SystemExit( + f"No function starts at 0x{addr:x}. " + f"Try `decompiler list_functions --filter '{args.target}'` or " + "pick a function-start address." + ) + dec = client.decompile(addr) + if dec is None: + raise SystemExit( + f"Decompiler engine returned no result for 0x{addr:x}. " + "The address is a known function start, but decompilation " + "failed — this usually means the backend can't handle this " + "function (unreachable code, ARM/x86 mode mismatch, etc.)." + ) + text = dec.text if hasattr(dec, "text") else str(dec) + if getattr(args, "raw", False): + # --raw: dump just the text body to stdout, regardless of --json. + print(text) + return 0 + out = { + "addr": addr, + "decompiler": dec.decompiler if hasattr(dec, "decompiler") else None, + "text": text, + } + _emit(args, out, text_field="text") + return 0 + + +def cmd_disassemble(args) -> int: + with _with_client(args) as client: + addr = _resolve_function_addr(client, args.target) + known = _known_function_addrs(client) + if addr is None: + raise SystemExit(f"Function not found: {args.target!r}") + if known and addr not in known: + raise SystemExit( + f"No function starts at 0x{addr:x}. " + f"Try `decompiler list_functions --filter '{args.target}'` or " + "pick a function-start address." + ) + text = client.disassemble(addr) + if text is None: + raise SystemExit( + f"Disassembler returned no instructions for 0x{addr:x} " + "(likely a function too small to disassemble or a backend bug)." + ) + if getattr(args, "raw", False): + print(text) + return 0 + _emit(args, {"addr": addr, "text": text}, text_field="text") + return 0 + + +def cmd_list_functions(args) -> int: + with _with_client(args) as client: + pattern = re.compile(args.filter) if args.filter else None + entries: List[Dict] = [] + for addr, func in sorted(client.functions.items(), key=lambda kv: kv[0]): + name = getattr(func, "name", None) or "" + if pattern and not pattern.search(name): + continue + size = getattr(func, "size", 0) or 0 + entries.append({"addr": addr, "size": int(size), "name": name}) + + if args.json: + _emit_list(args, entries) + else: + if not entries: + print("No functions matched.") + return 0 + print(f"{'ADDR':<12} {'SIZE':<8} NAME") + for e in entries: + print(f"0x{e['addr']:<10x} {e['size']:<8} {e['name']}") + return 0 + + +# --------------------------------------------------------------------------- +# xrefs +# --------------------------------------------------------------------------- + +def _format_xref(artifact) -> Dict: + """Render any artifact (Function, GlobalVariable, etc.) as a uniform dict. + + Unlike `_format_function`, this keeps the artifact kind so callers can + tell code refs apart from data refs. + """ + return { + "kind": type(artifact).__name__, + "addr": getattr(artifact, "addr", None), + "name": getattr(artifact, "name", None), + } + + +def cmd_xref_to(args) -> int: + """All references — code and data — to the target. + + Note: distinct from `get_callers`, which is call-sites only. `xref_to` + here asks the backend for *every* artifact that points at the target, + including globals, strings, and non-call code references. + + Resolution order for ``target``: + 1. Function name or address that matches a known function — use the + function-level xref path (entry-point references). + 2. A raw numeric address or a string literal surfaced by `list_strings` + — use the raw-address xref path (data refs to strings, globals, etc.). + """ + from declib.artifacts import Function + + with _with_client(args) as client: + parsed_addr, parsed_name = _parse_target(args.target) + func_addr = _resolve_function_addr(client, args.target) + known = _known_function_addrs(client) + is_function_target = func_addr is not None and (not known or func_addr in known) + + resolved_addr: Optional[int] = None + target_kind: str # "function" | "address" | "string" + + if is_function_target: + resolved_addr = func_addr + target_kind = "function" + elif parsed_addr is not None: + # Raw address that isn't a function start — try data xrefs. + resolved_addr = parsed_addr + target_kind = "address" + elif parsed_name is not None: + # Treat as a string literal: find that string and xref its address. + match = _find_string_addr(client, parsed_name) + if match is None: + raise SystemExit( + f"Not found: {args.target!r} is not a function, address, " + "or known string. Try `decompiler list_strings --filter " + f"'{parsed_name}'` to search." + ) + resolved_addr = match + target_kind = "string" + else: + raise SystemExit(f"Function not found: {args.target!r}") + + xrefs: List = [] + if target_kind == "function": + func_stub = Function(resolved_addr, 0) + try: + xrefs = client.xrefs_to(func_stub, decompile=bool(args.decompile)) + except Exception as exc: + _l.debug("xrefs_to raised %s; falling back to get_callers", exc) + xrefs = client.get_callers(resolved_addr) + else: + try: + xrefs = client.xrefs_to_addr(resolved_addr) + except Exception as exc: + _l.debug("xrefs_to_addr raised %s; returning empty", exc) + xrefs = [] + + # Enrich Function entries with names from the light artifact cache, + # since some backends only return (addr, 0) stubs from xrefs_to. + light_funcs = dict(client.functions.items()) + data: List[Dict] = [] + for x in xrefs: + entry = _format_xref(x) + if entry["kind"] == "Function" and not entry.get("name"): + func = light_funcs.get(entry.get("addr")) + if func is not None: + entry["name"] = getattr(func, "name", None) + data.append(entry) + _emit_xrefs(args, resolved_addr, data, direction="to", target_kind=target_kind) + return 0 + + +def _find_string_addr(client, value: str) -> Optional[int]: + """Look up the address of a string literal (exact match, then substring).""" + try: + strings = client.list_strings() or [] + except Exception: + return None + exact = [addr for addr, text in strings if text == value] + if exact: + return exact[0] + contains = [addr for addr, text in strings if value in text] + if contains: + return contains[0] + return None + + +def cmd_xref_from(args) -> int: + """Return the callees of a function (what the function calls). + + Implementation: + 1. Use the backend's native per-function callee query (`xrefs_from`). + 2. Fall back to parsing `call 0x…` from disassembly when the backend + returns nothing. + """ + with _with_client(args) as client: + addr = _resolve_function_addr(client, args.target) + if addr is None: + raise SystemExit(f"Function not found: {args.target!r}") + + callees: List[Dict] = [] + seen = set() + try: + for callee in client.xrefs_from(addr): + callee_addr = getattr(callee, "addr", None) + if callee_addr in seen: + continue + seen.add(callee_addr) + callees.append(_format_xref(callee)) + except Exception as exc: + _l.debug("xrefs_from failed (%s); falling back to disasm scan.", exc) + + if not callees: + # Fallback: parse `call 0x...` from disassembly. + disasm = client.disassemble(addr) or "" + call_re = re.compile(r"\bcall\b[^0-9]*0x([0-9a-fA-F]+)") + functions_by_addr = dict(client.functions.items()) + for match in call_re.finditer(disasm): + try: + callee_addr = int(match.group(1), 16) + except ValueError: + continue + if callee_addr in seen: + continue + seen.add(callee_addr) + func = functions_by_addr.get(callee_addr) + callees.append({ + "kind": "Function", + "addr": callee_addr, + "name": func.name if func else None, + }) + + # Enrich entries that came back without a name but whose addr is known + # from the light artifact cache. + if callees: + light_funcs = dict(client.functions.items()) + for entry in callees: + if entry.get("kind") == "Function" and not entry.get("name"): + func = light_funcs.get(entry.get("addr")) + if func is not None: + entry["name"] = getattr(func, "name", None) + + _emit_xrefs(args, addr, callees, direction="from") + return 0 + + +def _emit_xrefs( + args, + addr: int, + xrefs: List[Dict], + *, + direction: str, + target_kind: Optional[str] = None, +) -> None: + payload: Dict = {"addr": addr, "direction": direction, "xrefs": xrefs} + if target_kind is not None: + payload["target_kind"] = target_kind + if args.json: + print(json.dumps(_annotate_addrs(payload), indent=2, default=str)) + return + if not xrefs: + print(f"No xrefs {direction} {_format_addr_hex(addr)}") + return + for x in xrefs: + a = x.get("addr") + n = x.get("name") or "" + kind = x.get("kind") or "" + a_str = _format_addr_hex(a) if isinstance(a, int) else "?" + if kind: + print(f"{a_str}\t{kind}\t{n}") + else: + print(f"{a_str}\t{n}") + + +# --------------------------------------------------------------------------- +# rename +# --------------------------------------------------------------------------- + +def cmd_rename(args) -> int: + kind = args.kind + with _with_client(args) as client: + if kind == "func": + addr = _resolve_function_addr(client, args.target) + if addr is None: + raise SystemExit(f"Function not found: {args.target!r}") + func = client.functions[addr] + if not func: + raise SystemExit(f"Could not load function at 0x{addr:x}") + func.name = args.new_name + if func.header is not None: + func.header.name = args.new_name + ok = bool(client.set_artifact(func)) + _emit(args, {"kind": "func", "addr": addr, "new_name": args.new_name, "success": ok}) + return EXIT_OK if ok else EXIT_USER_ERROR + elif kind == "var": + if not args.function: + raise SystemExit("--function is required when renaming a variable") + func_addr = _resolve_function_addr(client, args.function) + if func_addr is None: + raise SystemExit(f"Function not found: {args.function!r}") + func = client.functions[func_addr] + if not func: + raise SystemExit(f"Could not load function at 0x{func_addr:x}") + name_map = {args.target: args.new_name} + ok = bool(client.rename_local_variables_by_names(func, name_map)) + _emit(args, {"kind": "var", "function_addr": func_addr, + "old_name": args.target, "new_name": args.new_name, + "success": ok}) + return EXIT_OK if ok else EXIT_USER_ERROR + raise SystemExit(f"Unknown rename kind: {kind}") + + +# --------------------------------------------------------------------------- +# create-type / retype +# --------------------------------------------------------------------------- + +def cmd_create_type(args) -> int: + """Define a new struct/enum/typedef from a C string and apply it. + + The definition is parsed (client-side, decompiler-free) into a declib + artifact and pushed through the normal type-setting path, which works + across every backend. + """ + from declib.api.type_definition_parser import ( + parse_type_definition, TypeDefinitionParseError, + ) + + try: + artifact = parse_type_definition(args.definition) + except TypeDefinitionParseError as exc: + # Fail before connecting — nothing about a server changes the parse. + raise SystemExit(f"Could not parse type definition: {exc}") + + with _with_client(args) as client: + ok = bool(client.set_artifact(artifact)) + _emit(args, { + "kind": type(artifact).__name__, + "name": artifact.name, + "size": getattr(artifact, "size", None), + "members": len(artifact.members) if hasattr(artifact, "members") else None, + "success": ok, + }) + return EXIT_OK if ok else EXIT_USER_ERROR + + +def _find_variable(func, var_name: str): + """Locate a variable by name in a function. Returns (kind, var) or (None, None). + + Stack variables are checked before arguments. ``kind`` is "stack" or "arg". + """ + for svar in func.stack_vars.values(): + if svar.name == var_name: + return "stack", svar + if func.header is not None: + for arg in func.header.args.values(): + if arg.name == var_name: + return "arg", arg + return None, None + + +def _compute_type_size(client, type_str: str) -> int: + """Best-effort byte size for a (possibly user-defined) type string.""" + ctype = client.type_parser.parse_type(type_str) + if ctype is not None and ctype.size: + return ctype.size + # Unknown/user-defined non-pointer type (e.g. a struct by value): ask the + # backend what it already knows about this type. + try: + defined = client.get_defined_type(type_str) + except Exception: + defined = None + size = getattr(defined, "size", None) + if size: + return size + # 0 means "let the backend infer the size". + return ctype.size if ctype is not None else 0 + + +def cmd_retype(args) -> int: + """Change the type of a local variable or argument of a function.""" + with _with_client(args) as client: + func_addr = _resolve_function_addr(client, args.function) + if func_addr is None: + raise SystemExit(f"Function not found: {args.function!r}") + func = client.functions[func_addr] + if not func: + raise SystemExit(f"Could not load function at 0x{func_addr:x}") + + kind, var = _find_variable(func, args.variable) + if var is None: + raise SystemExit( + f"Variable {args.variable!r} not found in {args.function!r}. " + "Check the name (it is case-sensitive)." + ) + + var.type = args.new_type + var.size = _compute_type_size(client, args.new_type) + + ok = bool(client.set_artifact(func)) + if not ok: + raise SystemExit( + f"Backend rejected retype of {args.variable!r} to {args.new_type!r}." + ) + + # Re-read so the caller can see what the backend actually stored. + refreshed = client.functions[func_addr] + _, new_var = _find_variable(refreshed, args.variable) + _emit(args, { + "function_addr": func_addr, + "variable": args.variable, + "kind": kind, + "new_type": args.new_type, + "applied_type": getattr(new_var, "type", None) if new_var else None, + "success": ok, + }) + return EXIT_OK + + +# --------------------------------------------------------------------------- +# sync +# --------------------------------------------------------------------------- + +def cmd_sync(args) -> int: + """Copy work on a function from one running server into another. + + Source is selected by --from-id; destination by the usual + --id/--binary/--backend. Syncs the function's referenced user types + (struct/enum/typedef) first, then the function header (name/return/args) + and stack variables. Addresses and stack offsets are canonical in lifted + form, so they re-key correctly on the destination even if it names the + function differently. + """ + from declib.artifacts import Struct, Enum, Typedef + + src_record = _select_server(server_id=args.from_id, binary_path=None, backend=None) + dst_record = _select_server( + server_id=getattr(args, "id", None), + binary_path=getattr(args, "binary", None), + backend=getattr(args, "backend", None), + ) + if src_record.get("id") == dst_record.get("id"): + raise SystemExit( + f"Source and destination are the same server (id={src_record.get('id')}). " + "Pick two different servers." + ) + + src = _connect_client(src_record) + dst = None + try: + dst = _connect_client(dst_record) + + addr = _resolve_function_addr(src, args.target) + known = _known_function_addrs(src) + if addr is None or (known and addr not in known): + raise SystemExit(f"Function not found on source: {args.target!r}") + + src_func = src.functions[addr] + if not src_func: + raise SystemExit(f"Could not load function at 0x{addr:x} on source") + + # 1) Sync referenced user types first so retypes resolve on the dest. + synced_types, failed_types = [], [] + try: + deps = src.get_dependencies(src_func, decompile=True) + except Exception as exc: + _l.debug("get_dependencies failed: %s", exc) + deps = [] + for dep in deps: + if isinstance(dep, (Struct, Enum, Typedef)): + name = getattr(dep, "name", None) + try: + ok_dep = bool(dst.set_artifact(dep)) + except Exception as exc: + _l.debug("type sync failed for %s: %s", name, exc) + ok_dep = False + (synced_types if ok_dep else failed_types).append(name) + + # 2) Sync the function header + stack vars in one shot. + func_ok = bool(dst.set_artifact(src_func)) + + synced_vars = sorted( + (sv.offset, getattr(sv, "name", None)) + for sv in (src_func.stack_vars or {}).values() + ) + _emit(args, { + "target": args.target, + "addr": addr, + "from_id": src_record.get("id"), + "to_id": dst_record.get("id"), + "function_name": src_func.name, + "synced_types": synced_types, + "failed_types": failed_types, + "synced_stack_vars": [{"offset": off, "name": nm} for off, nm in synced_vars], + "success": func_ok, + }) + return EXIT_OK if func_ok else EXIT_USER_ERROR + finally: + src.shutdown() + if dst is not None: + dst.shutdown() + + +# --------------------------------------------------------------------------- +# list_strings / get_callers (new core APIs) +# --------------------------------------------------------------------------- + +def cmd_list_strings(args) -> int: + """List strings the decompiler has identified in the binary. + + This surfaces exactly what the backend's own string detector produced — + nothing more, nothing less. Decompilers miss things (angr in particular + is thin on `.rodata`), so if this looks sparse, reach for an external + tool (`strings(1)`, `rabin2 -z`, `readelf -p .rodata`) to get the + complete picture. + """ + with _with_client(args) as client: + native = client.list_strings(filter=args.filter) or [] + + results: List[Dict] = [] + for addr, s in native: + if len(s) < args.min_length: + continue + results.append({"addr": addr, "string": s}) + + # Sort by addr. + results.sort(key=lambda e: e.get("addr", 0)) + + if args.json: + _emit_list(args, results) + else: + for entry in results: + print(f"{_format_addr_hex(entry['addr'])}\t{entry['string']}") + return 0 + + +def cmd_get_callers(args) -> int: + """Functions that contain a call to the target (call-sites only). + + Distinct from `xref_to`, which returns every reference (code *or* data). + If you want the full reference set, use `xref_to` instead. + """ + with _with_client(args) as client: + # Reuse the resolver so absolute addresses get normalized to the lifted + # form the server expects. + resolved = _resolve_function_addr(client, args.target) + if resolved is None: + raise SystemExit(f"Function not found: {args.target!r}") + try: + callers = client.get_callers(resolved) + except ValueError as exc: + raise SystemExit(str(exc)) + data = [_format_xref(c) for c in callers] + if args.json: + _emit(args, {"target": args.target, "target_addr": resolved, "callers": data}) + else: + if not data: + print(f"No callers found for {args.target!r}") + else: + for entry in data: + a = entry.get("addr") + n = entry.get("name") or "" + print(f"0x{a:x}\t{n}" if a is not None else f"?\t{n}") + return 0 + + +# --------------------------------------------------------------------------- +# read_memory +# --------------------------------------------------------------------------- + +def cmd_read_memory(args) -> int: + """Read ``size`` bytes from the binary starting at ``addr``. + + Address accepts hex (``0x...``) or decimal. Output defaults to a hex+ascii + dump; use ``--format hex`` for a single hex blob, ``--format raw`` to write + raw bytes to stdout, or ``--json`` for a JSON envelope with the bytes + base64-encoded. + """ + import base64 + + addr_value, name = _parse_target(args.addr) + if addr_value is None: + raise SystemExit( + f"Invalid address {args.addr!r}; expected hex (0x..) or decimal." + ) + if args.size <= 0: + raise SystemExit(f"--size must be > 0 (got {args.size})") + + with _with_client(args) as client: + data = client.read_memory(addr_value, args.size) + if data is None: + raise SystemExit( + f"Backend could not read 0x{args.size:x} bytes at " + f"{_format_addr_hex(addr_value)}. The address may be " + "uninitialized, unmapped, or outside any loaded segment." + ) + # Some backends return short reads when the request straddles the + # end of a mapped region; surface that in the JSON output and warn + # in text mode so the caller knows. + actual_size = len(data) + + if args.format == "raw" and not args.json: + sys.stdout.buffer.write(data) + return 0 + + if args.json: + payload = { + "addr": addr_value, + "size": actual_size, + "requested_size": args.size, + "bytes_b64": base64.b64encode(data).decode("ascii"), + "hex": data.hex(), + } + print(json.dumps(_annotate_addrs(payload), indent=2, default=str)) + return 0 + + if args.format == "hex": + print(data.hex()) + return 0 + + # Default: hexdump-style output. + for line in _hexdump(data, base_addr=addr_value): + print(line) + if actual_size < args.size: + print( + f"# short read: got {actual_size} of {args.size} requested bytes", + file=sys.stderr, + ) + return 0 + + +def _hexdump(data: bytes, *, base_addr: int = 0, width: int = 16) -> List[str]: + """Return a list of hexdump lines like ``addr: hh hh ... |ascii|``.""" + lines: List[str] = [] + for offset in range(0, len(data), width): + chunk = data[offset:offset + width] + hex_part = " ".join(f"{b:02x}" for b in chunk) + # Pad short final lines so the ASCII column stays aligned. + hex_part = hex_part.ljust(width * 3 - 1) + ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk) + lines.append(f"{_format_addr_hex(base_addr + offset)}: {hex_part} |{ascii_part}|") + return lines + + +# --------------------------------------------------------------------------- +# install-skill +# --------------------------------------------------------------------------- + +_SKILL_AGENT_CHOICES = ("claude", "codex", "all") + + +def _codex_skill_dest() -> Path: + codex_home = os.environ.get("CODEX_HOME") + if codex_home: + return Path(codex_home).expanduser() / "skills" + return Path(os.path.expanduser("~/.codex/skills")) + + +def _skill_dest_for_agent(agent: str) -> Path: + if agent == "claude": + return Path(os.path.expanduser("~/.claude/skills")) + if agent == "codex": + return _codex_skill_dest() + raise ValueError(f"Unknown skill agent: {agent!r}") + + +def _default_skill_agents() -> List[str]: + # Codex sets CODEX_* env vars in its execution environment. Prefer its + # skill directory there, while preserving Claude as the normal shell default. + if any(key.startswith("CODEX_") for key in os.environ): + return ["codex"] + return ["claude"] + + +def _selected_skill_agents(raw_agents: Optional[List[str]]) -> List[str]: + agents = raw_agents or _default_skill_agents() + if "all" in agents: + agents = ["claude", "codex"] + + selected: List[str] = [] + for agent in agents: + if agent not in ("claude", "codex"): + raise SystemExit( + f"Unsupported skill agent {agent!r}; pick one of: claude, codex, all" + ) + if agent not in selected: + selected.append(agent) + return selected + + +def _skill_destinations(args) -> List[Tuple[str, Path]]: + if args.dest: + if args.agent: + raise SystemExit("--dest cannot be combined with --agent") + return [("custom", Path(args.dest).expanduser().resolve())] + + return [ + (agent, _skill_dest_for_agent(agent).expanduser().resolve()) + for agent in _selected_skill_agents(args.agent) + ] + + +def cmd_install_skill(args) -> int: + names = args.names or skills.available_skills() + if not names: + raise SystemExit("No bundled skills to install") + + installed: List[Dict] = [] + for agent, dest_root in _skill_destinations(args): + dest_root.mkdir(parents=True, exist_ok=True) + for name in names: + src = skills.skill_path(name) + dest = dest_root / name + if dest.exists() and not args.force: + raise SystemExit( + f"Skill already exists at {dest}. Pass --force to overwrite." + ) + if dest.exists() and args.force: + shutil.rmtree(dest) + shutil.copytree(src, dest) + installed.append({"name": name, "agent": agent, "path": str(dest)}) + + if args.json: + print(json.dumps({"installed": installed}, indent=2, default=str)) + else: + for entry in installed: + agent = "" if entry["agent"] == "custom" else f" ({entry['agent']})" + print(f"installed {entry['name']}{agent} -> {entry['path']}") + return 0 + + +# --------------------------------------------------------------------------- +# shared helpers +# --------------------------------------------------------------------------- + +def _annotate_addrs(payload): + """Recursively add `*_hex` siblings for every `*addr` integer field. + + JSON historically emitted addresses as decimals; feedback was that this + is awkward when copying from one command to another. Instead of breaking + existing int fields, we add a sibling hex-string field so both forms + are available. A key named `addr` gets `addr_hex`, `target_addr` gets + `target_addr_hex`, `function_addr` gets `function_addr_hex`, etc. + """ + if isinstance(payload, dict): + for key in list(payload.keys()): + value = payload[key] + if ( + (key == "addr" or key.endswith("_addr")) + and isinstance(value, int) + and f"{key}_hex" not in payload + ): + payload[f"{key}_hex"] = _format_addr_hex(value) + for v in payload.values(): + _annotate_addrs(v) + elif isinstance(payload, list): + for item in payload: + _annotate_addrs(item) + return payload + + +def _format_addr_hex(value: int) -> str: + """Format an address as `0x`, normalizing negatives to unsigned 64-bit. + + Some backends (Ghidra in particular) can surface java-signed long values + for synthetic addresses. Emitting `0x-100000` in JSON is useless — render + those as their unsigned-64 equivalent so downstream consumers always see + a well-formed hex address. + """ + if value < 0: + value &= (1 << 64) - 1 + return f"0x{value:x}" + + +def _emit(args, payload: Dict, *, text_field: Optional[str] = None) -> None: + """Emit a response either as JSON or as a human-readable block.""" + if args.json: + print(json.dumps(_annotate_addrs(payload), indent=2, default=str)) + return + if text_field and text_field in payload: + print(payload[text_field]) + return + # Default: key: value lines + for k, v in payload.items(): + print(f"{k}: {v}") + + +def _emit_list(args, payload): + """Same as _emit but for a top-level list payload (JSON arrays).""" + if args.json: + print(json.dumps(_annotate_addrs(payload), indent=2, default=str)) + return + # Fallback: print each item on its own line as "key: value" pairs if + # it's a dict; otherwise str(item). + for item in payload: + if isinstance(item, dict): + print(" ".join(f"{k}={v}" for k, v in item.items())) + else: + print(item) + + +def _format_function(func) -> Dict: + return { + "addr": getattr(func, "addr", None), + "name": getattr(func, "name", None), + } + + +# --------------------------------------------------------------------------- +# argparse plumbing +# --------------------------------------------------------------------------- + +def _add_server_filter_args(p: argparse.ArgumentParser) -> None: + p.add_argument("--id", dest="id", help="Server ID to target (see `decompiler list`).") + p.add_argument("--binary", dest="binary", help="Match server by binary path.") + p.add_argument("--backend", dest="backend", choices=sorted(SUPPORTED_DECOMPILERS), help="Match server by backend.") + + +def _add_output_args(p: argparse.ArgumentParser) -> None: + p.add_argument("--json", action="store_true", help="Emit JSON output instead of text.") + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="decompiler", + description=( + "LLM-friendly decompiler CLI powered by DecLib. " + "Load a binary once, then run decompile/disassemble/xref/rename " + "commands. Multiple binaries/backends can run concurrently." + ), + ) + parser.add_argument("-v", "--verbose", action="store_true", help="Enable debug logging.") + sub = parser.add_subparsers(dest="cmd", required=True) + + # load + p_load = sub.add_parser("load", help="Load a binary, starting a server if needed.") + p_load.add_argument("binary", help="Path to the binary to analyze.") + p_load.add_argument("--backend", default="angr", choices=sorted(SUPPORTED_DECOMPILERS), + help="Backend decompiler to use (default: angr).") + p_load.add_argument("--id", dest="id", help="Explicit server ID (otherwise auto-generated).") + p_load.add_argument("--force", action="store_true", + help="Start a new server even if one already exists for this binary.") + p_load.add_argument("--replace", action="store_true", + help="Stop the existing server for this binary+backend (if any) before starting.") + p_load.add_argument( + "--project-dir", + dest="project_dir", + help=( + "Where the backend should store its project/database files " + "(Ghidra project, IDA .id*, etc.). Default: a per-binary folder " + "under the user cache dir. Pass '' to drop files next to the binary." + ), + ) + _add_output_args(p_load) + p_load.set_defaults(func=cmd_load) + + # list + p_list = sub.add_parser("list", help="List running decompiler servers.") + p_list.add_argument("--show-registry", action="store_true", + help="Print just the registry directory path and exit.") + _add_output_args(p_list) + p_list.set_defaults(func=cmd_list) + + # list_functions + p_lf = sub.add_parser("list_functions", help="List functions in the binary.") + p_lf.add_argument("--filter", dest="filter", help="Regex to filter function names.") + _add_server_filter_args(p_lf) + _add_output_args(p_lf) + p_lf.set_defaults(func=cmd_list_functions) + + # stop + p_stop = sub.add_parser("stop", help="Stop a running server.") + p_stop.add_argument("--id", dest="id", help="Server ID to stop.") + p_stop.add_argument("--binary", dest="binary", help="Stop servers for this binary.") + p_stop.add_argument("--all", action="store_true", help="Stop every running server.") + _add_output_args(p_stop) + p_stop.set_defaults(func=cmd_stop) + + # decompile + p_dec = sub.add_parser("decompile", help="Decompile a function by name or address.") + p_dec.add_argument("target", help="Function name or address (hex/decimal).") + p_dec.add_argument("--raw", action="store_true", + help="Print the decompilation text directly (no JSON or header wrapping).") + _add_server_filter_args(p_dec) + _add_output_args(p_dec) + p_dec.set_defaults(func=cmd_decompile) + + # disassemble + p_dis = sub.add_parser("disassemble", help="Disassemble a function by name or address.") + p_dis.add_argument("target", help="Function name or address (hex/decimal).") + p_dis.add_argument("--raw", action="store_true", + help="Print the disassembly text directly (no JSON or header wrapping).") + _add_server_filter_args(p_dis) + _add_output_args(p_dis) + p_dis.set_defaults(func=cmd_disassemble) + + # xref_to + p_xto = sub.add_parser( + "xref_to", + help=( + "Every reference (code AND data) to a target. " + "For call-sites only, see `get_callers`." + ), + ) + p_xto.add_argument("target", help="Function name or address (hex/decimal).") + p_xto.add_argument("--decompile", action="store_true", + help="Ask the backend to decompile first (picks up more refs on Ghidra).") + _add_server_filter_args(p_xto) + _add_output_args(p_xto) + p_xto.set_defaults(func=cmd_xref_to) + + # xref_from + p_xfrom = sub.add_parser("xref_from", help="Things a function calls (callees).") + p_xfrom.add_argument("target", help="Function name or address (hex/decimal).") + _add_server_filter_args(p_xfrom) + _add_output_args(p_xfrom) + p_xfrom.set_defaults(func=cmd_xref_from) + + # rename + p_ren = sub.add_parser("rename", help="Rename a function or a local variable.") + p_ren.add_argument("kind", choices=["func", "var"], help="What to rename.") + p_ren.add_argument("target", help="Function name/address (for `func`) or variable name (for `var`).") + p_ren.add_argument("new_name", help="New name.") + p_ren.add_argument("--function", help="When renaming a variable, the containing function.") + _add_server_filter_args(p_ren) + _add_output_args(p_ren) + p_ren.set_defaults(func=cmd_rename) + + # create-type + p_ct = sub.add_parser( + "create-type", + help=( + "Define a new struct, enum, or typedef from a C definition string " + "and apply it to the binary's type database." + ), + ) + p_ct.add_argument( + "definition", + help=( + 'C type definition, e.g. "struct Point { int x; int y; }", ' + '"enum Color { RED, GREEN, BLUE }", or "typedef int my_int_t".' + ), + ) + _add_server_filter_args(p_ct) + _add_output_args(p_ct) + p_ct.set_defaults(func=cmd_create_type) + + # retype + p_rt = sub.add_parser( + "retype", + help="Change the type of a function's local variable or argument.", + ) + p_rt.add_argument("function", help="Function name or address (hex/decimal).") + p_rt.add_argument("variable", help="Variable (stack var or arg) name to retype.") + p_rt.add_argument("new_type", help='New C type, e.g. "int", "double", "Point *".') + _add_server_filter_args(p_rt) + _add_output_args(p_rt) + p_rt.set_defaults(func=cmd_retype) + + # sync + p_sync = sub.add_parser( + "sync", + help=( + "Copy work on a function (name, return/arg types, stack vars, and " + "referenced user types) from one running server into another for " + "the same binary. Source = --from-id; destination = --id/--binary/--backend." + ), + ) + p_sync.add_argument("target", help="Function name or address (hex/decimal) on the source.") + p_sync.add_argument( + "--from-id", dest="from_id", required=True, + help="Source server ID to copy work FROM (see `decompiler list`).", + ) + _add_server_filter_args(p_sync) + _add_output_args(p_sync) + p_sync.set_defaults(func=cmd_sync) + + # list_strings + p_ls = sub.add_parser( + "list_strings", + help=( + "List strings the decompiler identified in the binary. " + "Fidelity varies by backend (angr < ghidra < ida) and may be " + "incomplete — use external tools (strings(1), rabin2 -z, " + "readelf -p) for an exhaustive scan." + ), + ) + p_ls.add_argument("--filter", dest="filter", help="Regex to filter strings.") + p_ls.add_argument("--min-length", dest="min_length", type=int, default=4, + help="Minimum string length to keep (default: 4).") + _add_server_filter_args(p_ls) + _add_output_args(p_ls) + p_ls.set_defaults(func=cmd_list_strings) + + # get_callers + p_gc = sub.add_parser( + "get_callers", + help=( + "Functions that call a target (call-sites only). " + "For every reference (code AND data), see `xref_to`." + ), + ) + p_gc.add_argument("target", help="Function name or address (hex/decimal).") + _add_server_filter_args(p_gc) + _add_output_args(p_gc) + p_gc.set_defaults(func=cmd_get_callers) + + # read_memory + p_rm = sub.add_parser( + "read_memory", + help=( + "Read raw bytes from the binary at an address. " + "Default output is a hexdump; pass --format hex for a single hex " + "string, --format raw for binary stdout, or --json for a JSON " + "envelope with base64-encoded bytes." + ), + ) + p_rm.add_argument("addr", help="Address to start reading from (hex 0x.. or decimal).") + p_rm.add_argument("size", type=lambda x: int(x, 0), + help="Number of bytes to read (decimal or 0x-prefixed hex).") + p_rm.add_argument("--format", choices=("hexdump", "hex", "raw"), default="hexdump", + help="Text-mode output format. Ignored when --json is set.") + _add_server_filter_args(p_rm) + _add_output_args(p_rm) + p_rm.set_defaults(func=cmd_read_memory) + + # install-skill + p_sk = sub.add_parser( + "install-skill", + help="Install the bundled Agent Skill (SKILL.md) for Claude Code or Codex.", + ) + p_sk.add_argument("names", nargs="*", + help="Specific skill names to install (default: all bundled).") + p_sk.add_argument( + "--agent", + action="append", + choices=_SKILL_AGENT_CHOICES, + help=( + "Agent skill directory to install into. Repeat for multiple agents, " + "or use 'all'. Default: codex when CODEX_* env vars are present, " + "otherwise claude." + ), + ) + p_sk.add_argument( + "--dest", + help="Install destination override. Cannot be combined with --agent.", + ) + p_sk.add_argument("--force", action="store_true", + help="Overwrite an existing skill directory.") + _add_output_args(p_sk) + p_sk.set_defaults(func=cmd_install_skill) + + return parser + + +def main(argv: Optional[List[str]] = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + _configure_logging(getattr(args, "verbose", False)) + try: + return args.func(args) or EXIT_OK + except SystemExit: + raise + except Exception as exc: # noqa: BLE001 + _l.exception("Unhandled error: %s", exc) + print(f"Error: {exc}", file=sys.stderr) + return EXIT_RUNTIME_ERROR + + +if __name__ == "__main__": # pragma: no cover + sys.exit(main()) diff --git a/declib/configuration.py b/declib/configuration.py new file mode 100644 index 00000000..571c356f --- /dev/null +++ b/declib/configuration.py @@ -0,0 +1,184 @@ +import builtins +from typing import Optional, Dict + +from declib.decompilers import SUPPORTED_DECOMPILERS, GHIDRA_DECOMPILER, IDA_DECOMPILER, ANGR_DECOMPILER, \ + BINJA_DECOMPILER +from platformdirs import user_config_dir +from filelock import FileLock +import pathlib +import logging +import toml +import os + +_l = logging.getLogger(__name__) + + +class DLConfig: + __slots__ = ( + "save_location", + "_config_lock", + ) + + def __init__(self, save_location: Optional[str] = None): + if not save_location: + save_location = user_config_dir("declib") + self.save_location = _create_path(save_location) + self._config_lock = FileLock(save_location + f"/{self.__class__.__name__}.lock", timeout=-1) + + def save(self): + self.save_location = _create_path(self.save_location) + if not self.save_location.parent.exists(): + self.save_location.parent.mkdir(parents=True, exist_ok=True) + + dump_dict = {} + for attr in self.__slots__: + if attr == '_config_lock': + continue + attr_val = getattr(self, attr) + if isinstance(attr_val, pathlib.Path): + attr_val = str(attr_val) + + if isinstance(attr_val, dict): + attr_val = {k: str(v) if isinstance(v, pathlib.Path) else v for k, v in attr_val.items()} + + dump_dict[attr] = attr_val + + with self._config_lock: + with open(self.save_location, "w") as fp: + toml.dump(dump_dict, fp) + + _l.debug("Saved config to %s", self.save_location) + return True + + def load(self): + self.save_location = _create_path(self.save_location) + if not self.save_location.exists(): + return None + + with self._config_lock: + with open(self.save_location, "r") as fp: + load_dict = toml.load(fp) + + for attr in self.__slots__: + if attr == '_config_lock': + continue + setattr(self, attr, load_dict.get(attr, None)) + + return self + + @classmethod + def load_from_file(cls, save_location=None): + config = cls(save_location) + return config.load() + + @classmethod + def update_or_make(cls, save_location=None, **attrs_to_update): + exists = False + if save_location: + save_location = _create_path(save_location) + exists = save_location.exists() + + if not exists: + config = cls(save_location) + else: + config = cls.load_from_file(save_location) + + for attr, val in attrs_to_update.items(): + if attr in config.__slots__: + setattr(config, attr, val) + + config.save() + return config + + +class DecLibConfig(DLConfig): + __slots__ = ( + "save_location", + "plugins_paths", + "headless_binary_paths", + "gdbinit_path", + ) + + def __init__(self, + save_location: Optional[str] = None, + plugins_paths: Optional[Dict] = {}, + headless_binary_paths: Optional[Dict] = {}, + gdbinit_path: Optional[str] = None + ): + super().__init__(save_location) + self.save_location = self.save_location / f"{__class__.__name__}.toml" + self.gdbinit_path = gdbinit_path + self.plugins_paths = {} + self.headless_binary_paths = {} + + @classmethod + def update_or_make(cls, save_location=None, **attrs_to_update): + exists = False + if save_location: + save_location = _create_path(save_location) + exists = save_location.exists() + + if not exists: + config = cls(save_location) + else: + config = cls.load_from_file(save_location) + + for attr, val in attrs_to_update.items(): + if attr in config.__slots__: + setattr(config, attr, val) + + for decompiler in SUPPORTED_DECOMPILERS: + plugins_path = config.plugins_paths[decompiler] if decompiler in config.plugins_paths else None + headless_path = config.headless_binary_paths[ + decompiler] if decompiler in config.headless_binary_paths else None + # Attempt to find default plugins_path if not given + if not plugins_path: + plugins_path = _infer_plugins_path(decompiler) + # Check if only plugins path exists and attempt to infer headless path + if plugins_path and not headless_path: + headless_path = _infer_headless_path(plugins_path, decompiler) + config.plugins_paths[decompiler] = plugins_path + config.headless_binary_paths[decompiler] = headless_path + + config.save() + return config + + +def _create_path(path_str): + return pathlib.Path(path_str).expanduser().absolute() + + +def _infer_headless_path(plugins_path, decompiler): + if decompiler == GHIDRA_DECOMPILER: + # Infer ghidra headless + plugins_path = _create_path(plugins_path) + install_root = plugins_path.parent + headless_path = install_root / "support" / ("analyzeHeadless.bat" if os.name == 'nt' else "analyzeHeadless") + return headless_path if headless_path.exists() else None + + if decompiler == IDA_DECOMPILER: + # Infer ida headless + plugins_path = _create_path(plugins_path) + install_root = plugins_path.parent.parent + headless_path = install_root / "idat64" + return headless_path if headless_path.exists() else None + + return None + + +def _infer_plugins_path(decompiler): + home = _create_path(os.getenv("HOME") or "~/") + if decompiler == GHIDRA_DECOMPILER: + # Ghidra plugins isn't in install root, so just attempt to use default + default_path = home / "ghidra_scripts" + return default_path if default_path.exists() else None + + if decompiler == IDA_DECOMPILER: + default_path = home / ".idapro" / "plugins" + return default_path if default_path.exists() else None + + if decompiler == BINJA_DECOMPILER: + default_path = home / ".binaryninja" / "plugins" + return default_path if default_path.exists() else None + + return None diff --git a/declib/decompiler_stubs/__init__.py b/declib/decompiler_stubs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/declib/decompiler_stubs/angr_declib/__init__.py b/declib/decompiler_stubs/angr_declib/__init__.py new file mode 100644 index 00000000..39a0a2af --- /dev/null +++ b/declib/decompiler_stubs/angr_declib/__init__.py @@ -0,0 +1,4 @@ +try: + from declib.decompilers.angr import * +except ImportError: + print("[!] declib is not installed, please `pip install declib` for THIS python interpreter") diff --git a/declib/decompiler_stubs/angr_declib/plugin.toml b/declib/decompiler_stubs/angr_declib/plugin.toml new file mode 100644 index 00000000..f7d586bc --- /dev/null +++ b/declib/decompiler_stubs/angr_declib/plugin.toml @@ -0,0 +1,13 @@ +[meta] +plugin_metadata_version = 0 + +[plugin] +name = "declib" +shortname = "declib" +version = "0.0.0" +description = "" +long_description = "" +platforms = ["windows", "linux", "macos"] +min_angr_version = "9.0.0.0" +author = "The BinSync Team" +entrypoints = ["__init__.py"] \ No newline at end of file diff --git a/declib/decompiler_stubs/binja_declib/README.md b/declib/decompiler_stubs/binja_declib/README.md new file mode 100644 index 00000000..cc187cc4 --- /dev/null +++ b/declib/decompiler_stubs/binja_declib/README.md @@ -0,0 +1,8 @@ + +## declib + +

+ declib Logo +

+ +Your Only Decompiler API diff --git a/declib/decompiler_stubs/binja_declib/__init__.py b/declib/decompiler_stubs/binja_declib/__init__.py new file mode 100644 index 00000000..b68df6da --- /dev/null +++ b/declib/decompiler_stubs/binja_declib/__init__.py @@ -0,0 +1,4 @@ +try: + from declib.decompilers.binja import * +except ImportError: + print("[!] declib is not installed, please `pip install declib` for THIS python interpreter") diff --git a/declib/decompiler_stubs/binja_declib/plugin.json b/declib/decompiler_stubs/binja_declib/plugin.json new file mode 100644 index 00000000..50742b08 --- /dev/null +++ b/declib/decompiler_stubs/binja_declib/plugin.json @@ -0,0 +1,21 @@ +{ + "pluginmetadataversion" : 2, + "name": "DecLib", + "type": ["core", "ui"], + "api": ["python3"], + "description": "Adds support for cross-decompiler collab in the declib supported decompilers.", + "longdescription": "Adds support for cross-decompiler collab in the declib supported decompilers.", + "license": { + "name": "BSD 2-clause", + "text": "" + }, + "platforms" : ["Darwin", "Linux", "Windows"], + "installinstructions" : { + "Darwin" : "Install through the Binja Plugin Manager. To update do `pip install -U declib`", + "Linux" : "Install through the Binja Plugin Manager. To update do `pip install -U declib`", + "Windows" : "Install through the Binja Plugin Manager. To update do `pip install -U declib`" + }, + "version": "0.0.0", + "author": "BinSync Team", + "minimumbinaryninjaversion": 1200 +} diff --git a/declib/decompiler_stubs/binja_declib/requirements.txt b/declib/decompiler_stubs/binja_declib/requirements.txt new file mode 100644 index 00000000..5685a451 --- /dev/null +++ b/declib/decompiler_stubs/binja_declib/requirements.txt @@ -0,0 +1 @@ +declib \ No newline at end of file diff --git a/declib/decompiler_stubs/ida_declib.py b/declib/decompiler_stubs/ida_declib.py new file mode 100644 index 00000000..64ef639f --- /dev/null +++ b/declib/decompiler_stubs/ida_declib.py @@ -0,0 +1,8 @@ +def PLUGIN_ENTRY(*args, **kwargs): + try: + from declib.decompilers.ida import DecLibPlugin + except ImportError: + print("[!] declib is not installed, please `pip install declib` for THIS python interpreter") + return None + + return DecLibPlugin(*args, **kwargs) diff --git a/declib/decompilers/__init__.py b/declib/decompilers/__init__.py new file mode 100644 index 00000000..20068b34 --- /dev/null +++ b/declib/decompilers/__init__.py @@ -0,0 +1,8 @@ +ANGR_DECOMPILER = "angr" +IDA_DECOMPILER = "ida" +BINJA_DECOMPILER = "binja" +GHIDRA_DECOMPILER = "ghidra" + +SUPPORTED_DECOMPILERS = { + ANGR_DECOMPILER, IDA_DECOMPILER, BINJA_DECOMPILER, GHIDRA_DECOMPILER +} diff --git a/declib/decompilers/angr/__init__.py b/declib/decompilers/angr/__init__.py new file mode 100644 index 00000000..ca1ce6b4 --- /dev/null +++ b/declib/decompilers/angr/__init__.py @@ -0,0 +1,11 @@ +try: + import angrmanagement + AM_PRESENT = True +except ImportError: + AM_PRESENT = False + +if AM_PRESENT: + try: + from .compat import * + except ImportError: + pass diff --git a/declib/decompilers/angr/artifact_lifter.py b/declib/decompilers/angr/artifact_lifter.py new file mode 100644 index 00000000..a6c4c1ad --- /dev/null +++ b/declib/decompilers/angr/artifact_lifter.py @@ -0,0 +1,46 @@ +import typing + +from declib.api import ArtifactLifter + +if typing.TYPE_CHECKING: + from .interface import AngrInterface + +class AngrArtifactLifter(ArtifactLifter): + """ + TODO: finish me + """ + def __init__(self, interface: "AngrInterface"): + super(AngrArtifactLifter, self).__init__(interface) + + def is_arm(self) -> bool: + if self.deci.binary_arch is not None: + return "ARM" in self.deci.binary_arch + return False + + + def lift_type(self, type_str: str) -> str: + return type_str + + def lift_stack_offset(self, offset: int, func_addr: int) -> int: + return offset + + def lower_type(self, type_str: str) -> str: + return type_str + + def lower_stack_offset(self, offset: int, func_addr: int) -> int: + return offset + + def lower_addr(self, addr: int) -> int: + new_addr = super().lower_addr(addr) + if self.is_arm() and not self.deci.addr_starts_instruction(addr): + new_addr += 1 + + return new_addr + + + def lift_addr(self, addr: int) -> int: + new_addr = super().lift_addr(addr) + if self.is_arm() and new_addr % 2 == 1: + new_addr -= 1 + + return new_addr \ No newline at end of file diff --git a/declib/decompilers/angr/compat.py b/declib/decompilers/angr/compat.py new file mode 100644 index 00000000..74d6e34b --- /dev/null +++ b/declib/decompilers/angr/compat.py @@ -0,0 +1,262 @@ +# pylint: disable=wrong-import-position,wrong-import-order +import logging +from collections import defaultdict +from typing import Optional + +from angrmanagement.plugins import BasePlugin +from angrmanagement.ui.workspace import Workspace +from angrmanagement.ui.views.view import BaseView + +from declib.artifacts import ( + StackVariable, FunctionHeader, Enum, Struct, GlobalVariable, Comment, FunctionArgument +) +from declib.decompilers.angr.interface import AngrInterface + +l = logging.getLogger(__name__) + + +class GenericDLAngrManagementPlugin(BasePlugin): + def __init__(self, workspace: Workspace, interface: Optional[AngrInterface] = None, context_menu_items=None): + super().__init__(workspace) + # (name, action_string, callback_func, category) + self.context_menu_items = context_menu_items or [] + # Keep strong refs to QShortcut objects so Qt doesn't GC them + self._qshortcuts: list = [] + if interface is None: + from declib.decompilers.angr.interface import AngrInterface + self.interface = AngrInterface( + workspace, + init_plugin=True, + ) + else: + self.interface = interface + + def teardown(self): + pass + + def register_shortcut(self, name: str, shortcut: str, callback_func, deci=None) -> bool: + """ + Register a keyboard shortcut bound to ``callback_func`` on the angr-management + main window. The shortcut is an application-wide QShortcut so it fires from + any focused widget. + """ + from PySide6.QtGui import QShortcut, QKeySequence + from PySide6.QtCore import Qt + from PySide6.QtWidgets import QApplication + + parent = getattr(self.workspace, "main_window", None) + if parent is None: + app = QApplication.instance() + if app is not None: + for w in app.topLevelWidgets(): + if w.isWindow(): + parent = w + break + if parent is None: + l.warning("No Qt main window available; cannot bind shortcut %s", shortcut) + return False + + qsc = QShortcut(QKeySequence(shortcut), parent) + qsc.setContext(Qt.ApplicationShortcut) + qsc.activated.connect(lambda: callback_func(None, deci=deci)) + self._qshortcuts.append(qsc) + return True + + # + # Context Menus + # + + @staticmethod + def build_nested_structure_from_ctx_items(context_items): + def insert(categories, action, func, node): + if categories: + category = categories[0] + next_node = None + for child in node: + if child[0] == category: + next_node = child[1] + break + if not next_node: + next_node = [] + node.append((category, next_node)) + insert(categories[1:], action, func, next_node) + else: + node.append((action, func)) + + root = [] + for path, action, func in context_items: + categories = path.strip('/').split('/') + insert(categories, action, func, root) + + return root[0] + + def build_context_menu_node(self, node): + """ + The context menu triggered on right-click on a node in the decompilation view. + If used agnostic to the node type, this will always be on the context menu + """ + try: + func_addr = node.codegen.cfunc.addr + except AttributeError: + func_addr = None + + # only add the context menu items if we are in a function + if func_addr is not None: + # collect all the context menu items into a single list + ctx_items = [ + (category if category else "", action_string, callback_func) + for name, action_string, callback_func, category in self.context_menu_items + ] + if ctx_items: + nested_structure = GenericDLAngrManagementPlugin.build_nested_structure_from_ctx_items(ctx_items) + if not nested_structure[0][0]: + root_items = nested_structure[0][1] + categorized_items = nested_structure[1] + for item in root_items: + yield item + else: + categorized_items = nested_structure + + yield categorized_items + + # + # Decompiler Hooks + # + + # pylint: disable=unused-argument + def handle_stack_var_renamed(self, func, offset, old_name, new_name): + if func is None: + return False + + decompilation = self.interface.decompile_function(func).codegen + stack_var = self.interface.find_stack_var_in_codegen(decompilation, offset) + self.interface.stack_variable_changed(StackVariable(offset, new_name, None, stack_var.size, func.addr)) + return True + + # pylint: disable=unused-argument + def handle_stack_var_retyped(self, func, offset, old_type, new_type): + decompilation = self.interface.decompile_function(func).codegen + stack_var = self.interface.find_stack_var_in_codegen(decompilation, offset) + var_type = AngrInterface.stack_var_type_str(decompilation, stack_var) + self.interface.stack_variable_changed(StackVariable(offset, stack_var.name, var_type, stack_var.size, func.addr)) + return True + + # pylint: disable=unused-argument + def handle_func_arg_renamed(self, func, offset, old_name, new_name): + decompilation = self.interface.decompile_function(func).codegen + func_args = AngrInterface.func_args_as_declib_args(decompilation) + self.interface.function_header_changed( + FunctionHeader( + name=None, + addr=func.addr, + type_=None, + args={ + offset: FunctionArgument(offset=offset, name=new_name, type_=None, size=func_args[offset].size) + }, + ) + ) + + return True + + # pylint: disable=unused-argument + def handle_func_arg_retyped(self, func, offset, old_type, new_type): + decompilation = self.interface.decompile_function(func).codegen + func_args = AngrInterface.func_args_as_declib_args(decompilation) + self.interface.function_header_changed( + FunctionHeader( + name=None, + addr=func.addr, + type_=None, + args={ + offset: FunctionArgument(offset=offset, name=None, type_=new_type, size=func_args[offset].size) + }, + ) + ) + + return True + + # pylint: disable=unused-argument,no-self-use + def handle_global_var_renamed(self, address, old_name, new_name): + self.interface.global_variable_changed( + GlobalVariable(addr=address, name=new_name, type_=None) + ) + return True + + # pylint: disable=unused-argument,no-self-use + def handle_global_var_retyped(self, address, old_type, new_type): + self.interface.global_variable_changed( + GlobalVariable(addr=address, name=None, type_=new_type) + ) + return True + + # pylint: disable=unused-argument + def handle_function_renamed(self, func, old_name, new_name): + if func is None: + return False + + self.interface.function_header_changed(FunctionHeader(name=new_name, addr=func.addr)) + return True + + # pylint: disable=unused-argument,no-self-use + def handle_function_retyped(self, func, old_type, new_type): + if func is None: + return False + + self.interface.function_header_changed(FunctionHeader(name=None, addr=func.addr, type_=new_type)) + return True + + # pylint: disable=unused-argument + def handle_comment_changed(self, address, old_cmt, new_cmt, created: bool, decomp: bool): + # comments are only possible in functions in AM + func_addr = self.interface.get_closest_function(address) + if func_addr is None: + return False + + self.interface.comment_changed( + Comment(addr=address, comment=new_cmt, func_addr=func_addr, decompiled=True), deleted=not new_cmt + ) + return True + +class AngrWidgetWrapper(BaseView): + """ + The class for the window that shows changes/info to BinSync data. This includes things like + changes to functions or structs. + """ + + def __init__(self, workspace, default_docking_position, qt_cls, window_name: str, *args, **kwargs): + # hacky imports to avoid ui + from declib.ui.version import set_ui_version + set_ui_version("PySide6") + from declib.ui.qt_objects import QVBoxLayout + + super().__init__(window_name.replace(" ", "_"), workspace, default_docking_position) + self.base_caption = window_name + self.widget = qt_cls(*args, **kwargs) + + main_layout = QVBoxLayout() + main_layout.addWidget(self.widget) + self.setLayout(main_layout) + self.width_hint = 300 + + def closeEvent(self, event): + self.widget.close() + + +def attach_qt_widget(workspace: Workspace, qt_cls, window_name: str, default_docking_position=None, *args, **kwargs): + from PySide6QtAds import SideBarRight, CDockWidget, CDockManager + + wrapper = AngrWidgetWrapper(workspace, default_docking_position, qt_cls, window_name, *args, **kwargs) + if not wrapper.widget: + l.error("Failed to create widget %s", window_name) + return False + + workspace.add_view(wrapper) + dock = workspace.view_manager.view_to_dock[wrapper] + dock.setAutoHide(False, SideBarRight) + dock.closed.disconnect() + dock.setFeature(CDockWidget.DockWidgetDeleteOnClose, False) + # grab the dock manager by climbing up parents, probably a better way to directly grab it + dm = dock.parent().parent().parent() + assert (isinstance(dm, CDockManager)) + dm.setAutoHideConfigFlags(CDockManager.AutoHideHasCloseButton, False) + return True diff --git a/declib/decompilers/angr/interface.py b/declib/decompilers/angr/interface.py new file mode 100644 index 00000000..75a70562 --- /dev/null +++ b/declib/decompilers/angr/interface.py @@ -0,0 +1,949 @@ +import logging +import os +import re +from collections import defaultdict +from functools import lru_cache +from typing import Optional, Dict, List, Tuple +from pathlib import Path + +import angr +from angr.analyses.decompiler.structured_codegen import DummyStructuredCodeGenerator + +from declib.api.decompiler_interface import ( + DecompilerInterface, +) +from declib.artifacts import ( + Function, FunctionHeader, Comment, StackVariable, FunctionArgument, Artifact, Decompilation, Context, + Struct, StructMember +) +from .artifact_lifter import AngrArtifactLifter + +l = logging.getLogger(__name__) + +try: + from angrmanagement.ui.views import CodeView +except ImportError: + l.debug("angr-management module not found... likely running headless.") + +logging.getLogger("angr").setLevel(logging.ERROR) +logging.getLogger("cle").setLevel(logging.ERROR) + + +class AngrInterface(DecompilerInterface): + """ + The class used for all pushing/pulling and merging based actions with BinSync artifacts. + This class is responsible for handling callbacks that are done by changes from the local user + and responsible for running a thread to get new changes from other users. + """ + + def __init__(self, workspace=None, **kwargs): + self.workspace = workspace + self.main_instance = workspace.main_instance if workspace else self + self._ctx_menu_items = [] + self._am_logger = None + self._cfg = None + self._binary_arch = None + super().__init__(name="angr", artifact_lifter=AngrArtifactLifter(self), **kwargs) + + def _init_headless_components(self, *args, **kwargs): + super()._init_headless_components(*args, **kwargs) + self.project = angr.Project(str(self._binary_path), auto_load_libs=False) + # cross_references=True populates kb.xrefs so xrefs_to_addr (e.g. + # "who references this string constant?") works. + self._cfg = self.project.analyses.CFG( + show_progressbar=False, normalize=True, data_references=True, cross_references=True, + ) + self.project.analyses.CompleteCallingConventions(cfg=self._cfg, recover_variables=True, analyze_callsites=True) + + def _init_gui_components(self, *args, **kwargs): + super()._init_gui_components(*args, **kwargs) + if self.workspace is None: + raise ValueError("The workspace provided is None, which will result in a broken BinSync.") + + self._am_logger = logging.getLogger(f"angrmanagement.{self._plugin_name or 'generic_plugin'}") + self._am_logger.setLevel(logging.INFO) + + # + # Decompiler API + # + + @property + def binary_base_addr(self) -> int: + for seg in self.main_instance.project.loader.main_object.segments: + return seg.min_addr + + # fallback + return self.main_instance.project.loader.main_object.mapped_base + + @property + def binary_hash(self) -> str: + return self.main_instance.project.loader.main_object.md5.hex() + + @property + def binary_path(self) -> Optional[str]: + try: + return self.main_instance.project.loader.main_object.binary + # pylint: disable=broad-except + except Exception: + return None + + def fast_get_function(self, func_addr) -> Optional[Function]: + lowered_addr = self.art_lifter.lower_addr(func_addr) + try: + _func = self.main_instance.project.kb.functions[lowered_addr] + except KeyError: + #self.warning(f"Function at {hex(func_addr)} not found.") + return None + + func = Function(addr=_func.addr, size=_func.size, name=_func.name) + if not _func or not _func.prototype: + type_ = None + else: + type_ = _func.prototype.returnty.c_repr() if _func.prototype.returnty else None + func.header.type = type_ + return self.art_lifter.lift(func) + + def get_func_size(self, func_addr) -> int: + func_addr = self.art_lifter.lower_addr(func_addr) + try: + func = self.main_instance.project.kb.functions[func_addr] + return func.size + except KeyError: + return 0 + + def xrefs_to(self, artifact: Artifact, decompile=False, only_code=False) -> List[Artifact]: + if not isinstance(artifact, Function): + l.warning("xrefs_to is only implemented for functions.") + return [] + if only_code: + l.warning("only_code is not supported in angr.") + + function: Function = self.art_lifter.lower(artifact) + program_cfg = self.main_instance.project.kb.cfgs.get_most_accurate() + if program_cfg is None: + return [] + + func_node = program_cfg.get_any_node(function.addr) + if func_node is None: + return [] + + xrefs = [] + seen_callers = set() + for node in program_cfg.graph.predecessors(func_node): + func_addr = node.function_address + if func_addr is None or func_addr == function.addr: + continue + if func_addr in seen_callers: + continue + seen_callers.add(func_addr) + xrefs.append(self.art_lifter.lift(Function(func_addr, 0))) + + return xrefs + + def xrefs_from(self, func_addr: int) -> List[Function]: + """angr callees: use the kb.callgraph successor set. + + ``kb.callgraph`` is a NetworkX digraph populated during CFG analysis; + its successors of a function address are the direct callees, which is + what we want. Unlike ``Function.transition_graph``, these are + deduplicated per target and come with kb function lookups. + """ + lowered = self.art_lifter.lower_addr(func_addr) + project = self.main_instance.project + callgraph = getattr(project.kb, "callgraph", None) + if callgraph is None or lowered not in callgraph: + return [] + + kb_functions = project.kb.functions + callees: List[Function] = [] + seen = set() + for succ_addr in callgraph.successors(lowered): + if succ_addr in seen: + continue + seen.add(succ_addr) + func_obj = kb_functions.get(succ_addr, None) + name = getattr(func_obj, "name", None) if func_obj is not None else None + header = FunctionHeader(name=name, addr=succ_addr) if name else None + callees.append(self.art_lifter.lift(Function(succ_addr, 0, header=header))) + return callees + + def xrefs_to_addr(self, addr: int, only_code: bool = False) -> List[Artifact]: + """angr data-xref lookup: look up kb.xrefs references to ``addr``. + + Falls back to the default (empty) if the xref manager isn't populated. + """ + lowered = self.art_lifter.lower_addr(addr) + project = self.main_instance.project + xref_manager = getattr(project.kb, "xrefs", None) + if xref_manager is None: + return [] + + try: + xref_set = xref_manager.get_xrefs_by_dst(lowered) + except Exception: + return [] + if not xref_set: + return [] + + program_cfg = project.kb.cfgs.get_most_accurate() + if program_cfg is None: + return [] + + results: List[Artifact] = [] + seen = set() + for xref in xref_set: + node = program_cfg.get_any_node(xref.ins_addr, anyaddr=True) + if node is None or node.function_address is None: + continue + func_addr = node.function_address + if func_addr in seen: + continue + seen.add(func_addr) + name = None + try: + name = project.kb.functions[func_addr].name + except Exception: + pass + header = FunctionHeader(name=name, addr=func_addr) if name else None + results.append(self.art_lifter.lift(Function(func_addr, 0, header=header))) + return results + + def list_strings(self, filter: Optional[str] = None) -> List[Tuple[int, str]]: + pattern = re.compile(filter) if filter else None + try: + cfg = self.main_instance.project.kb.cfgs.get_most_accurate() + except Exception: + cfg = None + results: List[Tuple[int, str]] = [] + seen = set() + if cfg is not None: + for addr, mem_data in cfg.memory_data.items(): + if mem_data.sort != "string" or not mem_data.content: + continue + try: + text = mem_data.content.decode("utf-8", errors="replace") + except Exception: + continue + lifted_addr = self.art_lifter.lift_addr(addr) + if lifted_addr in seen: + continue + seen.add(lifted_addr) + if pattern is None or pattern.search(text): + results.append((lifted_addr, text)) + results.sort(key=lambda item: item[0]) + return results + + def read_memory(self, addr: int, size: int) -> Optional[bytes]: + if size <= 0: + return b"" + lowered = self.art_lifter.lower_addr(addr) + loader_memory = self.main_instance.project.loader.memory + try: + data = loader_memory.load(lowered, size) + except (KeyError, ValueError): + # cle's Clemory raises when the address isn't backed by a segment. + return None + return bytes(data) + + def disassemble(self, addr: int, **kwargs) -> Optional[str]: + lowered = self.art_lifter.lower_addr(addr) + func = self.main_instance.project.kb.functions.get(lowered, None) + if func is None: + for _addr, _func in self.main_instance.project.kb.functions.items(): + if _addr <= lowered < (_addr + (_func.size or 0)): + func = _func + break + if func is None: + return None + + try: + base_addr = self.binary_base_addr + except Exception: + base_addr = 0 + hex_re = re.compile(r"0x([0-9a-fA-F]+)") + + def _rewrite_operands(op_str: str) -> str: + # Rewrite absolute addresses in operands to their lifted form so the + # output is consistent across decompilers (e.g. ghidra lifts addresses). + def _sub(match: "re.Match[str]") -> str: + try: + raw = int(match.group(1), 16) + except ValueError: + return match.group(0) + if base_addr and raw >= base_addr: + return f"0x{raw - base_addr:x}" + return match.group(0) + + return hex_re.sub(_sub, op_str) + + lines: List[str] = [] + try: + blocks = sorted(func.blocks, key=lambda b: b.addr) + except Exception: + blocks = list(func.blocks) + for block in blocks: + try: + for insn in block.capstone.insns: + lifted = self.art_lifter.lift_addr(insn.address) + op_str = _rewrite_operands(insn.op_str) + lines.append(f"0x{lifted:x}:\t{insn.mnemonic}\t{op_str}".rstrip()) + except Exception: + continue + return "\n".join(lines) if lines else None + + def _decompile(self, function: Function, map_lines=False, **kwargs) -> Optional[Decompilation]: + if function.dec_obj is None: + function.dec_obj = self.get_decompilation_object(function, do_lower=False) + + if function.dec_obj is None: + return None + + codegen = function.dec_obj.codegen + if codegen is None or not codegen.text: + return None + + decompilation = Decompilation(addr=function.addr, text=codegen.text, decompiler=self.name) + if map_lines: + if self.headless: + decompilation.line_map = self.line_map_from_decompilation(function.dec_obj) + else: + self.warning("Mapping lines is only supported in headless mode.") + decompilation.line_map = {} + + return decompilation + + def get_decompilation_object(self, function: Function, do_lower=True, **kwargs) -> Optional[object]: + func_addr = self.art_lifter.lower_addr(function.addr) if do_lower else function.addr + func = self.main_instance.project.kb.functions.get(func_addr, None) + if func is None: + return None + + try: + decomp = self.decompile_function(func) + except Exception as e: + l.warning("Failed to decompile %s because %s", func, e) + decomp = None + + return decomp + + def local_variable_names(self, func: Function) -> List[str]: + codegen = self.decompile_function( + self.main_instance.project.kb.functions[self.art_lifter.lower_addr(func.addr)] + ).codegen + if not codegen or not codegen.cfunc or not codegen.cfunc.variable_manager: + return [] + + return [v.name for v in codegen.cfunc.variable_manager._unified_variables] + + def rename_local_variables_by_names(self, func: Function, name_map: Dict[str, str], **kwargs) -> bool: + codegen = self.decompile_function( + self.main_instance.project.kb.functions[self.art_lifter.lower_addr(func.addr)] + ).codegen + if not codegen or not codegen.cfunc or not codegen.cfunc.variable_manager: + return False + + changed = False + for v in codegen.cfunc.variable_manager._unified_variables: + if v.name in name_map and v.name != name_map[v.name]: + v.name = name_map[v.name] + changed = True + + if not self.headless: + self.refresh_decompilation(func.addr) + return changed + + @property + def binary_arch(self) -> str | None: + if self._binary_arch is None: + if self.main_instance.project.arch: + self._binary_arch = self.main_instance.project.arch.name + + return self._binary_arch + + + # + # GUI API + # + + def _init_gui_plugin(self, *args, **kwargs): + from .compat import GenericDLAngrManagementPlugin + self.gui_plugin = GenericDLAngrManagementPlugin(self.workspace, interface=self) + self.workspace.plugins.register_active_plugin(self._plugin_name, self.gui_plugin) + return self.gui_plugin + + def gui_goto(self, func_addr): + self.workspace.jump_to(self.art_lifter.lower_addr(func_addr)) + + def gui_register_ctx_menu(self, name, action_string, callback_func, category=None, shortcut=None) -> bool: + if self.gui_plugin is None: + l.critical("Cannot register context menu item without a GUI plugin.") + return False + + self._ctx_menu_items.append((name, action_string, callback_func, category)) + self.gui_plugin.context_menu_items = self._ctx_menu_items + + if shortcut: + try: + self.gui_plugin.register_shortcut(name, shortcut, callback_func, deci=self) + except Exception as e: + l.warning("Failed to register angr shortcut %r for %s: %s", shortcut, name, e) + return True + + def gui_active_context(self) -> Optional[Context]: + curr_view = self.workspace.view_manager.current_tab + # current view is non-existent or does not support a "function" view type of context + if not curr_view or not hasattr(curr_view, "function"): + return None + + try: + func = curr_view.function + except NotImplementedError: + return None + + # TODO: support addr and screen_name for Context + if func is None or func.am_obj is None: + return None + + context = Context(addr=None, func_addr=func.addr) + return self.art_lifter.lift(context) + + def gui_attach_qt_window(self, qt_window: type["QWidgt"], title: str, target_window=None, position=None, *args, **kwargs) -> bool: + from .compat import attach_qt_widget + if self.workspace is None: + l.warning("Cannot attach a Qt window without a workspace.") + return False + + return attach_qt_widget(self.workspace, qt_window, title, *args, **kwargs) + + + # + # Artifact API + # + + def _set_function(self, func: Function, **kwargs) -> bool: + angr_func = self.main_instance.project.kb.functions[func.addr] + + # re-decompile a function if needed + decompilation = self.decompile_function(angr_func).codegen + changes = super()._set_function(func, decompilation=decompilation, **kwargs) + if not self.headless and changes: + # Use "retype_variable" event to trigger proper UI refresh including type reflow + self.refresh_decompilation(func.addr, event="retype_variable") + + return changes + + def _get_function(self, addr, **kwargs) -> Optional[Function]: + try: + _func = self.main_instance.project.kb.functions[addr] + except KeyError: + return None + + func = Function(_func.addr, _func.size) + if not _func or not _func.prototype: + type_ = None + else: + type_ = _func.prototype.returnty.c_repr() if _func.prototype.returnty else None + func.header = FunctionHeader( + _func.name, _func.addr, type_=type_ + ) + + try: + decompilation = self.decompile_function(_func).codegen + except Exception as e: + l.warning("Failed to decompile function %s: %s", hex(_func.addr), e) + decompilation = None + + if not decompilation: + return func + + func.header.args = self.func_args_as_declib_args(decompilation) + # overwrite type again since it can change with decompilation + functy = decompilation.cfunc.functy if decompilation.cfunc else None + if functy and functy.returnty: + func.header.type = decompilation.cfunc.functy.returnty.c_repr() + + stack_vars = { + angr_sv.offset: StackVariable( + angr_sv.offset, angr_sv.name, self.stack_var_type_str(decompilation, angr_sv), angr_sv.size, func.addr + ) + for angr_sv in self.stack_vars_in_dec(decompilation) + } + func.stack_vars = stack_vars + + return func + + def _functions(self) -> Dict[int, Function]: + funcs = {} + for addr, func in self.main_instance.project.kb.functions.items(): + funcs[addr] = Function(addr, func.size) + + # syscalls and simprocedures are not real funcs to sync + if func.is_syscall or func.is_simprocedure or not func.name: + continue + + funcs[addr].name = func.name + + return funcs + + def _set_function_header(self, fheader: FunctionHeader, decompilation=None, **kwargs) -> bool: + angr_func = self.main_instance.project.kb.functions[fheader.addr] + changes = False + if not fheader: + return changes + + if fheader.name and fheader.name != angr_func.name: + angr_func.name = fheader.name + decompilation.cfunc.name = fheader.name + decompilation.cfunc.demangled_name = fheader.name + changes = True + + if fheader.args: + for i, arg in fheader.args.items(): + if not arg: + continue + + if i >= len(decompilation.cfunc.arg_list): + break + + dec_arg = decompilation.cfunc.arg_list[i].variable + # TODO: set the types of the args + if arg.name and arg.name != dec_arg.name: + dec_arg.name = arg.name + changes = True + + return changes + + def _set_stack_variable(self, svar: StackVariable, decompilation=None, **kwargs) -> bool: + changed = False + if not svar or not decompilation: + return changed + + dec_svar = AngrInterface.find_stack_var_in_codegen(decompilation, svar.offset) + if not dec_svar: + return changed + + # Set the name if provided and different + if svar.name and svar.name != dec_svar.name: + dec_svar.name = svar.name + dec_svar.renamed = True + changed = True + + # Set the type if provided + if svar.type: + try: + from angr.sim_type import parse_type + types_store = self.main_instance.project.kb.types + arch = self.main_instance.project.arch + + # Parse the type string into a SimType + sim_type = parse_type(svar.type, predefined_types=types_store, arch=arch) + sim_type = sim_type.with_arch(arch) + + # Get the variable manager and set the type + variable_kb = decompilation._variable_kb if hasattr(decompilation, '_variable_kb') else self.main_instance.project.kb + variable_manager = variable_kb.variables[svar.addr] + variable_manager.set_variable_type(dec_svar, sim_type, all_unified=True, mark_manual=True) + changed = True + except Exception as e: + l.warning(f"Failed to set stack variable type for {svar.name}: {e}") + + return changed + + def _set_comment(self, comment: Comment, decompilation=None, **kwargs) -> bool: + changed = False + if not comment or not comment.comment: + return changed + + if comment.decompiled and comment.addr != comment.func_addr: + try: + pos = decompilation.map_addr_to_pos.get_nearest_pos(comment.addr) + corrected_addr = decompilation.map_pos_to_addr.get_node(pos).tags['ins_addr'] + # pylint: disable=broad-except + except Exception: + return changed + + dec_cmt = decompilation.stmt_comments.get(corrected_addr, None) + if dec_cmt != comment.comment: + decompilation.stmt_comments[corrected_addr] = comment.comment + changed |= True + else: + kb_cmt = self.main_instance.project.kb.comments.get(comment.addr, None) + if kb_cmt != comment.comment: + self.main_instance.project.kb.comments[comment.addr] = comment.comment + changed |= True + + func_addr = comment.func_addr or self.get_closest_function(comment.addr) + return changed & self.refresh_decompilation(func_addr) + + # structs + def _structs(self) -> Dict[str, Struct]: + """ + Returns a dict of declib.Struct that contain the name and size of each struct in the decompiler. + """ + from angr.sim_type import SimStruct, TypeRef + structs = {} + types_store = self.main_instance.project.kb.types + + for type_ref in types_store.iter_own(): + if not isinstance(type_ref, TypeRef): + continue + sim_type = type_ref.type + if isinstance(sim_type, SimStruct): + structs[type_ref.name] = Struct(type_ref.name, sim_type.size // 8 if sim_type.size else 0, {}) + + return structs + + def _get_struct(self, name) -> Optional[Struct]: + """ + Get a struct by name from the TypesStore. + """ + from angr.sim_type import SimStruct, TypeRef + types_store = self.main_instance.project.kb.types + + try: + type_ref = types_store[name] + except KeyError: + return None + + if not isinstance(type_ref, TypeRef): + return None + + sim_struct = type_ref.type + if not isinstance(sim_struct, SimStruct): + return None + + return self._angr_struct_to_declib(name, sim_struct) + + def _set_struct(self, struct: Struct, header=True, members=True, **kwargs) -> bool: + """ + Create or update a struct in the TypesStore. + """ + from angr.sim_type import SimStruct, TypeRef, parse_type + from collections import OrderedDict + + types_store = self.main_instance.project.kb.types + arch = self.main_instance.project.arch + + # Build the fields OrderedDict from DecLib struct members + fields = OrderedDict() + if members and struct.members: + sorted_members = sorted(struct.members.items(), key=lambda x: x[0]) + for offset, member in sorted_members: + # Parse the member type string into a SimType + try: + sim_type = parse_type(member.type, predefined_types=types_store, arch=arch) + except Exception: + # Fallback to a simple int type with the right size if parsing fails + from angr.sim_type import SimTypeInt + sim_type = SimTypeInt(signed=False).with_arch(arch) + + fields[member.name] = sim_type.with_arch(arch) + + # Create the SimStruct + sim_struct = SimStruct(fields, name=struct.name, pack=True) + sim_struct = sim_struct.with_arch(arch) + + # Wrap it in a TypeRef and store it + type_ref = TypeRef(struct.name, sim_struct) + types_store[struct.name] = type_ref + + return True + + def _del_struct(self, name) -> bool: + """ + Delete a struct from the TypesStore. + """ + types_store = self.main_instance.project.kb.types + + if name in types_store.data: + del types_store.data[name] + return True + + return False + + @staticmethod + def _angr_struct_to_declib(name: str, sim_struct: "angr.sim_type.SimStruct") -> Struct: + """ + Convert an angr SimStruct to a DecLib Struct. + """ + members = {} + if sim_struct._arch is not None: + offsets = sim_struct.offsets + for field_name, sim_type in sim_struct.fields.items(): + offset = offsets.get(field_name, 0) + type_str = sim_type.c_repr() if sim_type else None + size = sim_type.size // 8 if sim_type and sim_type.size else 0 + members[offset] = StructMember(field_name, offset, type_str, size) + + size = sim_struct.size // 8 if sim_struct.size else 0 + return Struct(name, size, members) + + # + # Utils + # + + def info(self, msg: str, **kwargs): + if self._am_logger is not None: + self._am_logger.info(msg) + + def debug(self, msg: str, **kwargs): + if self._am_logger is not None: + self._am_logger.debug(msg) + + def warning(self, msg: str, **kwargs): + if self._am_logger is not None: + self._am_logger.warning(msg) + + def error(self, msg: str, **kwargs): + if self._am_logger is not None: + self._am_logger.error(msg) + + def print(self, msg: str, **kwargs): + if self.headless: + print(msg) + else: + self.info(msg) + + # + # angr-management specific helpers + # + + # TODO: add LRU back one day + #@lru_cache(maxsize=1024) + def addr_starts_instruction(self, addr) -> bool: + """ + Returns True when the provided address maps to a valid instruction address in the binary and that address + is at the start of the instruction (not in the middle). Useful for checking if an instruction is + incorrectly computed due to ARM THUMB. + """ + cfg = self.main_instance.project.kb.cfgs.get_most_accurate() + if cfg is None: + l.warning("Unable load CFG from angr. Other operations may be wrong.") + return False + + node = cfg.get_any_node(addr, anyaddr=True) + if node is None: + return False + + return addr in node.instruction_addrs + + def refresh_decompilation(self, func_addr, event=None): + if self.headless: + return False + + self.workspace.jump_to(func_addr) + view = self.workspace._get_or_create_view("pseudocode", CodeView) + if event: + view.codegen.am_event(event=event) + else: + view.codegen.am_event() + view.focus() + return True + + def _headless_decompile(self, func): + if not func.normalized: + func.normalize() + + return self.main_instance.project.analyses.Decompiler(func, cfg=self._cfg, flavor='pseudocode', preset="full") + + def _angr_management_decompile(self, func): + # recover direct pseudocode + self.main_instance.project.analyses.Decompiler(func, flavor='pseudocode') + + # attempt to get source code if its available + source_root = None + if self.main_instance.original_binary_path: + source_root = os.path.dirname(self.main_instance.original_binary_path) + self.main_instance.project.analyses.ImportSourceCode(func, flavor='source', source_root=source_root) + + def decompile_function(self, func, refresh_gui=False): + # check for known decompilation + available = self.main_instance.project.kb.decompilations.available_flavors(func.addr) + should_decompile = False + if self.headless or 'pseudocode' not in available: + should_decompile = True + else: + cached = self.main_instance.project.kb.decompilations[(func.addr, 'pseudocode')] + if isinstance(cached, DummyStructuredCodeGenerator): + should_decompile = True + + decomp = None + if should_decompile: + if not self.headless: + self._angr_management_decompile(func) + else: + decomp = self._headless_decompile(func) + + # grab newly cached pseudocode + if not self.headless: + decomp = self.main_instance.project.kb.decompilations[(func.addr, 'pseudocode')] + + # refresh the UI after decompiling + if refresh_gui and not self.headless: + self.workspace.reload() + + # re-decompile current view to cause a refresh + current_tab = self.workspace.view_manager.current_tab + if isinstance(current_tab, CodeView) and current_tab.function == func: + self.workspace.decompile_current_function() + + return decomp + + @staticmethod + def find_stack_var_in_codegen(decompilation, stack_offset: int) -> Optional[angr.sim_variable.SimStackVariable]: + for var in decompilation.cfunc.variable_manager._unified_variables: + if hasattr(var, "offset") and var.offset == stack_offset: + return var + + return None + + @staticmethod + def stack_var_type_str(decompilation, stack_var: angr.sim_variable.SimStackVariable): + try: + var_type = decompilation.cfunc.variable_manager.get_variable_type(stack_var) + # pylint: disable=broad-except + except Exception: + return None + + return var_type.c_repr() if var_type is not None else None + + @staticmethod + def stack_vars_in_dec(decompilation): + for var in decompilation.cfunc.variable_manager._unified_variables: + if hasattr(var, "offset"): + yield var + + @staticmethod + def func_args_as_declib_args(decompilation) -> Dict[int, FunctionArgument]: + args = {} + if not decompilation.cfunc.arg_list: + return args + + for idx, arg in enumerate(decompilation.cfunc.arg_list): + type_ = arg.variable_type.c_repr() if arg.variable_type is not None else None + args[idx] = FunctionArgument( + idx, arg.variable.name, type_, arg.variable.size + ) + + return args + + @staticmethod + def func_insn_addrs(func: angr.knowledge_plugins.Function): + insn_addrs = set() + for block in func.blocks: + insn_addrs.update(block.instruction_addrs) + + return insn_addrs + + def get_closest_function(self, addr): + try: + func_addr = self.workspace.main_instance.project.kb.cfgs.get_most_accurate()\ + .get_any_node(addr, anyaddr=True)\ + .function_address + except AttributeError: + func_addr = None + + return func_addr + + @staticmethod + def line_map_from_decompilation(dec): + import ailment + from angr.analyses.decompiler.structured_codegen.c import CStructuredCodeWalker, CFunctionCall, CIfElse, CIfBreak + + if dec is None or dec.codegen is None: + return None + + codegen = dec.codegen + base_addr = dec.project.loader.main_object.image_base_delta + if hasattr(dec, "unoptimized_ail_graph"): + nodes = dec.unoptimized_ail_graph.nodes + else: + l.critical(f"You are likely using an older version of angr that has no unoptimized_ail_graph." + f" Using clinic_graph instead, results will be less accurate...") + nodes = dec.clinic.cc_graph.nodes + + # get the mapping of the original AIL graph + mapping = defaultdict(set) + ail_node_addr_map = { + node.addr: node for node in nodes + } + for addr, ail_block in ail_node_addr_map.items(): + # get instructions of this block + try: + vex_block = dec.project.factory.block(addr) + except Exception: + continue + + ail_block_stmts = [stmt for stmt in ail_block.statements if not isinstance(stmt, ailment.statement.Label)] + if not ail_block_stmts: + continue + + next_ail_stmt_idx = 0 + for ins_addr in vex_block.instruction_addrs: + next_ail_stmt_addr = ail_block_stmts[next_ail_stmt_idx].ins_addr + mapping[next_ail_stmt_addr].add(ins_addr) + if ins_addr == next_ail_stmt_addr: + next_ail_stmt_idx += 1 + if next_ail_stmt_idx >= len(ail_block_stmts): + break + + # node to addr map + ailaddr_to_addr = defaultdict(set) + for k, v in mapping.items(): + for v_ in v: + ailaddr_to_addr[k - base_addr].add(v_ - base_addr) + + codegen.show_externs = False + codegen.regenerate_text() + + decompilation = codegen.text + if not decompilation: + return + + try: + first_code_pos = codegen.map_pos_to_addr.items()[0][0] + except Exception: + return + + # map the position start to an address + pos_addr_map = defaultdict(set) + for start, pos_map in codegen.map_pos_to_addr.items(): + obj = pos_map.obj + if not hasattr(obj, "tags"): + continue + + # leads to mapping at the beginning of loops, so skip. + # see kill.o binary for send_signals + if isinstance(obj, CIfElse): + continue + + ins_addr = obj.tags.get("ins_addr", None) + if ins_addr: + pos_addr_map[start].add(ins_addr - base_addr) + + # find every line + line_end_pos = [i for i, x in enumerate(decompilation) if x == "\n"] + line_to_addr = defaultdict(set) + last_pos = len(decompilation) - 1 + line_to_addr[1].add(codegen.cfunc.addr - base_addr) + for i, pos in enumerate(line_end_pos[:-1]): + if pos == last_pos: + break + + curr_end = line_end_pos[i+1] - 1 + # check if this is the variable decs and header + if curr_end < first_code_pos: + line_to_addr[i+2].add(codegen.cfunc.addr - base_addr) + continue + + # not header, real code + for p_idx in range(pos+1, curr_end+1): + if p_idx in pos_addr_map: + # line_to_addr[str(i+1)].update(pos_addr_map[p_idx]) + for ail_ins_addr in pos_addr_map[p_idx]: + if ail_ins_addr in ailaddr_to_addr: + line_to_addr[i+2].update(ailaddr_to_addr[ail_ins_addr]) + else: + line_to_addr[i+2].add(ail_ins_addr) + + return line_to_addr diff --git a/declib/decompilers/binja/__init__.py b/declib/decompilers/binja/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/declib/decompilers/binja/artifact_lifter.py b/declib/decompilers/binja/artifact_lifter.py new file mode 100644 index 00000000..110416c2 --- /dev/null +++ b/declib/decompilers/binja/artifact_lifter.py @@ -0,0 +1,32 @@ +from declib.api import ArtifactLifter + + +class BinjaArtifactLifter(ArtifactLifter): + lift_map = { + "int64_t": "long long", + "uint64_t": "unsigned long", + "int32_t": "int", + "uint32_t": "unsigned int", + "int16_t": "short", + "uint16_t": "unsigned short", + "int8_t": "char", + "uint8_t": "unsigned char", + } + + def __init__(self, deci): + super(BinjaArtifactLifter, self).__init__(deci) + + def lift_type(self, type_str: str) -> str: + for bn_t, bs_t in self.lift_map.items(): + type_str = type_str.replace(bn_t, bs_t) + + return type_str + + def lift_stack_offset(self, offset: int, func_addr: int) -> int: + return offset + + def lower_type(self, type_str: str) -> str: + return type_str + + def lower_stack_offset(self, offset: int, func_addr: int) -> int: + return offset diff --git a/declib/decompilers/binja/hooks.py b/declib/decompilers/binja/hooks.py new file mode 100644 index 00000000..12c3c4cb --- /dev/null +++ b/declib/decompilers/binja/hooks.py @@ -0,0 +1,201 @@ +from collections import defaultdict +from typing import Dict +import logging + +from .interface import BinjaInterface, BN_AVAILABLE, VALID_FUNC_SYM_TYPES +if BN_AVAILABLE: + import binaryninja + from binaryninja.types import StructureType, EnumerationType + from binaryninja import SymbolType + from binaryninja.binaryview import BinaryDataNotification + +from declib.artifacts import ( + FunctionHeader, FunctionArgument, GlobalVariable, StackVariable, Comment +) + +l = logging.getLogger(__name__) + + +# +# Hooks (callbacks) +# + +class DataMonitor(BinaryDataNotification): + def __init__(self, view, interface): + super().__init__() + self._bv = view + self._interface: BinjaInterface = interface + self._changing_func_addr = None + self._changing_func_pre_change = None + self._seen_comments = defaultdict(dict) + + def function_updated(self, view, func_): + # Updates that occur without a service request are requests for comment changes + if self._changing_func_pre_change is None: + # + # comments + # + + func_addr = func_.start + current_comments = dict(func_.comments) + prev_comments = self._seen_comments[func_addr] + # Changes have only occurred when the comments we see before the change request are different + # from the comments we see now (after the change request) + if current_comments != prev_comments: + + # Find all the comments that may have been: + # 1. Updated in-place + # 2. Deteted + for addr, prev_comment in prev_comments.items(): + curr_comment = current_comments.get(addr, None) + # no change for this comment + if curr_comment == prev_comment: + continue + + self._interface.comment_changed( + Comment( + addr, + str(curr_comment) if curr_comment else "", + decompiled=True, + func_addr=func_addr + ), + deleted=curr_comment is None, + ) + + # Find any comment which was newly added in this change + for addr, curr_comment in current_comments.items(): + if addr in prev_comments: + continue + + if curr_comment: + self._interface.comment_changed( + Comment(addr, str(curr_comment), decompiled=True, func_addr=func_addr) + ) + + self._seen_comments[func_addr] = current_comments + + # service requested function only + if self._changing_func_pre_change is not None and self._changing_func_addr == func_.start: + l.debug("Update on %s being processed...", hex(self._changing_func_addr)) + self._changing_func_addr = None + + # convert to declib Function type for diffing + bn_func = view.get_function_at(func_.start) + bs_func = BinjaInterface.bn_func_to_bs(bn_func) + current_comments = dict(bn_func.comments) + + # + # header + # + + # check if the headers differ + # NOTE: function name done inside symbol update hook + if self._changing_func_pre_change.header.diff(bs_func.header): + old_header: FunctionHeader = self._changing_func_pre_change.header + new_header: FunctionHeader = bs_func.header + + old_args = old_header.args or {} + for off, old_arg in old_args.items(): + new_arg = new_header.args.get(off, None) + if new_arg is None: + # TODO: support deleting args + continue + + if old_arg == new_arg: + continue + + diff_arg = FunctionArgument(off, None, None, None) + if old_arg.name != new_arg.name: + diff_arg.name = str(new_arg.name) + + if old_arg.type != new_arg.type: + diff_arg.type = str(new_arg.type) + + if old_arg.size != new_arg.size: + diff_arg.size = int(new_arg.size) + + self._interface.function_header_changed( + FunctionHeader(None, old_header.addr, args={off: diff_arg}) + ) + + # new func args added to header + for off, new_arg in bs_func.args.items(): + if off in old_args: + continue + + self._interface.function_header_changed( + FunctionHeader(None, old_header.addr, args={ + off: FunctionArgument(off, str(new_arg.name), str(new_arg.type), int(new_arg.size)) + }) + ) + + # + # stack vars + # + + header_args_names = set([arg.name for arg in bs_func.header.args.values()]) + if self._changing_func_pre_change.stack_vars != bs_func.stack_vars: + old_svs: Dict[int, StackVariable] = self._changing_func_pre_change.stack_vars + new_svs: Dict[int, StackVariable] = bs_func.stack_vars + + for off, old_sv in old_svs.items(): + new_sv = new_svs.get(off, None) + if new_sv is None or new_sv.name in header_args_names: + continue + + if old_sv == new_sv: + continue + + diff_sv = StackVariable(off, None, None, old_sv.size, bs_func.addr) + if old_sv.name != new_sv.name: + diff_sv.name = str(new_sv.name) + + if old_sv.type != new_sv.type: + diff_sv.type = str(new_sv.type) + + self._interface.stack_variable_changed(diff_sv) + + for off, new_sv in new_svs.items(): + if off in old_svs or new_sv.name in header_args_names: + continue + + self._interface.stack_variable_changed( + StackVariable(off, str(new_sv.name), str(new_sv.type), new_sv.size, bs_func.addr) + ) + + self._changing_func_pre_change = None + + def function_update_requested(self, view, func): + if self._changing_func_addr is None: + l.debug("Update on %s requested...", func) + self._changing_func_addr = func.start + self._changing_func_pre_change = BinjaInterface.bn_func_to_bs(func) + + def symbol_updated(self, view, sym): + l.debug("Symbol update Requested on %s...", sym) + if sym.type in VALID_FUNC_SYM_TYPES: + l.debug(" -> Function Symbol") + func = view.get_function_at(sym.address) + bs_func = BinjaInterface.bn_func_to_bs(func) + self._interface.function_header_changed( + FunctionHeader(bs_func.name, bs_func.addr) + ) + elif sym.type == SymbolType.DataSymbol: + l.debug(" -> Data Symbol") + var: binaryninja.DataVariable = view.get_data_var_at(sym.address) + self._interface.global_variable_changed( + GlobalVariable(int(sym.address), str(var.name), type_=str(var.type), size=int(var.type.width)) + ) + else: + print(f" -> Other Symbol: {sym.type}") + pass + + def type_defined(self, view, name, type_): + l.debug("Type Defined: %s %s", name, type_) + name = str(name) + if isinstance(type_, StructureType): + bs_struct = BinjaInterface.bn_struct_to_bs(name, type_) + self._interface.struct_changed(bs_struct) + elif isinstance(type_, EnumerationType): + bs_enum = BinjaInterface.bn_enum_to_bs(name, type_) + self._interface.enum_changed(bs_enum) diff --git a/declib/decompilers/binja/interface.py b/declib/decompilers/binja/interface.py new file mode 100644 index 00000000..75db4d08 --- /dev/null +++ b/declib/decompilers/binja/interface.py @@ -0,0 +1,795 @@ +import threading +import functools +from collections import defaultdict +from typing import Dict, Optional, Any, List +import hashlib +import logging + +BN_AVAILABLE = True +try: + import binaryninja +except ImportError: + BN_AVAILABLE = False + +BN_UI_AVAILABLE = True +try: + import binaryninjaui +except Exception: + BN_UI_AVAILABLE = False + +if BN_AVAILABLE: + from binaryninja import SymbolType, PluginCommand, lineardisassembly + from binaryninja.function import DisassemblySettings + from binaryninja.enums import DisassemblyOption, LinearDisassemblyLineType, InstructionTextTokenType + from binaryninja.enums import VariableSourceType + from binaryninja.types import StructureType, EnumerationType +if BN_UI_AVAILABLE: + from binaryninjaui import UIContext + + +import declib +from declib.api.decompiler_interface import DecompilerInterface +from declib.artifacts import ( + Function, FunctionHeader, StackVariable, + Comment, GlobalVariable, Patch, StructMember, FunctionArgument, + Enum, Struct, Artifact, Decompilation, Context, Typedef +) + +from .artifact_lifter import BinjaArtifactLifter + +l = logging.getLogger(__name__) + +# +# Helpers +# + +VALID_FUNC_SYM_TYPES = {SymbolType.FunctionSymbol, SymbolType.LibraryFunctionSymbol} + +def background_and_wait(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + output = [None] + + def thunk(): + output[0] = func(*args, **kwargs) + return 1 + + thread = threading.Thread(target=thunk) + thread.start() + thread.join() + + return output[0] + return wrapper + + +class BinjaInterface(DecompilerInterface): + def __init__(self, bv=None, **kwargs): + self._bv: "binaryninja.BinaryView" = bv + self._data_monitor = None + super(BinjaInterface, self).__init__(name="binja", artifact_lifter=BinjaArtifactLifter(self), **kwargs) + + @property + def bv(self): + if self._bv is None: + l.warning("The BinaryView is not initialized. You may need to pass 'bv=' to the constructor call or discover call.") + + return self._bv + + @bv.setter + def bv(self, bv: "binaryninja.BinaryView"): + if not isinstance(bv, binaryninja.BinaryView): + raise TypeError("The bv must be a BinaryView instance.") + + self._bv = bv + + def _init_headless_components(self, *args, **kwargs): + super()._init_headless_components(*args, **kwargs) + if not BN_AVAILABLE: + raise ImportError("Unable to import binaryninja module. Are you sure you have it installed with an enterprise license?") + + self.bv = binaryninja.load(str(self._binary_path)) + + def _init_gui_components(self, *args, **kwargs): + if binaryninja.core_ui_enabled(): + super()._init_gui_components(*args, **kwargs) + return True + else: + return False + + def _init_gui_plugin(self, *args, **kwargs): + return self + + def __del__(self): + if self.headless and BN_AVAILABLE: + self.bv.file.close() + + # + # GUI + # + + def gui_active_context(self) -> Optional[Context]: + all_contexts = UIContext.allContexts() + if not all_contexts: + return None + + ctx = all_contexts[0] + handler = ctx.contentActionHandler() + if handler is None: + return None + + actionContext = handler.actionContext() + if actionContext is None: + return None + + func_addr = actionContext.function.start if actionContext.function is not None else None + addr = actionContext.address if actionContext.address is not None else None + # TODO: support screen_name + context = Context(addr=addr, func_addr=func_addr) + return self.art_lifter.lift(context) + + def gui_goto(self, func_addr) -> None: + func_addr = self.art_lifter.lower_addr(func_addr) + self.bv.offset = func_addr + + def gui_register_ctx_menu(self, name, action_string, callback_func, category=None, shortcut=None) -> bool: + # TODO: this needs to have a wrapper function that passes the bv to the current deci + # correct name, category, and action_string for Binja + action_string = action_string.replace("/", "\\") + category = category.replace("/", "\\") if category else "" + + PluginCommand.register_for_address( + f"{category}\\{action_string}", + action_string, + callback_func, + is_valid=self.is_bn_func + ) + + if shortcut and BN_UI_AVAILABLE: + try: + from binaryninjaui import UIAction, UIActionHandler + action_name = f"{category}\\{action_string}" if category else action_string + UIAction.registerAction(action_name, shortcut) + # UIAction expects a callable taking a UIActionContext + UIActionHandler.globalActions().bindAction( + action_name, UIAction(lambda ctx: callback_func(None)) + ) + except Exception as e: + l.warning(f"Failed to register Binja shortcut {shortcut!r} for {name}: {e}") + + return True + + def gui_ask_for_string(self, question, title="Plugin Question", default="") -> str: + resp = binaryninja.get_text_line_input(question, title) + return resp.decode() if resp else "" + + def gui_ask_for_choice(self, question: str, choices: list, title="Plugin Question") -> str: + choice_idx = binaryninja.get_choice_input(question, title, choices) + return choices[choice_idx] if choice_idx is not None else "" + + # + # Public API + # + + @property + def binary_base_addr(self) -> int: + return self._get_first_segment_base() + + @property + def binary_hash(self) -> str: + hash_ = "" + try: + hash_ = hashlib.md5(self.bv.file.raw[:]).hexdigest() + except Exception: + pass + + return hash_ + + @property + def binary_path(self) -> Optional[str]: + try: + return self.bv.file.original_filename + except Exception: + return None + + def fast_get_function(self, func_addr) -> Optional[Function]: + func_addr = self.art_lifter.lower_addr(func_addr) + func = self.bv.get_function_at(func_addr) + if not func: + return None + + return self.art_lifter.lift(self.bn_func_to_bs(func)) + + def get_func_size(self, func_addr) -> int: + func_addr = self.art_lifter.lower_addr(func_addr) + func = self.bv.get_function_at(func_addr) + if not func: + return 0 + + return func.highest_address - func.start + + def xrefs_to(self, artifact: Artifact, decompile=False, only_code=False) -> List[Artifact]: + if not isinstance(artifact, Function): + l.warning("xrefs_to is only implemented for functions.") + return [] + + function: Function = self.art_lifter.lower(artifact) + if not function: + return [] + + bn_xrefs = list(self.bv.get_code_refs(function.addr)) + if not only_code: + bn_xrefs.extend(self.bv.get_data_refs(function.addr)) + + xrefs = [] + for bn_xref in bn_xrefs: + if bn_xref.function is None: + continue + + xrefs.append(Function(bn_xref.function.start, 0)) + + return xrefs + + def get_func_containing(self, addr: int) -> Optional[Function]: + addr = self.art_lifter.lower_addr(addr) + funcs = self.bv.get_functions_containing(addr) + if not funcs: + return None + + if len(funcs) > 1: + l.warning("More than one function contains the the address %s", addr) + + bn_func = funcs[0] + return self._get_function(bn_func.start) + + def _decompile(self, function: Function, map_lines=False, **kwargs) -> Optional[Decompilation]: + bv = self.bv + if bv is None: + return + + bn_func = self.addr_to_bn_func(bv, function.addr) + if bn_func is None: + return None + + settings = DisassemblySettings() + settings.set_option(DisassemblyOption.ShowVariableTypesWhenAssigned) + settings.set_option(DisassemblyOption.GroupLinearDisassemblyFunctions) + settings.set_option(DisassemblyOption.WaitForIL) + + decomp_text = "" + obj = lineardisassembly.LinearViewObject.single_function_language_representation(bn_func, settings) + cursor = obj.cursor + line_map = defaultdict(set) + while True: + for ln, line in enumerate(cursor.lines): + if line.type in [ + LinearDisassemblyLineType.FunctionHeaderStartLineType, + LinearDisassemblyLineType.FunctionHeaderEndLineType, + LinearDisassemblyLineType.AnalysisWarningLineType, + ]: + continue + + for i in line.contents.tokens: + if i.type == InstructionTextTokenType.TagToken: + continue + + decomp_text += str(i) + decomp_text += "\n" + if line.contents and line.contents.address is not None: + line_map[ln].add(int(line.contents.address)) + + if not cursor.next(): + break + + decompilation = Decompilation( + addr=function.addr, + text=decomp_text, + decompiler=self.name + ) + if map_lines: + # TODO: make this more accurate! + decompilation.line_map = dict(line_map) + + return decompilation + + def local_variable_names(self, func: Function) -> List[str]: + bn_func = self.addr_to_bn_func(self.bv, self.art_lifter.lower_addr(func.addr)) + if bn_func is None: + return [] + + return [str(var.name) for var in bn_func.vars] + + @background_and_wait + def rename_local_variables_by_names(self, func: Function, name_map: Dict[str, str], **kwargs) -> bool: + bn_func = self.addr_to_bn_func(self.bv, self.art_lifter.lower_addr(func.addr)) + if bn_func is None: + return False + + lvars = { + lvar.name: lvar for lvar in bn_func.vars if lvar.name + } + update = False + for name, lvar in lvars.items(): + new_name = name_map.get(name, None) + if new_name is None: + continue + + lvar.name = new_name + update |= True + + if update: + bn_func.reanalyze() + + return update + + def get_decompilation_object(self, function: Function, **kwargs) -> Optional[object]: + """ + Binary Ninja has no internal object that needs to be refreshed. + """ + return None + + def read_memory(self, addr: int, size: int) -> Optional[bytes]: + if size <= 0: + return b"" + lowered = self.art_lifter.lower_addr(addr) + try: + data = self.bv.read(lowered, size) + except Exception: + return None + if data is None: + return None + return bytes(data) + + def start_artifact_watchers(self): + if not self.artifact_watchers_started: + from .hooks import DataMonitor + if self.bv is None: + raise RuntimeError("Cannot start artifact watchers without a BinaryView.") + + self._data_monitor = DataMonitor(self.bv, self) + self.bv.register_notification(self._data_monitor) + super().start_artifact_watchers() + + def stop_artifact_watchers(self): + if self.artifact_watchers_started: + self.bv.unregister_notification(self._data_monitor) + self._data_monitor = None + super().stop_artifact_watchers() + + # + # Artifact API + # + + # functions + def _set_function(self, func: Function, **kwargs) -> bool: + bn_func = self.bv.get_function_at(func.addr) + if bn_func is None: + return False + + return super()._set_function(func, bn_func=bn_func, **kwargs) + + def _get_function(self, addr, **kwargs) -> Optional[Function]: + bn_func = self.bv.get_function_at(addr) + if bn_func is None: + return None + + return self.bn_func_to_bs(bn_func) + + def _functions(self) -> Dict[int, Function]: + funcs = {} + for bn_func in self.bv.functions: + if not bn_func.symbol.type in VALID_FUNC_SYM_TYPES: + continue + + funcs[bn_func.start] = Function(bn_func.start, bn_func.total_bytes) + funcs[bn_func.start].name = bn_func.name + + return funcs + + # function header + def _set_function_header(self, fheader: FunctionHeader, bn_func=None, **kwargs) -> bool: + updates = False + if not fheader: + return updates + + # func name + if fheader.name and fheader.name != bn_func.name: + bn_func.name = fheader.name + updates |= True + + # ret type + if fheader.type and \ + fheader.type != bn_func.return_type.get_string_before_name(): + + try: + new_type, _ = self.bv.parse_type_string(fheader.type) + except Exception: + new_type = None + + if new_type is not None: + bn_func.return_type = new_type + updates |= True + + # parameters + if not fheader.args: + return updates + + for i, bn_var in enumerate(bn_func.parameter_vars): + bs_var = fheader.args.get(i, None) + if bs_var is None: + continue + + # type + if bs_var.type and bs_var.type != self.art_lifter.lift_type(str(bn_var.type)): + bn_var.type = bs_var.type + updates |= True + # refresh + bn_var = bn_func.parameter_vars[i] + + # name + if bs_var.name and bs_var.name != str(bn_var.name): + bn_var.name = bs_var.name + updates |= True + + return updates + + def _valid_var_for_bn_set(self, bs_var: StackVariable): + # a stopgap for issue reported in: + # https://github.com/binsync/declib/issues/128 + # + # the real fix is likely on the binja side. + return bs_var.offset is not None and bs_var.name is not None + + # stack vars + def _set_stack_variable(self, svar: StackVariable, bn_func=None, **kwargs) -> bool: + updates = False + current_bn_vars: Dict[int, Any] = { + v.storage: v for v in bn_func.stack_layout + if v.source_type == VariableSourceType.StackVariableSourceType and v not in bn_func.parameter_vars + } + + bn_offset = svar.offset + if bn_offset in current_bn_vars: + # name + if svar.name and svar.name != str(current_bn_vars[bn_offset].name): + current_bn_vars[bn_offset].name = svar.name + updates |= True + + # type + if svar.type: + try: + bs_svar_type, _ = self.bv.parse_type_string(svar.type) + except Exception: + bs_svar_type = None + + if bs_svar_type is not None: + if self.art_lifter.lift_type(str(current_bn_vars[bn_offset].type)) != bs_svar_type: + current_bn_vars[bn_offset].type = bs_svar_type + + # this can cause a binja segfault, so we need to check if the var is valid before doing + # normal python try/except + if self._valid_var_for_bn_set(svar): + try: + bn_func.create_user_stack_var(bn_offset, bs_svar_type, svar.name) + bn_func.create_auto_stack_var(bn_offset, bs_svar_type, svar.name) + except Exception as e: + l.warning("BinSync could not sync stack variable at offset %s: %s", bn_offset, e) + + updates |= True + + return updates + + # global variables + def _set_global_variable(self, gvar: GlobalVariable, **kwargs) -> bool: + bn_gvar = self.bv.get_data_var_at(gvar.addr) + global_type = self.bv.parse_type_string(gvar.type) + changed = False + + if bn_gvar is None: + bn_gvar = self.bv.define_user_data_var(gvar.addr, global_type, gvar.name) + changed = True + + if bn_gvar: + self.bv.define_user_data_var(gvar.addr, global_type, gvar.name) + changed = True + + return changed + + def _get_global_var(self, addr) -> Optional[GlobalVariable]: + bn_gvar = self.bv.get_data_var_at(addr) + if bn_gvar is None: + return None + + return GlobalVariable( + addr, + self.bv.get_symbol_at(addr) or f"data_{addr:x}", + type_=str(bn_gvar.type) if bn_gvar.type is not None else None, + size=bn_gvar.type.width + ) + + def _global_vars(self, **kwargs) -> Dict[int, GlobalVariable]: + return { + addr: GlobalVariable(addr, var.name or f"data_{addr:x}") + for addr, var in self.bv.data_vars.items() + } + + # structs + def _set_struct(self, struct: Struct, header=True, members=True, **kwargs) -> bool: + if header: + self.bv.define_user_type(struct.name, binaryninja.Type.structure(packed=True)) + + if members: + # this scope assumes that the type is now defined... if it's not we will error + with binaryninja.Type.builder(self.bv, struct.name) as s: + s.width = struct.size + members = list() + for offset in sorted(struct.members.keys()): + bs_memb = struct.members[offset] + try: + bn_type = self.bv.parse_type_string(bs_memb.type)[0] if bs_memb.type else None + except Exception: + bn_type = None + finally: + if bn_type is None: + bn_type = binaryninja.Type.int(bs_memb.size) + + members.append((bn_type, bs_memb.name)) + s.members = members + + return True + + def _get_struct(self, name) -> Optional[Struct]: + bn_struct = self.bv.types.get(name, None) + if bn_struct is None or not isinstance(bn_struct, StructureType): + return None + + return self.bn_struct_to_bs(name, bn_struct) + + def _del_struct(self, name) -> bool: + return self.bv.undefine_user_type(name) + + def _structs(self) -> Dict[str, Struct]: + return { + name: Struct(''.join(name.name), t.width, {}) for name, t in self.bv.types.items() + if isinstance(t, StructureType) + } + + # enums + def _set_enum(self, enum: Enum, **kwargs) -> bool: + bn_members = list(enum.members.items()) + new_type = binaryninja.TypeBuilder.enumeration(self.bv.arch, bn_members) + self.bv.define_user_type(enum.name, new_type) + return True + + def _get_enum(self, name) -> Optional[Enum]: + bn_enum = self.bv.types.get(name, None) + if bn_enum is None: + return None + + if isinstance(bn_enum, EnumerationType): + return self.bn_enum_to_bs(name, bn_enum) + + return None + + def _enums(self) -> Dict[str, Enum]: + return { + name: self.bn_enum_to_bs(''.join(name.name), t) for name, t in self.bv.types.items() + if isinstance(t, EnumerationType) + } + + # typedef + def _set_typedef(self, typedef: Typedef, **kwargs) -> bool: + base_type = self.bv.parse_type_string(typedef.type)[0] + if base_type is None: + raise ValueError(f"Could not parse the type {typedef.type}") + + # handle primitive types + try: + base_type_name = str(base_type.name) + except NotImplementedError: + base_type_name = str(base_type) + + base_type_ref = binaryninja.TypeBuilder.named_type_reference( + binaryninja.NamedTypeReferenceClass.TypedefNamedTypeClass, base_type_name, base_type_name, + 0, base_type.width + ) + self.bv.define_user_type(typedef.name, base_type_ref) + return True + + def _get_typedef(self, name) -> Optional[Typedef]: + bn_typedef = self.bv.types.get(name, None) + if bn_typedef is None: + return None + + if isinstance(bn_typedef, binaryninja.NamedTypeReferenceType): + return self.bn_typedef_to_bs(name, bn_typedef) + + return None + + def _typedefs(self) -> Dict[str, Typedef]: + return { + name: self.bn_typedef_to_bs(''.join(name.name), t) for name, t in self.bv.types.items() + if isinstance(t, binaryninja.NamedTypeReferenceType) + } + + # patches + def _set_patch(self, patch: Patch, **kwargs) -> bool: + l.warning("Patch setting is unimplemented in Binja") + return False + + def _get_patch(self, addr) -> Optional[Patch]: + l.warning("Patch getting is unimplemented in Binja") + return None + + def _patches(self) -> Dict[int, Patch]: + l.warning("Patch listing is unimplemented in Binja") + return {} + + # comments + def _set_comment(self, comment: Comment, **kwargs) -> bool: + # search for the right function + declib_func = self.get_func_containing(comment.addr) + if declib_func is None: + # in the case of the function not existing, just comment in addr space + self.bv.set_comment_at(comment.addr, comment.comment) + return True + + # func exists for commenting + bn_func = self.addr_to_bn_func(self.bv, comment.addr) + bn_func.set_comment_at(comment.addr, comment.comment) + + def _get_comment(self, addr) -> Optional[Comment]: + non_func_cmt = self.bv.get_comment_at(addr) + if non_func_cmt: + return Comment(addr, non_func_cmt) + + # search for the right function + funcs = self.bv.get_functions_containing(addr) + if not funcs: + return None + + bn_func = funcs[0] + + for _addr, cmt in bn_func.comments.items(): + if addr == _addr: + return Comment( + addr, + cmt, + func_addr=bn_func.start, + decompiled=True + ) + + return None + + def _comments(self) -> Dict[int, Comment]: + # search every single function for comments + comments = {} + for bn_func in self.bv.functions: + if not bn_func.symbol.type in VALID_FUNC_SYM_TYPES: + continue + + comments.update(bn_func.comments) + + # TODO: show non-function based comments + return comments + + # + # Helper converter functions + # + + @staticmethod + def bn_struct_to_bs(name, bn_struct): + members = { + member.offset: StructMember(str(member.name), member.offset, str(member.type), member.type.width) + for member in bn_struct.members if member.offset is not None + } + + return Struct( + str(name), + bn_struct.width if bn_struct.width is not None else 0, + members + ) + + @staticmethod + def bn_func_to_bs(bn_func): + # + # header: name, ret type, args + # + + args = { + i: FunctionArgument(i, parameter.name, parameter.type.get_string_before_name(), parameter.type.width) + for i, parameter in enumerate(bn_func.parameter_vars) + } + # XXX: this a hack to fix the void (*arg) issue + for i, arg in args.items(): + # notice the missing end parenthesis + if arg.type.endswith("(*"): + arg.type = arg.type.replace("(*", "*") + + sync_header = FunctionHeader( + bn_func.name, + bn_func.start, + type_=bn_func.return_type.get_string_before_name(), + args=args + ) + + # + # stack vars + # + + binja_stack_vars = { + v.storage: v for v in bn_func.stack_layout + if v.source_type == VariableSourceType.StackVariableSourceType and v not in bn_func.parameter_vars + } + sorted_stack = sorted(bn_func.stack_layout, key=lambda x: x.storage) + var_sizes = {} + + for off, var in binja_stack_vars.items(): + i = sorted_stack.index(var) + if i + 1 >= len(sorted_stack): + var_sizes[var] = 0 + else: + var_sizes[var] = var.storage - sorted_stack[i].storage + + bs_stack_vars = { + off: declib.artifacts.StackVariable( + off, + var.name, + var.type.get_string_before_name(), + var_sizes[var], + bn_func.start + ) + for off, var in binja_stack_vars.items() + } + + try: + size = bn_func.highest_address - bn_func.start + except Exception as e: + size = 0 + l.critical(f"Failed to grab the size of function because {e}. It's possible the function " + f"is not yet known to Binary Ninja.") + + return Function(bn_func.start, size, header=sync_header, stack_vars=bs_stack_vars) + + @staticmethod + def bn_enum_to_bs(name: str, bn_enum: "binaryninja.EnumerationType"): + members = {} + + for enum_member in bn_enum.members: + if isinstance(enum_member, binaryninja.EnumerationMember) and isinstance(enum_member.value, int): + members[enum_member.name] = enum_member.value + + return Enum(name, members) + + @staticmethod + def bn_typedef_to_bs(name: str, bn_typedef: "binaryninja.NamedTypeReferenceType"): + return Typedef(name, str(bn_typedef.name)) + + @staticmethod + def addr_to_bn_func(bv, address): + funcs = bv.get_functions_containing(address) + try: + func = funcs[0] + except IndexError: + return None + + return func + + def is_bn_func(self, bv, address): + # HACK: update the BV whenever this is used in a context menu + self.bv = bv + func = self.addr_to_bn_func(bv, address) + return func is not None + + def _get_first_segment_base(self) -> int: + """ + Get the virtual address of the first segment. + """ + if self.bv is None: + return None + + # First, try to find a code/executable segment + for segment in self.bv.segments: + return segment.start + + # Fallback to bv.start if no segments found + return self.bv.start diff --git a/declib/decompilers/ghidra/__init__.py b/declib/decompilers/ghidra/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/declib/decompilers/ghidra/artifact_lifter.py b/declib/decompilers/ghidra/artifact_lifter.py new file mode 100644 index 00000000..f0b6176c --- /dev/null +++ b/declib/decompilers/ghidra/artifact_lifter.py @@ -0,0 +1,60 @@ +import logging +import typing + +from declib.api import ArtifactLifter + +_l = logging.getLogger(name=__name__) + +if typing.TYPE_CHECKING: + from .interface import GhidraDecompilerInterface + + +class GhidraArtifactLifter(ArtifactLifter): + lift_map = { + "undefined64": "long long", + "undefined32": "int", + "undefined16": "short", + "undefined8": "char", + "undefined": "char", + "char8": "char[8]", + "char4": "char[4]", + "char2": "char[2]", + "char1": "char", + #"sqword": "long long", + #"qword": "long long", + #"sdword": "int", + #"dword": "int", + #"word": "short", + #"byte": "char", + } + + def lift_type(self, type_str: str) -> str: + og_type_str = type_str + # convert to simple C when possible + for ghidra_t, bs_t in self.lift_map.items(): + type_str = type_str.replace(ghidra_t, bs_t) + + # parse out type decls if needed + type_str = self.type_parser.extract_type_name(type_str) + if type_str is None: + self.deci.error(f"Failed to extract type name from {og_type_str}, defaulting to void *") + type_str = "void *" + + scope_count = type_str.count("/") + if scope_count: + name, scope = self.deci._gscoped_type_to_bs(type_str) + type_str = self.scoped_type_to_str(name, scope=scope) + + return type_str + + def lift_stack_offset(self, offset: int, func_addr: int) -> int: + return offset + + def lower_type(self, type_str: str) -> str: + if self.SCOPE_DELIMITER in type_str: + type_str = self.deci._bs_scoped_type_to_g(type_str) + + return type_str + + def lower_stack_offset(self, offset: int, func_addr: int) -> int: + return offset diff --git a/declib/decompilers/ghidra/compat/__init__.py b/declib/decompilers/ghidra/compat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/declib/decompilers/ghidra/compat/headless.py b/declib/decompilers/ghidra/compat/headless.py new file mode 100644 index 00000000..d29af73d --- /dev/null +++ b/declib/decompilers/ghidra/compat/headless.py @@ -0,0 +1,156 @@ +import logging +from pathlib import Path +from typing import Union, Optional, Tuple + +from pyghidra.core import _analyze_program, _get_language, _get_compiler_spec +from jpype import JClass + +_l = logging.getLogger(__name__) + + +def open_program( + binary_path: Optional[Union[str, Path]] = None, + project_location: Union[str, Path] = None, + project_name: str = None, + program_name: str = None, + analyze=True, + language: str = None, + compiler: str = None, + loader: Union[str, JClass] = None +): + """ + Taken from Pyhidra, but updated to also return the project associated with the program: + https://github.com/dod-cyber-crime-center/pyhidra/blob/c878e91b53498f65f2eb0255e22189a6d172917c/pyhidra/core.py#L178 + """ + from pyghidra.launcher import PyGhidraLauncher, HeadlessPyGhidraLauncher + if binary_path is None and project_location is None: + raise ValueError("You must provide either a binary path or a project location.") + + if not PyGhidraLauncher.has_launched(): + launcher = HeadlessPyGhidraLauncher() + # Force the JVM into AWT-headless mode. The "headless" launcher does not + # set this itself, so if a (possibly stale) DISPLAY is exported Ghidra's + # JVM tries to reach an X server and dies with: + # java.awt.AWTError: Can't connect to X11 window server ... + launcher.add_vmargs("-Djava.awt.headless=true") + launcher.start() + + from ghidra.app.script import GhidraScriptUtil + from ghidra.program.flatapi import FlatProgramAPI + project, program = _setup_project( + binary_path=binary_path, + project_location=project_location, + project_name=project_name, + program_name=program_name if program_name else project_name, + language=language, + compiler=compiler, + loader=loader + ) + GhidraScriptUtil.acquireBundleHostReference() + flat_api = FlatProgramAPI(program) + if analyze: + _analyze_program(flat_api, program) + + return flat_api, project, program + + +def _setup_project( + binary_path: Optional[Union[str, Path]] = None, + project_location: Union[str, Path] = None, + project_name: str = None, + program_name: str = None, + language: str = None, + compiler: str = None, + loader: Union[str, JClass] = None +) -> Tuple["GhidraProject", "Program"]: + from ghidra.base.project import GhidraProject + from ghidra.util.exception import NotFoundException + from java.lang import ClassLoader + from java.io import IOException + + if binary_path is not None: + binary_path = Path(binary_path) + if project_location: + project_location = Path(project_location) + else: + project_location = binary_path.parent + if not project_name: + project_name = f"{binary_path.name}_ghidra" + project_location /= project_name + + # Ensure the project location directory exists + project_location.mkdir(exist_ok=True, parents=True) + + if isinstance(loader, str): + from java.lang import ClassNotFoundException + try: + gcl = ClassLoader.getSystemClassLoader() + loader = JClass(loader, gcl) + except (TypeError, ClassNotFoundException) as e: + raise ValueError from e + + if isinstance(loader, JClass): + from ghidra.app.util.opinion import Loader + if not Loader.class_.isAssignableFrom(loader): + raise TypeError(f"{loader} does not implement ghidra.app.util.opinion.Loader") + + # Open/Create project + program: "Program" = None + try: + project = GhidraProject.openProject(project_location, project_name, True) + # XXX: binsync patch added here: + if binary_path is not None or program_name is not None: + if program_name is None: + program_name = binary_path.name + if project.getRootFolder().getFile(program_name): + program = project.openProgram("/", program_name, False) + except (IOException, NotFoundException): + project = GhidraProject.createProject(project_location, project_name, False) + + # NOTE: GhidraProject.importProgram behaves differently when a loader is provided + # loaderClass may not be null so we must use the correct method override + + if binary_path is not None and program is None: + if language is None: + if loader is None: + program = project.importProgram(binary_path) + else: + program = project.importProgram(binary_path, loader) + if program is None: + raise RuntimeError(f"Ghidra failed to import '{binary_path}'. Try providing a language manually.") + else: + lang = _get_language(language) + comp = _get_compiler_spec(lang, compiler) + if loader is None: + program = project.importProgram(binary_path, lang, comp) + else: + program = project.importProgram(binary_path, loader, lang, comp) + if program is None: + message = f"Ghidra failed to import '{binary_path}'. " + if compiler: + message += f"The provided language/compiler pair ({language} / {compiler}) may be invalid." + else: + message += f"The provided language ({language}) may be invalid." + raise ValueError(message) + if program_name: + program.setName(program_name) + project.saveAs(program, "/", program.getName(), True) + + return project, program + +def close_program(program, project) -> bool: + """ + Returns true if closing was successful, false otherwise. + + """ + from ghidra.app.script import GhidraScriptUtil + + try: + GhidraScriptUtil.releaseBundleHostReference() + project.save(program) + project.close() + return True + except Exception as e: + _l.critical("Failed to close project: %s", e) + + return False diff --git a/declib/decompilers/ghidra/compat/imports.py b/declib/decompilers/ghidra/compat/imports.py new file mode 100644 index 00000000..68580dc6 --- /dev/null +++ b/declib/decompilers/ghidra/compat/imports.py @@ -0,0 +1,78 @@ +import logging + +_l = logging.getLogger(__name__) + + +def get_private_class(path: str): + from java.lang import ClassLoader + from jpype import JClass + + gcl = ClassLoader.getSystemClassLoader() + return JClass(path, loader=gcl) + +from ghidra.framework.model import DomainObjectListener +from ghidra.program.model.symbol import SourceType, SymbolType +from ghidra.program.model.pcode import HighFunctionDBUtil +from ghidra.program.model.data import ( + DataTypeConflictHandler, StructureDataType, ByteDataType, EnumDataType, CategoryPath, TypedefDataType +) +from ghidra.program.util import ChangeManager, ProgramChangeRecord, FunctionChangeRecord +from ghidra.program.database.function import VariableDB, FunctionDB +from ghidra.program.database.symbol import CodeSymbol, FunctionSymbol +from ghidra.program.model.listing import CodeUnit +from ghidra.app.cmd.comments import SetCommentCmd +from ghidra.app.cmd.label import RenameLabelCmd +from ghidra.app.context import ProgramLocationContextAction, ProgramLocationActionContext +from ghidra.app.decompiler import DecompInterface +from ghidra.app.plugin.core.analysis import AutoAnalysisManager +from ghidra.app.util.cparser.C import CParserUtils +from ghidra.app.decompiler import PrettyPrinter +from ghidra.util.task import ConsoleTaskMonitor +from ghidra.util.data import DataTypeParser +from ghidra.util.exception import CancelledException +from docking.action import MenuData +from docking.action.builder import ActionBuilder + +EnumDB = get_private_class("ghidra.program.database.data.EnumDB") +StructureDB = get_private_class("ghidra.program.database.data.StructureDB") +TypedefDB = get_private_class("ghidra.program.database.data.TypedefDB") + +__all__ = [ + # forcefully imported objects + "DomainObjectListener", + "SourceType", + "ChangeManager", + "ProgramChangeRecord", + "FunctionChangeRecord", + "VariableDB", + "FunctionDB", + "CodeSymbol", + "FunctionSymbol", + "ProgramLocationContextAction", + "ProgramLocationActionContext", + "MenuData", + "ActionBuilder", + "HighFunctionDBUtil", + "DataTypeConflictHandler", + "StructureDataType", + "ByteDataType", + "CodeUnit", + "SetCommentCmd", + "EnumDataType", + "CategoryPath", + "TypedefDataType", + "EnumDB", + "RenameLabelCmd", + "SymbolType", + "StructureDB", + "PrettyPrinter", + "ConsoleTaskMonitor", + "DecompInterface", + "AutoAnalysisManager", + "DataTypeParser", + "CParserUtils", + "CancelledException", + "EnumDB", + "StructureDB", + "TypedefDB" +] diff --git a/declib/decompilers/ghidra/compat/state.py b/declib/decompilers/ghidra/compat/state.py new file mode 100644 index 00000000..cab98bb3 --- /dev/null +++ b/declib/decompilers/ghidra/compat/state.py @@ -0,0 +1,54 @@ +import logging + +_l = logging.getLogger(__name__) + + +def _get_python_plugin(flat_api=None): + if flat_api is not None: + state = flat_api.getState() + else: + _l.warning("Using internal ghidra functions without a distinct FlatAPI is likely dangerous!") + # assume it must be either in the globals or __this__ object, but this will likley crash if we are here + gvs = dict(globals()) + state = gvs.get("getState", None) or gvs.get("__this__", None).getState + + tool = state.getTool() + api = None + if tool is not None: + for plugin in state.getTool().getManagedPlugins(): + if plugin.name == "PyGhidraPlugin": + api = plugin + break + else: + raise RuntimeError("PyGhidraPlugin not found") + else: + # This is s special case: semi-headless + # we started ghidra with something like pyhidra.run_script, which causes us to run the current instance + # as if it were a script, not a single service inside ghidra + api = state + + return api + + +def _in_headless_mode(flat_api): + return flat_api is not None and not hasattr(flat_api, "getState") + +# +# Public API for interacting with the Ghidra state +# + + +def get_current_program(flat_api=None) -> "ProgramDB": + api = _get_python_plugin(flat_api=flat_api) if not _in_headless_mode(flat_api) else flat_api + return api.getCurrentProgram() + + +def get_current_address(flat_api=None) -> int: + if _in_headless_mode(flat_api): + raise RuntimeError("Cannot get current address in headless mode") + + addr = _get_python_plugin(flat_api=flat_api).getProgramLocation().getAddress().offset + if addr is not None: + addr = int(addr) + + return addr \ No newline at end of file diff --git a/declib/decompilers/ghidra/compat/transaction.py b/declib/decompilers/ghidra/compat/transaction.py new file mode 100644 index 00000000..f9f8f6ff --- /dev/null +++ b/declib/decompilers/ghidra/compat/transaction.py @@ -0,0 +1,30 @@ +from functools import wraps +import typing + +if typing.TYPE_CHECKING: + from ..interface import GhidraDecompilerInterface + + +class Transaction: + def __init__(self, flat_api, msg="BinSync transaction"): + self._trans_msg = msg + self._flat_api = flat_api + self.trans_id = None + + def __enter__(self): + self.trans_id = self._flat_api.currentProgram.startTransaction(self._trans_msg) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._flat_api.currentProgram.endTransaction(self.trans_id, True) + + +def ghidra_transaction(f): + @wraps(f) + def _ghidra_transaction(self: "GhidraDecompilerInterface", *args, **kwargs): + with Transaction(flat_api=self.flat_api, msg=f"BS::{f.__name__}(args={args})"): + ret_val = f(self, *args, **kwargs) + + return ret_val + + return _ghidra_transaction + diff --git a/declib/decompilers/ghidra/hooks.py b/declib/decompilers/ghidra/hooks.py new file mode 100644 index 00000000..316a4715 --- /dev/null +++ b/declib/decompilers/ghidra/hooks.py @@ -0,0 +1,242 @@ +import logging +import typing +from typing import Tuple, Optional +import threading + +from ...artifacts import FunctionHeader, Function, FunctionArgument, StackVariable, GlobalVariable, Struct, Enum + +if typing.TYPE_CHECKING: + from declib.decompilers.ghidra.interface import GhidraDecompilerInterface + +_l = logging.getLogger(__name__) +from .compat.imports import ( + DomainObjectListener, ChangeManager, ProgramChangeRecord, VariableDB, FunctionDB, CodeSymbol, + FunctionSymbol, FunctionChangeRecord +) +from jpype import JImplements, JOverride + + +@JImplements(DomainObjectListener, deferred=False) +class DataMonitor: + @JOverride + def __init__(self, deci: "GhidraDecompilerInterface"): + self._deci = deci + # Init event lists + self.funcEvents = { + ChangeManager.DOCR_FUNCTION_CHANGED, + ChangeManager.DOCR_FUNCTION_BODY_CHANGED, + ChangeManager.DOCR_VARIABLE_REFERENCE_ADDED, + ChangeManager.DOCR_VARIABLE_REFERENCE_REMOVED + } + + self.symDelEvents = { + ChangeManager.DOCR_SYMBOL_REMOVED + } + + self.symChgEvents = { + ChangeManager.DOCR_SYMBOL_ADDED, + ChangeManager.DOCR_SYMBOL_RENAMED, + ChangeManager.DOCR_SYMBOL_DATA_CHANGED + } + + self.typeEvents = { + ChangeManager.DOCR_SYMBOL_ADDRESS_CHANGED, + ChangeManager.DOCR_DATA_TYPE_CHANGED, + ChangeManager.DOCR_DATA_TYPE_REPLACED, + ChangeManager.DOCR_DATA_TYPE_RENAMED, + ChangeManager.DOCR_DATA_TYPE_SETTING_CHANGED, + ChangeManager.DOCR_DATA_TYPE_MOVED, + ChangeManager.DOCR_DATA_TYPE_ADDED + } + + self.imageBaseEvents = { + ChangeManager.DOCR_IMAGE_BASE_CHANGED + } + + self.TrackedEvents = ( + self.funcEvents | self.symDelEvents | self.symChgEvents | self.typeEvents | self.imageBaseEvents + ) + + @JOverride + def domainObjectChanged(self, ev): + try: + self.do_change_handler(ev) + except Exception as e: + excep_str = str(e).replace('\n', ' ') + self._deci.error(f"Error in domainObjectChanged: {excep_str}") + + def do_change_handler(self, ev): + for record in ev: + if not isinstance(record, ProgramChangeRecord): + continue + + changeType = record.getEventType() + if changeType not in self.TrackedEvents: + # bail out early if we don't care about this event + continue + + new_value = record.getNewValue() + obj = record.getObject() + if changeType in self.funcEvents: + func_change_type = record.getSpecificChangeType() + if func_change_type == FunctionChangeRecord.FunctionChangeType.RETURN_TYPE_CHANGED: + # Function return type changed + header = FunctionHeader( + name=None, addr=obj.getEntryPoint().getOffset(), type_=str(obj.getReturnType()) + ) + self._deci.function_header_changed(header) + + elif changeType in self.typeEvents: + if changeType == ChangeManager.DOCR_SYMBOL_ADDRESS_CHANGED: + # stack variables change address when retyped! + if isinstance(obj, VariableDB): + parent_namespace = obj.getParentNamespace() + storage = obj.getVariableStorage() + if ( + (new_value is not None) and (storage is not None) and bool(storage.isStackStorage()) + and (parent_namespace is not None) + ): + sv = StackVariable( + int(storage.stackOffset), + None, + str(obj.getDataType()), + int(storage.size), + int(obj.parentNamespace.entryPoint.offset) + ) + self._deci.stack_variable_changed( + sv + ) + + else: + try: + struct = self._deci.structs[new_value.name] + # TODO: access old name indicate deletion + # self._deci.struct_changed(Struct(None, None, None), deleted=True) + self._deci.struct_changed(struct) + except KeyError: + pass + if changeType == ChangeManager.DOCR_SYMBOL_ADDRESS_CHANGED: + # stack variables change address when retyped! + if isinstance(obj, VariableDB): + parent_namespace = obj.getParentNamespace() + storage = obj.getVariableStorage() + if ( + (new_value is not None) and (storage is not None) and bool(storage.isStackStorage()) + and (parent_namespace is not None) + ): + self._deci.stack_variable_changed( + StackVariable( + int(storage.stackOffset), + None, + str(obj.getDataType()), + int(storage.size), + int(obj.parentNamespace.entryPoint.offset) + ) + ) + + else: + try: + struct = self._deci.structs[new_value.name] + # TODO: access old name indicate deletion + # self._deci.struct_changed(Struct(None, None, None), deleted=True) + self._deci.struct_changed(struct) + except KeyError: + pass + + try: + enum = self._deci.enums[new_value.name] + # self._deci.enum_changed(Enum(None, None), deleted=True) + self._deci.enum_changed(enum) + except KeyError: + pass + + elif changeType in self.symDelEvents: + # Globals are deleted first then recreated + if isinstance(obj, CodeSymbol): + removed = GlobalVariable(obj.getAddress().getOffset(), obj.getName()) + # deleted kwarg not yet handled by global_variable_changed + self._deci.global_variable_changed(removed, deleted=True) + elif changeType in self.symChgEvents: + # For creation events, obj is stored in newValue + if obj is None and new_value is not None: + obj = new_value + + if changeType == ChangeManager.DOCR_SYMBOL_ADDED: + if isinstance(obj, CodeSymbol): + gvar = GlobalVariable(obj.getAddress().getOffset(), obj.getName()) + self._deci.global_variable_changed(gvar) + elif changeType == ChangeManager.DOCR_SYMBOL_RENAMED: + if isinstance(obj, CodeSymbol): + gvar = GlobalVariable(obj.getAddress().getOffset(), new_value) + self._deci.global_variable_changed(gvar) + if isinstance(obj, FunctionSymbol): + header = FunctionHeader(name=new_value, addr=int(obj.getAddress().offset)) + self._deci.function_header_changed(header) + elif isinstance(obj, VariableDB): + parent_namespace = obj.getParentNamespace() + storage = obj.getVariableStorage() + if ( + (new_value is not None) and (storage is not None) and bool(storage.isStackStorage()) + and (parent_namespace is not None) + ): + self._deci.stack_variable_changed( + StackVariable( + int(obj.variableStorage.stackOffset), + new_value, + None, + None, + int(obj.parentNamespace.entryPoint.offset) + ) + ) + elif isinstance(obj, FunctionDB): + # TODO: Fix argument name support + # changed_arg = FunctionArgument(None, newValue, None, None) + # header = FunctionHeader(None, None, args={None: changed_arg}) + # self._deci.function_header_changed(header) + pass + else: + continue + elif changeType in self.imageBaseEvents: + new_base_addr = int(new_value.getOffset()) if new_value is not None else None + if new_base_addr is not None: + self._deci._binary_base_addr = new_base_addr + + +def create_data_monitor(deci: "GhidraDecompilerInterface"): + data_monitor = DataMonitor(deci) + return data_monitor + + +def _qt_shortcut_to_ghidra(shortcut: str) -> str: + """Convert a Qt-style shortcut like "Ctrl+Shift+D" to Ghidra's "ctrl shift D".""" + if not shortcut: + return "" + parts = shortcut.split("+") + out = [] + for p in parts[:-1]: + out.append(p.strip().lower()) + key = parts[-1].strip() + out.append(key.upper() if len(key) == 1 else key) + return " ".join(out) + + +def create_context_action( + name, action_string, callback_func, category=None, + plugin_name="declib_ghidra", tool=None, shortcut=None, +): + from .compat.imports import ProgramLocationActionContext, ActionBuilder + def _invoke(ctx: ProgramLocationActionContext): + threading.Thread(target=callback_func, daemon=True).start() + + menu_path = [] + if category is not None and "/" in category: + menu_path.extend(category.split("/")) + menu_path.append(action_string) + + b = (ActionBuilder(name, plugin_name) + .popupMenuPath(list(menu_path)) + .withContext(ProgramLocationActionContext) + .validContextWhen(lambda ctx: ctx is not None and ctx.getAddress() is not None) + .onAction(_invoke)) + + return b.buildAndInstall(tool) diff --git a/declib/decompilers/ghidra/interface.py b/declib/decompilers/ghidra/interface.py new file mode 100644 index 00000000..4e546846 --- /dev/null +++ b/declib/decompilers/ghidra/interface.py @@ -0,0 +1,1433 @@ +import os +import re +import sys +import time +import typing +from collections import defaultdict +from pathlib import Path +from typing import Optional, Dict, List, Tuple, Union +import logging +import queue +import threading + +from declib.api import DecompilerInterface, CType +from declib.api.decompiler_interface import requires_decompilation +from declib.artifacts import ( + Function, FunctionHeader, StackVariable, Comment, FunctionArgument, GlobalVariable, Struct, StructMember, Enum, + Decompilation, Context, Artifact, Typedef +) + +from .artifact_lifter import GhidraArtifactLifter +from .compat.transaction import ghidra_transaction +from .compat.headless import close_program, open_program +from .compat.state import get_current_address + +if typing.TYPE_CHECKING: + from ghidra.program.model.listing import Function as GhidraFunction, Program + from ghidra.program.flatapi import FlatProgramAPI + from ghidra.program.model.pcode import HighSymbol + + + +_l = logging.getLogger(__name__) + + +class GhidraDecompilerInterface(DecompilerInterface): + CACHE_TIMEOUT = 5 + _program: Optional["Program"] + flat_api: "FlatProgramAPI" + + def __init__( + self, + flat_api=None, + loop_on_plugin=True, + start_headless_watchers=False, + analyze=True, + project_location: Optional[Union[str, Path]] = None, + project_name: Optional[str] = None, + program_name: Optional[str] = None, + program_obj: Optional["Program"] = None, + language: Optional[str] = None, + **kwargs + ): + self.loop_on_plugin = loop_on_plugin + self.flat_api = flat_api + + # headless-only attributes + self._start_headless_watchers = start_headless_watchers + self._headless_analyze = analyze + self._headless_project_location = project_location + self._headless_project_name = project_name + self._program_name = program_name + self._project = None + self._program = program_obj + self._language = language + + # ui-only attributes + self._data_monitor = None + + # cachable attributes + self._active_ctx = None + self._binary_base_addr = None + self._default_pointer_size = None + self._gsym_size = None + self._max_gsym_size = 50_000 + + # main thread queue + self._main_thread_queue = queue.Queue() + self._results_queue = queue.Queue() + + super().__init__( + name="ghidra", + artifact_lifter=GhidraArtifactLifter(self), + supports_undo=True, + supports_type_scopes=True, + default_func_prefix="FUN_", + **kwargs + ) + + def _init_gui_components(self, *args, **kwargs): + # XXX: yeah, this is bad naming! + if self._start_headless_watchers: + self.start_artifact_watchers() + + super()._init_gui_components(*args, **kwargs) + + def _deinit_headless_components(self): + if self._program is not None and self._project is not None: + close_program(self._program, self._project) + self._project = None + self._program = None + + def _init_headless_components(self, *args, **kwargs): + if self._program is not None: + # We were already provided a program object as part of the instantiation, so just use it + from ghidra.program.flatapi import FlatProgramAPI + self.flat_api = FlatProgramAPI(self._program) + return + + else: + # This interface was not explicitly initialized as part of a GhidraScript, do the setup on our own + if os.getenv("GHIDRA_INSTALL_DIR", None) is None: + raise RuntimeError("GHIDRA_INSTALL_DIR must be set in the environment to use Ghidra headless.") + + flat_api, project, program = open_program( + binary_path=self._binary_path, + analyze=self._headless_analyze, + project_location=self._headless_project_location, + project_name=self._headless_project_name, + program_name=self._program_name, + language=self._language, + ) + self._program = program + self._project = project + self.flat_api = flat_api + if flat_api is None: + raise RuntimeError("Failed to open program with Pyhidra") + + # + # GUI + # + + def start_artifact_watchers(self): + if self.headless: + _l.warning("Artifact watching is not supported in headless mode.") + return + + from .hooks import create_data_monitor + if not self.artifact_watchers_started: + if self.flat_api is None: + raise RuntimeError("Cannot start artifact watchers without FlatProgramAPI.") + + self._data_monitor = create_data_monitor(self) + self.currentProgram.addListener(self._data_monitor) + super().start_artifact_watchers() + + def stop_artifact_watchers(self): + if self.artifact_watchers_started: + self._data_monitor = None + # TODO: generalize superclass method? + super().stop_artifact_watchers() + + def gui_run_on_main_thread(self, func, *args, **kwargs): + self._main_thread_queue.put((func, args, kwargs)) + return self._results_queue.get() + + def gui_register_ctx_menu(self, name, action_string, callback_func, category=None, shortcut=None) -> bool: + from .hooks import create_context_action + + def callback_func_wrap(*args, **kwargs): + try: + callback_func(*args, **kwargs) + except Exception as e: + self.warning(f"Exception in ctx menu callback {name}: {e}") + raise + create_context_action( + name, action_string, callback_func_wrap, category=(category or "DecLib"), + tool=self.flat_api.getState().getTool(), shortcut=shortcut, + ) + return True + + def gui_ask_for_string(self, question, title="Plugin Question", default="") -> str: + answer = self.flat_api.askString(title, question, default) + return answer if answer else "" + + def gui_ask_for_choice(self, question: str, choices: list, title="Plugin Question") -> str: + answer = self.flat_api.askChoice(title, question, choices, choices[0]) + return answer if answer else "" + + def gui_active_context(self) -> Optional[Context]: + active_addr = get_current_address(flat_api=self.flat_api) + if (self._active_ctx is None) or (active_addr is not None and self._active_ctx.addr != active_addr): + gfuncs = self.__fast_function(active_addr) + gfunc = gfuncs[0] if gfuncs else None + # TODO: support scree_name + context = Context(addr=active_addr) + if gfunc is not None: + context.func_addr = int(gfunc.getEntryPoint().getOffset()) + + self._active_ctx = self.art_lifter.lift(context) + + return self._active_ctx + + def gui_goto(self, func_addr) -> None: + func_addr = self.art_lifter.lower_addr(func_addr) + self.flat_api.goTo(self._to_gaddr(func_addr)) + + # + # Mandatory API + # + + def fast_get_function(self, func_addr) -> Optional[Function]: + lowered_addr = self.art_lifter.lower_addr(func_addr) + gfuncs = self.__fast_function(lowered_addr) + gfunc = gfuncs[0] if gfuncs else None + if gfunc is None: + _l.error("Func does not exist at %s", lowered_addr) + + bs_func = self._gfunc_to_bsfunc(gfunc) + lifted_func = self.art_lifter.lift(bs_func) + return lifted_func + + @property + def binary_base_addr(self) -> int: + if self._binary_base_addr is None: + self._binary_base_addr = self._get_first_segment_base() + + return self._binary_base_addr + + @property + def binary_hash(self) -> str: + return self.currentProgram.executableMD5 + + @property + def binary_path(self) -> Optional[str]: + return self.currentProgram.executablePath + + def get_func_size(self, func_addr) -> int: + func_addr = self.art_lifter.lower_addr(func_addr) + gfunc = self._get_nearest_function(func_addr) + if gfunc is None: + _l.critical("Failed to get function size for %s, likely a lifting error, report!", func_addr) + return -1 + + return int(gfunc.getBody().getNumAddresses()) + + def _decompile(self, function: Function, map_lines=False, **kwargs) -> Optional[Decompilation]: + dec_obj = self.get_decompilation_object(function, do_lower=False) + if dec_obj is None: + return None + + dec_results = dec_obj + dec_func = dec_results.getDecompiledFunction() + if dec_func is None: + return None + + decompilation = Decompilation(addr=function.addr, text=str(dec_func.getC()), decompiler=self.name) + if map_lines: + from .compat.imports import PrettyPrinter + + g_func = dec_results.function + linenum_to_addr = defaultdict(set) + linenum_to_addr[1].add(function.addr) + pp = PrettyPrinter(g_func, dec_results.getCCodeMarkup(), None) + for line in pp.getLines(): + ln = line.getLineNumber() + for i in range(line.getNumTokens()): + min_addr = line.getToken(i).getMinAddress() + if min_addr is None: + continue + + linenum_to_addr[ln].add(min_addr.offset) + max_addr = line.getToken(i).getMaxAddress() + if max_addr is not None: + linenum_to_addr[ln].add(max_addr.offset) + + decompilation.line_map = { + k: list(v) for k, v in dict(linenum_to_addr).items() + } + + return decompilation + + def get_decompilation_object(self, function: Function, do_lower=True) -> Optional[object]: + lowered_addr = self.art_lifter.lower_addr(function.addr) if do_lower else function.addr + return self._ghidra_decompile(self._get_nearest_function(lowered_addr)) + + def xrefs_from(self, func_addr: int) -> List[Function]: + """Ghidra callees: use Function.getCalledFunctions for an O(1) hit per caller.""" + from .compat.imports import ConsoleTaskMonitor + + lowered = self.art_lifter.lower_addr(func_addr) + gfunc = self._get_nearest_function(lowered) + if gfunc is None: + return [] + callees: List[Function] = [] + seen = set() + try: + for called_gfunc in gfunc.getCalledFunctions(ConsoleTaskMonitor()): + entry_addr = int(called_gfunc.getEntryPoint().getOffset()) + if entry_addr in seen: + continue + seen.add(entry_addr) + func = Function( + addr=entry_addr, + size=int(called_gfunc.getBody().getNumAddresses()), + header=FunctionHeader(name=str(called_gfunc.getName()), addr=entry_addr), + ) + callees.append(self.art_lifter.lift(func)) + except Exception as exc: + _l.warning("Ghidra xrefs_from(0x%x) failed: %s", func_addr, exc) + return callees + + def xrefs_to_addr(self, addr: int, only_code: bool = False) -> List[Artifact]: + """Ghidra data-xref lookup: walk ReferenceManager refs to ``addr``. + + Backends' stock ``xrefs_to(Function)`` only fires on function entry + points, so it misses data refs to string constants, globals, etc. + This uses Ghidra's ReferenceManager directly and resolves each + referencing instruction back to its containing function. + """ + lowered = self.art_lifter.lower_addr(addr) + return self._ghidra_refs_to_address(lowered, only_code=only_code) + + def _ghidra_refs_to_address(self, lowered_addr: int, only_code: bool = False) -> List[Artifact]: + refs: List[Artifact] = [] + seen_funcs = set() + try: + gaddr = self._to_gaddr(lowered_addr) + reference_manager = self.currentProgram.getReferenceManager() + function_manager = self.currentProgram.getFunctionManager() + ref_iter = reference_manager.getReferencesTo(gaddr) + while ref_iter.hasNext(): + ref = ref_iter.next() + from_addr_g = ref.getFromAddress() + if only_code: + ref_type = ref.getReferenceType() + try: + is_data = ref_type.isData() + except Exception: + is_data = False + if is_data: + continue + gfunc = function_manager.getFunctionContaining(from_addr_g) + if gfunc is None: + continue + entry_addr = int(gfunc.getEntryPoint().getOffset()) + if entry_addr in seen_funcs: + continue + seen_funcs.add(entry_addr) + func = Function( + addr=entry_addr, + size=int(gfunc.getBody().getNumAddresses()), + header=FunctionHeader(name=str(gfunc.getName()), addr=entry_addr), + ) + refs.append(self.art_lifter.lift(func)) + except Exception as exc: + _l.warning("Ghidra reference lookup at 0x%x failed: %s", lowered_addr, exc) + return refs + + def xrefs_to(self, artifact: Artifact, decompile=False, only_code=False) -> List[Artifact]: + if not isinstance(artifact, Function): + raise ValueError("Only functions are supported for xrefs_to") + + # Base function-level xref: who references the entry point. + # Without this, get_callgraph() + xref_from are empty on Ghidra + # because the base class returns `[]`. + lowered = self.art_lifter.lower(artifact) + xrefs = self._ghidra_refs_to_address(lowered.addr, only_code=only_code) + if not decompile: + return xrefs + + artifact: Function + if artifact.dec_obj is None: + artifact = self.functions[artifact.addr] + decompilation_results = self.get_decompilation_object(artifact, do_lower=True) + + high_function = decompilation_results.getHighFunction() + if high_function is None: + return xrefs + + new_xrefs = [] + for global_sym in high_function.getGlobalSymbolMap().getSymbols(): + sym_storage = global_sym.getStorage() + if not sym_storage.isMemoryStorage(): + continue + + gvar = GlobalVariable( + addr=int(sym_storage.getMinAddress().getOffset()), + name=str(global_sym.getName()), + type_=str(global_sym.getDataType().getPathName()) if global_sym.getDataType() else None, + size=int(global_sym.getSize()), + ) + new_xrefs.append(self.art_lifter.lift(gvar)) + + # xrefs are already lifted by _ghidra_refs_to_address; only new_xrefs need lifting. + return xrefs + new_xrefs + + def list_strings(self, filter: Optional[str] = None) -> List[Tuple[int, str]]: + pattern = re.compile(filter) if filter else None + found: Dict[int, str] = {} + try: + program = self.currentProgram + listing = program.getListing() + memory = program.getMemory() + base_addr = self.binary_base_addr + + def _record(gaddr, text: str) -> None: + if not text: + return + block = memory.getBlock(gaddr) if memory is not None else None + if block is None: + return + try: + if not block.isLoaded(): + return + except Exception: + pass + if gaddr.isNonLoadedMemoryAddress(): + return + addr = int(gaddr.getOffset()) + # Java signed longs can surface negative values for synthetic + # addresses (ELF section name tables, overlays, etc.). + if addr < base_addr: + return + if pattern is not None and not pattern.search(text): + return + # First writer wins — defined-data results carry the + # decompiler's own typing / encoding, so we prefer them + # over raw StringSearcher hits at the same address. + found.setdefault(addr, text) + + # Pass 1: strings the decompiler has already committed to a + # defined data type (char[], TerminatedCString, unicode). + data_iter = listing.getDefinedData(True) + while data_iter.hasNext(): + data = data_iter.next() + if not data.hasStringValue(): + continue + try: + raw = data.getValue() + text = str(raw) if raw is not None else "" + except Exception: + continue + _record(data.getAddress(), text) + + # Pass 2: ask Ghidra's own StringSearcher to scan initialized + # memory for ASCII runs. Ghidra's auto-analyzer misses sequences + # that it instead typed as `byte[N]` (e.g. a base64 alphabet + # stored as `uchar[64]`). This uses Ghidra's native detector — + # no parallel byte scanning. + self._scan_strings_via_searcher(program, memory, _record) + except Exception as exc: + _l.warning("Ghidra list_strings failed: %s", exc) + + results: List[Tuple[int, str]] = [ + (self.art_lifter.lift_addr(addr), text) + for addr, text in found.items() + ] + results.sort(key=lambda item: item[0]) + return results + + def _scan_strings_via_searcher(self, program, memory, record) -> None: + """Run Ghidra's StringSearcher over loaded memory. + + The searcher is the same component Ghidra's "Search > For Strings" + command uses. This catches ASCII runs that Ghidra auto-typed as + ``byte[N]`` / ``uchar[N]`` instead of promoting to a string (e.g. a + base64 alphabet stored as ``uchar[64]``). + """ + try: + from ghidra.program.util.string import StringSearcher, FoundStringCallback + from ghidra.util.task import TaskMonitor + from jpype import JImplements, JOverride + except Exception as exc: + _l.warning("StringSearcher unavailable, skipping supplemental scan: %s", exc) + return + + @JImplements(FoundStringCallback) + class _Collector: + def __init__(self, mem, on_string): + self._mem = mem + self._on_string = on_string + + @JOverride + def stringFound(self, found_string): + try: + text = found_string.getString(self._mem) + except Exception: + return + if text is None: + return + self._on_string(found_string.getAddress(), str(text)) + + try: + # ctor args: program, minStringSize, alignment, allCharSizes, + # requireNullTermination. allCharSizes=False keeps us on ASCII; + # the UTF variants would otherwise inflate results with noise. + searcher = StringSearcher(program, 4, 1, False, False) + scan_set = memory.getLoadedAndInitializedAddressSet() + # TaskMonitor.DUMMY is non-null but does nothing — passing None + # here crashes with NullPointerException inside AbstractStringSearcher. + searcher.search(scan_set, _Collector(memory, record), True, TaskMonitor.DUMMY) + except Exception as exc: + _l.warning("StringSearcher pass failed: %s", exc) + + def disassemble(self, addr: int, **kwargs) -> Optional[str]: + lowered = self.art_lifter.lower_addr(addr) + func = self._get_nearest_function(lowered) + if func is None: + return None + + lines: List[str] = [] + try: + listing = self.currentProgram.getListing() + body = func.getBody() + insn_iter = listing.getInstructions(body, True) + while insn_iter.hasNext(): + insn = insn_iter.next() + try: + insn_addr = int(insn.getAddress().getOffset()) + lifted = self.art_lifter.lift_addr(insn_addr) + lines.append(f"0x{lifted:x}:\t{str(insn)}") + except Exception: + continue + except Exception as exc: + _l.warning("Ghidra disassemble failed: %s", exc) + return None + return "\n".join(lines) if lines else None + + def read_memory(self, addr: int, size: int) -> Optional[bytes]: + if size <= 0: + return b"" + lowered = self.art_lifter.lower_addr(addr) + try: + import jpype + memory = self.currentProgram.getMemory() + gaddr = self._to_gaddr(lowered) + byte_array = jpype.JArray(jpype.JByte)(size) + # Memory.getBytes returns the count of bytes copied; on partial + # reads it raises MemoryAccessException, which we treat as the + # caller asked for memory we can't reach. + try: + read = int(memory.getBytes(gaddr, byte_array)) + except Exception as exc: + _l.debug("Ghidra read_memory at 0x%x size=%d failed: %s", lowered, size, exc) + return None + if read <= 0: + return b"" + # JByte values arrive as signed Python ints; mask back to unsigned + # so the resulting bytes match what the binary stores on disk. + return bytes(int(b) & 0xFF for b in byte_array[:read]) + except Exception as exc: + _l.warning("Ghidra read_memory failed: %s", exc) + return None + + # + # Extra API + # + + @property + def default_pointer_size(self) -> int: + if self._default_pointer_size is None: + self._default_pointer_size = int(self.currentProgram.getDefaultPointerSize()) + + return self._default_pointer_size + + def undo(self): + self.currentProgram.undo() + + @requires_decompilation + def local_variable_names(self, func: Function) -> List[str]: + symbols_by_name = self._get_local_variable_symbols(func) + return list(name for name, _ in symbols_by_name) + + @requires_decompilation + def rename_local_variables_by_names(self, func: Function, name_map: Dict[str, str], **kwargs) -> bool: + symbols_by_name = {name: sym for name, sym in self._get_local_variable_symbols(func)} + symbols_to_update = {} + for name, new_name in name_map.items(): + if name not in symbols_by_name or symbols_by_name[name].name == new_name or new_name in symbols_by_name: + continue + + sym: "HighSymbol" = symbols_by_name[name] + symbols_to_update[sym] = (new_name, None) + + return self._update_local_variable_symbols(symbols_to_update) if symbols_to_update else False + + # + # Private Artifact API + # + + def _set_function(self, func: Function, **kwargs) -> bool: + decompilation = self._ghidra_decompile(self._get_nearest_function(func.addr)) + changes = super()._set_function(func, decompilation=decompilation, **kwargs) + return changes + + def _get_function(self, addr, **kwargs) -> Optional[Function]: + func = self._get_nearest_function(addr) + if func is None: + return None + + dec = self._ghidra_decompile(func) + stack_variables = self._stack_variables(addr, decompilation=dec) + args = self._function_args(addr, decompilation=dec) + type_ = self._function_type(addr, decompilation=dec) + func_addr = int(func.getEntryPoint().getOffset()) + return Function( + addr=func_addr, + size=int(func.getBody().getNumAddresses()), + header=FunctionHeader(name=func.getName(), addr=func_addr, args=args, type_=type_), + stack_vars=stack_variables, dec_obj=dec + ) + + def _functions(self) -> Dict[int, Function]: + funcs = {} + func_info = self.__functions() + for addr, name, size in func_info: + funcs[addr] = Function( + addr=addr, size=size, header=FunctionHeader(name=name, addr=addr) + ) + + if not funcs: + _l.warning("Failed to get any functions from Ghidra. Did something break?") + + return funcs + + def _function_args(self, func_addr: int, decompilation=None) -> Dict[int, FunctionArgument]: + decompilation = decompilation or self._ghidra_decompile(self._get_nearest_function(func_addr)) + args = {} + for param_idx in range(decompilation.getHighFunction().getLocalSymbolMap().getNumParams()): + sym = decompilation.getHighFunction().getLocalSymbolMap().getParamSymbol(param_idx) + if not sym.isParameter(): + continue + + args[param_idx] = FunctionArgument( + offset=param_idx, name=str(sym.getName()), type_=str(sym.getDataType().getPathName()), size=int(sym.getSize()) + ) + + return args + + def _function_type(self, addr: int, decompilation=None) -> Optional[str]: + decompilation = decompilation or self._ghidra_decompile(self._get_nearest_function(addr)) + type_pathname = decompilation.getHighFunction().getFunctionPrototype().getReturnType().getPathName() + return type_pathname if type_pathname else None + + @ghidra_transaction + def _set_stack_variables(self, svars: List[StackVariable], **kwargs) -> bool: + from .compat.imports import SourceType + changes = False + if not svars: + return changes + + first_svar = svars[0] + func_addr = first_svar.addr + decompilation = kwargs.get('decompilation', None) or self._ghidra_decompile(self._get_function(func_addr)) + ghidra_func = decompilation.getFunction() if decompilation else self._get_nearest_function(func_addr) + gstack_vars = self.__get_decless_gstack_vars(ghidra_func) # this works because the func was already decompiled + #gstack_vars = self.__get_gstack_vars(decompilation.getHighFunction()) + if not gstack_vars: + return changes + + var_pairs = [] + for svar in svars: + for gstack_var in gstack_vars: + #if svar.offset == gstack_var.storage.stackOffset: + if svar.offset == gstack_var.getStackOffset(): + var_pairs.append((svar, gstack_var)) + break + + rename_pairs = [] + retype_pairs = [] + changes = False + #updates = {} + for svar, gstack_var in var_pairs: + #update_data = [gstack_var.name, None] + if svar.name and svar.name != gstack_var.name: + changes |= True + rename_pairs.append((gstack_var, svar.name)) + #update_data[0] = svar.name + + if svar.type: + parsed_type = self.typestr_to_gtype(svar.type) + if parsed_type is not None and parsed_type != str(gstack_var.getDataType().getPathName()): + changes |= True + retype_pairs.append((gstack_var, parsed_type)) + #update_data[1] = parsed_type + + #updates[gstack_var] = update_data + + self.__set_sym_names(rename_pairs, SourceType.USER_DEFINED) + self.__set_sym_types(retype_pairs, SourceType.USER_DEFINED) + #changes = self._update_local_variable_symbols(updates) + return changes + + def _get_stack_variable(self, addr: int, offset: int, **kwargs) -> Optional[StackVariable]: + gstack_var = self._get_gstack_var(addr, offset) + if gstack_var is None: + return None + + return self._gstack_var_to_bsvar(gstack_var) + + def _stack_variables(self, func_addr: int, decompilation=None) -> Dict[int, StackVariable]: + decompilation = decompilation or self._ghidra_decompile(self._get_nearest_function(func_addr)) + sv_info = self.__stack_variables(decompilation) + stack_variables = {} + for offset, name, type_, size in sv_info: + stack_variables[offset] = StackVariable( + stack_offset=offset, name=name, type_=type_, size=size, addr=func_addr + ) + + return stack_variables + + def _set_function_header(self, fheader: FunctionHeader, decompilation=None, **kwargs) -> bool: + from .compat.transaction import Transaction + from .compat.imports import SourceType, HighFunctionDBUtil + + changes = False + func_addr = fheader.addr + ghidra_func = decompilation.getFunction() if decompilation else self._get_nearest_function(func_addr) + + # func name + if fheader.name and fheader.name != ghidra_func.getName(): + with Transaction(self.flat_api, msg="BS::set_function_header::set_name"): + ghidra_func.setName(fheader.name, SourceType.USER_DEFINED) + changes = True + + # return type + if fheader.type and decompilation is not None: + parsed_type = self.typestr_to_gtype(fheader.type) + if parsed_type is not None and \ + parsed_type != str(decompilation.highFunction.getFunctionPrototype().getReturnType()): + with Transaction(self.flat_api, msg="BS::set_function_header::set_rettype"): + ghidra_func.setReturnType(parsed_type, SourceType.USER_DEFINED) + changes = True + + # args + # TODO: Only works for function arguments passed by register + if fheader.args and decompilation is not None: + params = ghidra_func.getParameters() + if len(params) == 0: + with Transaction(self.flat_api, msg="BS::set_function_header::update_params"): + HighFunctionDBUtil.commitParamsToDatabase( + decompilation.highFunction, + True, + HighFunctionDBUtil.ReturnCommitOption.COMMIT_NO_VOID, + SourceType.USER_DEFINED + ) + + with Transaction(self.flat_api, msg="BS::set_function_header::set_arguments"): + for offset, param in zip(fheader.args, params): + arg = fheader.args[offset] + gtype = self.typestr_to_gtype(arg.type) + param.setName(arg.name, SourceType.USER_DEFINED) + param.setDataType(gtype, SourceType.USER_DEFINED) + changes = True + + return changes + + @ghidra_transaction + def _set_struct(self, struct: Struct, header=True, members=True, **kwargs) -> bool: + from .compat.imports import DataTypeConflictHandler, StructureDataType, ByteDataType, CategoryPath + + data_manager = self.currentProgram.getDataTypeManager() + scope = struct.scope or "" + ghidra_struct = StructureDataType(CategoryPath("/" + scope), struct.name, 0) + for offset in struct.members: + member = struct.members[offset] + ghidra_struct.add(ByteDataType.dataType, 1, member.name, "") + ghidra_struct.growStructure(member.size - 1) + for dtc in ghidra_struct.getComponents(): + if dtc.getFieldName() == member.name: + gtype = self.typestr_to_gtype(member.type if member.type else 'undefined' + str(member.size)) + for i in range(offset, offset + member.size): + ghidra_struct.clearAtOffset(i) + ghidra_struct.replaceAtOffset(offset, gtype, member.size, member.name, "") + break + + # TODO: normalize the size of the struct if it did not grow enough + old_ghidra_struct = self._get_gtype_by_bs_name(struct.scoped_name, Struct) + try: + if old_ghidra_struct is not None: + data_manager.replaceDataType(old_ghidra_struct, ghidra_struct, True) + else: + data_manager.addDataType(ghidra_struct, DataTypeConflictHandler.DEFAULT_HANDLER) + return True + except Exception as ex: + print(f'Error filling struct {struct.name}: {ex}') + return False + + def _get_struct(self, name) -> Optional[Struct]: + ghidra_struct = self._get_gtype_by_bs_name(name, Struct) + if ghidra_struct is None: + return None + + full_struct_name = ghidra_struct.getPathName() + name, scope = self._gscoped_type_to_bs(full_struct_name) + size = 0 if ghidra_struct.isZeroLength() else ghidra_struct.getLength() + + return Struct( + name=name, size=size, members=self._struct_members_from_gstruct(ghidra_struct), scope=scope + ) + + @ghidra_transaction + def _del_struct(self, name) -> bool: + from .compat.imports import ConsoleTaskMonitor + data_manager = self.currentProgram.getDataTypeManager() + gstruct = self._get_gtype_by_bs_name(name, Struct) + try: + success = data_manager.remove(gstruct, ConsoleTaskMonitor()) + if success: + return True + else: + raise Exception('DataManager failed to remove struct') + except Exception as ex: + self.error(f"Failed to remove struct {name}: {ex}") + + + def _structs(self) -> Dict[str, Struct]: + structs = {} + gstructs = self.__gstructs() + for g_scoped_name, gstruct in gstructs: + name, scope = self._gscoped_type_to_bs(g_scoped_name) + size = 0 if gstruct.isZeroLength() else gstruct.getLength() + struct = Struct( + name=name, size=size, members=self._struct_members_from_gstruct(gstruct), scope=scope + ) + structs[struct.scoped_name] = struct + + return structs + + @ghidra_transaction + def _set_comment(self, comment: Comment, **kwargs) -> bool: + from .compat.imports import CodeUnit, SetCommentCmd + + cmt_type = CodeUnit.PRE_COMMENT if comment.decompiled else CodeUnit.EOL_COMMENT + if comment.addr == comment.func_addr: + cmt_type = CodeUnit.PLATE_COMMENT + + if comment.comment: + # TODO: check if comment already exists, and append? + return SetCommentCmd( + self._to_gaddr(comment.addr), cmt_type, comment.comment + ).applyTo(self.currentProgram) + return True + + def _get_comment(self, addr) -> Optional[Comment]: + # TODO: speedup needed here, see global vars for example + comments = self._comments() + return comments.get(addr, None) + + def _comments(self) -> Dict[int, Comment]: + comments = {} + funcs_code_units = self.__function_code_units() + for code_units in funcs_code_units: + for code_unit in code_units: + # TODO: this could be bad if we have multiple comments at the same address (pre and eol) + # eol comment + eol_cmt = code_unit.getComment(0) + if eol_cmt: + addr = int(code_unit.getAddress().getOffset()) + comments[addr] = Comment( + addr=addr, comment=str(eol_cmt) + ) + # pre comment + pre_cmt = code_unit.getComment(1) + if pre_cmt: + addr = int(code_unit.getAddress().getOffset()) + comments[addr] = Comment( + addr=addr, comment=str(pre_cmt), decompiled=True + ) + + return comments + + @ghidra_transaction + def _set_enum(self, enum: Enum, **kwargs) -> bool: + from .compat.imports import EnumDataType, CategoryPath, DataTypeConflictHandler + + data_manager = self.currentProgram.getDataTypeManager() + scope = enum.scope or "" + ghidra_enum = EnumDataType(CategoryPath("/" + scope), enum.name, 4) + for m_name, m_val in enum.members.items(): + ghidra_enum.add(m_name, m_val) + + old_ghidra_enum = self.currentProgram.getDataTypeManager().getDataType(ghidra_enum.getPathName()) + try: + if old_ghidra_enum: + data_manager.replaceDataType(old_ghidra_enum, ghidra_enum, True) + else: + data_manager.addDataType(ghidra_enum, DataTypeConflictHandler.DEFAULT_HANDLER) + return True + except Exception as ex: + self.error(f'Error adding enum {enum.name}: {ex}') + return False + + def _get_enum(self, name) -> Optional[Enum]: + g_enum = self._get_gtype_by_bs_name(name, Enum) + if g_enum is None: + return None + + name, scope = self._gscoped_type_to_bs(g_enum.getPathName()) + members = {_name: val for _name, val in self.__get_enum_members(g_enum)} + return Enum(name=name, members=members, scope=scope) + + def _enums(self) -> Dict[str, Enum]: + enums = {} + enums_by_name = self.__enum_names() + for g_enum_name, g_enum in enums_by_name: + name, scope = self._gscoped_type_to_bs(g_enum_name) + members = {_name: val for _name, val in self.__get_enum_members(g_enum)} + enum = Enum(name=name, members=members, scope=scope) + enums[enum.scoped_name] = enum + + return enums + + @ghidra_transaction + def _set_typedef(self, typedef: Typedef, **kwargs) -> bool: + from .compat.imports import TypedefDataType, CategoryPath, DataTypeConflictHandler + + # validate the typedef basetype + base_g_type = self.typestr_to_gtype(typedef.type) + if base_g_type is None: + raise ValueError(f"Invalid base type for typedef {typedef.name}: {typedef.type}") + + # parse out the correct name + scope = typedef.scope + if not scope: + scope = "" + + # do a full parse of the typedef + ghidra_typedef = TypedefDataType(CategoryPath("/"+scope), typedef.name, base_g_type) + if ghidra_typedef is None: + raise ValueError(f"Failed to create TypedefDataType for {typedef}") + + # get the old typedef if it exists, and override it + g_typename = ghidra_typedef.getPathName() + old_g_typedef = self.currentProgram.getDataTypeManager().getDataType(g_typename) + data_manager = self.currentProgram.getDataTypeManager() + + try: + if old_g_typedef: + data_manager.replaceDataType(old_g_typedef, ghidra_typedef, True) + else: + data_manager.addDataType(ghidra_typedef, DataTypeConflictHandler.DEFAULT_HANDLER) + return True + except Exception as ex: + self.error(f'Error adding typedef {typedef.name}: {ex}') + return False + + def _get_typedef(self, name) -> Optional[Typedef]: + g_typedef = self._get_gtype_by_bs_name(name, Typedef) + if g_typedef is None: + return None + + base_type = g_typedef.getDataType() + if base_type is None: + return None + + norm_name, scope = self._gscoped_type_to_bs(g_typedef.getPathName()) + return Typedef(name=norm_name, type_=str(base_type.getPathName()), scope=scope) + + def _typedefs(self) -> Dict[str, Typedef]: + typedefs = {} + typedefs_by_name = self.__gtypedefs() + for gtype_name, gtypedef in typedefs_by_name: + type_ = gtypedef.getDataType() + if type_ is None: + continue + + type_name = str(type_.getPathName()) + if not type_name or type_name == gtype_name: + continue + + name, scope = self._gscoped_type_to_bs(gtypedef.getPathName()) + bs_typedef = Typedef(name=name, type_=type_name, scope=scope) + # TODO: this could probably go wrong if typedef name and type are of different scopes + typedefs[bs_typedef.scoped_name] = bs_typedef + + return typedefs + + def _gsyms_too_large(self): + if self._gsym_size is None: + self._gsym_size = self.currentProgram.getSymbolTable().getNumSymbols() + + return self._gsym_size > self._max_gsym_size + + @ghidra_transaction + def _set_global_variable(self, gvar: GlobalVariable, **kwargs): + from .compat.imports import RenameLabelCmd, SourceType + + changes = False + if self._gsyms_too_large(): + self.warning("There are too many global symbols in your binary to accurately set. Skipping!") + + g_gvars_info = self.__g_global_variables() + + for addr, name, sym_data, sym in g_gvars_info: + if addr != gvar.addr: + continue + + # we've found the global variable + if gvar.name and gvar.name != name: + cmd = RenameLabelCmd(sym, gvar.name, SourceType.USER_DEFINED) + cmd.applyTo(self.currentProgram) + changes = True + + type_str = str(sym_data.getDataType().getPathName()) if sym_data is not None else None + if gvar.type and gvar.type != type_str: + # TODO: set type + pass + + return changes + + def _get_global_var(self, addr) -> Optional[GlobalVariable]: + gvars = self._global_vars(match_single_offset=addr) + return gvars.get(addr, None) + + def _global_vars(self, match_single_offset=None, **kwargs) -> Dict[int, GlobalVariable]: + if self._gsyms_too_large(): + self.warning("There are too many global symbols in your binary to get all global symbols!") + return {} + + g_gvars_info = self.__g_global_variables() + gvars = {} + for addr, name, sym_data, sym in g_gvars_info: + # speed optimization for single offset lookups + if match_single_offset is not None and match_single_offset != addr: + continue + + type_str = str(sym_data.getDataType().getPathName()) + size = int(self.currentProgram.getListing().getDataAt(sym.getAddress()).getLength()) \ + if type_str != "undefined" else self.default_pointer_size + + gvars[addr] = GlobalVariable(addr=addr, name=name, type_=type_str, size=size) + + return gvars + + # + # Specialized print handlers + # TODO: refactor the below for the new ghidra changes + # + + def print(self, msg, print_local=True, **kwargs): + print(msg) + + def info(self, msg: str, **kwargs): + _l.info(msg) + self.print(self._fmt_log_msg(msg, "INFO"), print_local=False) + + def debug(self, msg: str, **kwargs): + _l.debug(msg) + if _l.level >= logging.DEBUG: + self.print(self._fmt_log_msg(msg, "DEBUG"), print_local=False) + + def warning(self, msg: str, **kwargs): + _l.warning(msg) + self.print(self._fmt_log_msg(msg, "WARNING"), print_local=False) + + def error(self, msg: str, **kwargs): + _l.error(msg) + self.print(self._fmt_log_msg(msg, "ERROR"), print_local=False) + + @staticmethod + def _fmt_log_msg(msg: str, level: str): + full_filepath = Path(__file__) + log_path = str(full_filepath.with_suffix("").name) + for part in full_filepath.parts[:-1][::-1]: + log_path = f"{part}." + log_path + if part == "ghidra": + break + + return f"[{level}] | {log_path} | {msg}" + + # + # Ghidra Specific API + # + + def _gscoped_type_to_bs(self, gscoped_type: str) -> tuple[str, str | None]: + scope = None + if "/" in gscoped_type: + scope_parts = gscoped_type.split("/") + name = scope_parts.pop(-1) + scope = "/".join(scope_parts) + # remove the first slash + if scope.startswith("/"): + scope = scope[1:] + else: + name = gscoped_type + + return name, scope + + def _bs_scoped_type_to_g(self, bs_scoped_type: str) -> str: + name, scope = self.art_lifter.parse_scoped_type(bs_scoped_type) + if scope is None: + return "/" + name + + return f"/{scope}/{name}" + + def _to_gaddr(self, addr: int): + return self.flat_api.toAddr(hex(addr)) + + @property + def currentProgram(self): + from .compat.state import get_current_program + return get_current_program(self.flat_api) + + @ghidra_transaction + def _update_local_variable_symbols(self, symbols: Dict["HighSymbol", Tuple[str, Optional["DataType"]]]) -> bool: + return any([ + r is not None for r in self.__update_local_variable_symbols(symbols) + ]) + + def _get_struct_by_name(self, name: str) -> Optional["StructureDB"]: + """ + Returns None if the struct does not exist or is not a struct. + """ + from .compat.imports import StructureDB + + struct = self.currentProgram.getDataTypeManager().getDataType("/" + name) + return struct if isinstance(struct, StructureDB) else None + + def _struct_members_from_gstruct(self, gstruct: "StructDB") -> Dict[int, StructMember]: + gmemb_info = self.__gstruct_members(gstruct) + members = {} + for offset, field_name, type_, size in gmemb_info: + name = field_name if field_name else f'field_{hex(offset)[2:]}' + members[offset] = StructMember(name=name, offset=offset, type_=type_, size=size) + + return members + + def _get_nearest_function(self, addr: int) -> "GhidraFunction": + func_manager = self.currentProgram.getFunctionManager() + return func_manager.getFunctionContaining(self._to_gaddr(addr)) + + def _get_first_segment_base(self) -> int: + """ + Get the virtual address of the first segment. + """ + memory = self.currentProgram.getMemory() + + # First, try to find an executable segment (typically the code segment) + for block in memory.getBlocks(): + return int(block.getStart().getOffset()) + + # Fallback to image base if no memory blocks found + return int(self.currentProgram.getImageBase().getOffset()) + + def _gstack_var_to_bsvar(self, gstack_var: "LocalVariableDB"): + if gstack_var is None: + return None + + bs_stack_var = StackVariable( + gstack_var.getStackOffset(), + gstack_var.getName(), + str(gstack_var.getDataType().getPathName()), + gstack_var.getLength(), + gstack_var.getFunction().getEntryPoint().getOffset() # Unsure if this is what is wanted here + ) + return bs_stack_var + + def _gfunc_to_bsfunc(self, gfunc: "GhidraFunction"): + if gfunc is None: + return None + + bs_func = Function( + addr=gfunc.getEntryPoint().getOffset(), size=gfunc.getBody().getNumAddresses(), + header=FunctionHeader(name=gfunc.getName(), addr=gfunc.getEntryPoint().getOffset()), + ) + return bs_func + + def _ghidra_decompile_nearest(self, addr: int) -> Optional["DecompileResult"]: + func = self._get_nearest_function(addr) + if func is None: + raise RuntimeError(f"Failed to get nearest function for decompilation at {hex(addr)}") + + dec = self._ghidra_decompile(func) + if dec is None: + raise RuntimeError(f"Failed to decompile function at {hex(addr)}") + + return dec + + def _ghidra_decompile(self, func: "GhidraFunction") -> "DecompileResult": + """ + TODO: this needs to be cached! + @param func: + @return: + """ + from .compat.imports import DecompInterface, ConsoleTaskMonitor + + dec_interface = DecompInterface() + dec_interface.openProgram(self.currentProgram) + dec_results = dec_interface.decompileFunction(func, 0, ConsoleTaskMonitor()) + return dec_results + + def _get_gstack_var(self, func: "GhidraFunction", offset: int) -> Optional["LocalVariableDB"]: + """ + TODO: this needs to be updated that when its called we get decomilation, and pass it to + __get_gstack_vars + + @param func: + @param offset: + @return: + """ + gstack_vars = self.__get_decless_gstack_vars(func) + for var in gstack_vars: + if var.getStackOffset() == offset: + return var + + return None + + def _headless_lookup_struct(self, typestr: str) -> Optional["DataType"]: + """ + This function is mostly a hack because getDataTypeManagerService does not have up to date + datatypes in headless mode, so any structs you create dont get registerd + """ + if not typestr: + return None + + type_: CType = self.type_parser.parse_type(typestr) + if not type_: + # it was not parseable + return None + + # type is known and parseable + if not type_.is_unknown: + return None + + base_type_str = type_.base_type.type + return self.currentProgram.getDataTypeManager().getDataType("/" + base_type_str) + + def typestr_to_gtype(self, typestr: str) -> Optional["DataType"]: + """ + typestr should look something like: + `int` or if a struct `struct name`. + + @param typestr: + @return: + """ + from .compat.imports import DataTypeParser, AutoAnalysisManager + + if not typestr: + return None + + aam = AutoAnalysisManager.getAnalysisManager(self.currentProgram) + dt_service = aam.getDataTypeManagerService() + dt_parser = DataTypeParser(dt_service, DataTypeParser.AllowedDataTypes.ALL) + try: + parsed_type = dt_parser.parse(typestr) + except Exception as e: + parsed_type = None + + dtm = self.currentProgram.getDataTypeManager() + + # attempt a lookup as a custom datatype by name (e.g. "Point") + if parsed_type is None: + lookup = typestr if typestr.startswith("/") else "/" + typestr + parsed_type = dtm.getDataType(lookup) + + # attempt to resolve a pointer to a custom datatype (e.g. "Point *"): + # DataTypeParser can't resolve user structs by bare name and the path + # lookup above only matches non-pointer names, so build the pointer + # explicitly from the resolved base type. + if parsed_type is None and typestr.rstrip().endswith("*"): + base_str = typestr.strip() + ptr_levels = 0 + while base_str.endswith("*"): + base_str = base_str[:-1].strip() + ptr_levels += 1 + base_lookup = base_str if base_str.startswith("/") else "/" + base_str + base_dt = dtm.getDataType(base_lookup) + if base_dt is not None: + from ghidra.program.model.data import PointerDataType + parsed_type = base_dt + for _ in range(ptr_levels): + parsed_type = PointerDataType(parsed_type) + + if parsed_type is None: + _l.warning("Failed to parse type string: %s", typestr) + + return parsed_type + + def prototype_str_to_gtype(self, progotype_str: str) -> Optional["FunctionDefinitionDataType"]: + """ + Strings must look like: + 'void functions1(int p1, int p2)' + """ + from .compat.imports import CParserUtils + + if not progotype_str: + return None + + program = self.currentProgram + return CParserUtils.parseSignature(program, progotype_str) + + def _get_gtype_by_bs_name(self, name: str, bs_type: type[Artifact]) -> Optional["DataType"]: + """ + Returns None if the type does not exist or is not a struct. + """ + from .compat.imports import EnumDB, StructureDB, TypedefDB + + g_type = { + Typedef: TypedefDB, + Struct: StructureDB, + Enum: EnumDB, + }.get(bs_type, None) + if g_type is None: + raise ValueError(f"Invalid type for gtype lookup: {bs_type}") + + g_scoped_name = self._bs_scoped_type_to_g(name) + gtype = self.currentProgram.getDataTypeManager().getDataType(g_scoped_name) + if not gtype: + # TODO: add recovery one day: if the scope is None we should still try to search + #self.warning(f"Failed to get type by name: {g_scoped_name}") + return None + + if not isinstance(gtype, g_type): + #self.warning(f"Type {g_scoped_name} is not a {g_type.__name__}") + return None + + return gtype + + # + # Internal functions that are very dangerous + # + + def __fast_function(self, lowered_addr: int) -> List["GhidraFunction"]: + return [ + self.currentProgram.getFunctionManager().getFunctionContaining(self.flat_api.toAddr(hex(lowered_addr))) + ] + + def __functions(self) -> List[Tuple[int, str, int]]: + return [ + (int(func.getEntryPoint().getOffset()), str(func.getName()), int(func.getBody().getNumAddresses())) + for func in self.currentProgram.getFunctionManager().getFunctions(True) + ] + + def __update_local_variable_symbols(self, symbols: Dict["HighSymbol", Tuple[str, Optional["DataType"]]]) -> List: + from .compat.imports import HighFunctionDBUtil, SourceType + + return [ + HighFunctionDBUtil.updateDBVariable(sym, updates[0], updates[1], SourceType.ANALYSIS) + for sym, updates in symbols.items() + ] + + def _get_local_variable_symbols(self, func: Function) -> List[Tuple[str, "HighSymbol"]]: + return [ + (sym.name, sym) + for sym in func.dec_obj.getHighFunction().getLocalSymbolMap().getSymbols() if sym.name + ] + + + def __get_decless_gstack_vars(self, func: "GhidraFunction") -> List["LocalVariableDB"]: + return [var for var in func.getAllVariables() if var.isStackVariable()] + + + def __get_gstack_vars(self, high_func: "HighFunction") -> List["LocalVariableDB"]: + return [ + var for var in high_func.getLocalSymbolMap().getSymbols() + if var.storage and var.storage.isStackStorage() + ] + + + def __enum_names(self) -> List[Tuple[str, "EnumDB"]]: + from .compat.imports import EnumDB + + return [ + (dType.getPathName(), dType) + for dType in self.currentProgram.getDataTypeManager().getAllDataTypes() + if isinstance(dType, EnumDB) + ] + + + def __stack_variables(self, decompilation) -> List[Tuple[int, str, str, int]]: + return [ + (int(sym.getStorage().getStackOffset()), str(sym.getName()), sym.getDataType().getPathName(), int(sym.getSize())) + for sym in decompilation.getHighFunction().getLocalSymbolMap().getSymbols() + if sym.getStorage().isStackStorage() + ] + + + def __set_sym_names(self, sym_pairs, source_type): + return [ + sym.setName(new_name, source_type) for sym, new_name in sym_pairs + ] + + + def __set_sym_types(self, sym_pairs, source_type): + return [ + sym.setDataType(new_type, False, True, source_type) for sym, new_type in sym_pairs + ] + + + def __gstruct_members(self, gstruct: "StructureDB") -> List[Tuple[int, str, str, int]]: + return [ + (int(m.getOffset()), str(m.getFieldName()), str(m.getDataType().getPathName()), int(m.getLength())) + for m in gstruct.getComponents() + ] + + + def __get_enum_members(self, g_enum: "EnumDB") -> List[Tuple[str, int]]: + return [ + (name, g_enum.getValue(name)) for name in g_enum.getNames() + ] + + + def __g_global_variables(self): + # TODO: this could be optimized more both in use and in implementation + # TODO: this just does not work for bigger than 50k syms + from .compat.imports import SymbolType + + return [ + (int(sym.getAddress().getOffset()), str(sym.getName()), self.currentProgram.getListing().getDataAt(sym.getAddress()), sym) + for sym in self.currentProgram.getSymbolTable().getAllSymbols(True) + if sym.getSymbolType() == SymbolType.LABEL and + self.currentProgram.getListing().getDataAt(sym.getAddress()) and + not self.currentProgram.getListing().getDataAt(sym.getAddress()).isStructure() + ] + + + def __gstructs(self): + return [ + (struct.getPathName(), struct) + for struct in self.currentProgram.getDataTypeManager().getAllStructures() + ] + + + def __gtypedefs(self): + from .compat.imports import TypedefDB + + return [ + (typedef.getPathName(), typedef) + for typedef in self.currentProgram.getDataTypeManager().getAllDataTypes() + if isinstance(typedef, TypedefDB) + ] + + + def __function_code_units(self): + """ + Returns a list of code units for each function in the program. + """ + return [ + [code_unit for code_unit in self.currentProgram.getListing().getCodeUnits(func.getBody(), True)] + for func in self.currentProgram.getFunctionManager().getFunctions(True) + ] + diff --git a/declib/decompilers/ida/__init__.py b/declib/decompilers/ida/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/declib/decompilers/ida/artifact_lifter.py b/declib/decompilers/ida/artifact_lifter.py new file mode 100644 index 00000000..bf5aed27 --- /dev/null +++ b/declib/decompilers/ida/artifact_lifter.py @@ -0,0 +1,51 @@ +import logging + +from declib.api import ArtifactLifter +from declib.artifacts import Segment + +l = logging.getLogger(name=__name__) + +class IDAArtifactLifter(ArtifactLifter): + lift_map = { + "__int64": "long long", + "__int32": "int", + "__int16": "short", + "__int8": "char", + "_BOOL8": "bool", + "_BOOL4": "bool", + "_BOOL2": "bool", + "_BOOL1": "bool", + "_BOOL": "bool", + "_BYTE": "char", + "_WORD": "unsigned short", + "_DWORD": "unsigned int", + "_QWORD": "unsigned long long", + } + + def __init__(self, deci): + super(IDAArtifactLifter, self).__init__(deci) + + def lift_type(self, type_str: str) -> str: + return self.lift_ida_type(type_str) + + def lift_stack_offset(self, offset: int, func_addr: int) -> int: + from . import compat + return compat.ida_to_bs_stack_offset(func_addr, offset) + + def lower_type(self, type_str: str) -> str: + # TODO: this is a hack until https://github.com/binsync/declib/issues/97 is solved + if "/" in type_str: + type_str = type_str.split("/")[-1] + + return type_str + + def lower_stack_offset(self, offset: int, func_addr: int) -> int: + from . import compat + return compat.bs_to_ida_stack_offset(self.lower_addr(func_addr), offset) + + @staticmethod + def lift_ida_type(type_str: str) -> str: + for ida_t, bs_t in IDAArtifactLifter.lift_map.items(): + type_str = type_str.replace(ida_t, bs_t) + + return type_str \ No newline at end of file diff --git a/declib/decompilers/ida/compat.py b/declib/decompilers/ida/compat.py new file mode 100644 index 00000000..1ab252d5 --- /dev/null +++ b/declib/decompilers/ida/compat.py @@ -0,0 +1,2054 @@ +# ---------------------------------------------------------------------------- +# This file is more of a library for making compatibility calls to IDA for +# things such as getting decompiled function names, start addresses, and +# asking for write permission to ida. This will mostly be called in the +# deci. +# +# Note that anything that requires write permission to IDA will need to pass +# through this program if it is not running in the main thread. +# +# ---------------------------------------------------------------------------- +import datetime +import re +import threading +from functools import wraps +import typing +import logging +from packaging.version import Version +import os + +IDA_IS_INTERACTIVE = bool(os.getenv("IDA_IS_INTERACTIVE", False)) +if not IDA_IS_INTERACTIVE: + try: + import ida_kernwin + # this is to support IDA 8.4 and below + IDA_IS_INTERACTIVE |= bool(ida_kernwin.is_idaq()) + except ImportError: + pass + +if not IDA_IS_INTERACTIVE: + try: + # IDA 9+ + import idapro + except ImportError: + # IDA 9 Beta + import ida as idapro + +import idc, idaapi, ida_kernwin, ida_hexrays, ida_funcs, \ + ida_bytes, ida_idaapi, ida_typeinf, idautils, ida_kernwin, ida_segment + +import declib +from declib.artifacts import ( + Struct, FunctionHeader, FunctionArgument, StackVariable, Function, GlobalVariable, Enum, Artifact, Context, Typedef, + StructMember, Segment +) + +from .artifact_lifter import IDAArtifactLifter +if typing.TYPE_CHECKING: + from .interface import IDAInterface + +_l = logging.getLogger(__name__) +_IDA_VERSION = None + +FORM_TYPE_TO_NAME = None +FUNC_FORMS = {"decompilation", "disassembly"} + +def get_form_to_type_name(): + global FORM_TYPE_TO_NAME + if FORM_TYPE_TO_NAME is None: + mapping = { + idaapi.BWN_PSEUDOCODE: "decompilation", + idaapi.BWN_DISASM: "disassembly", + idaapi.BWN_FUNCS: "functions", + idaapi.BWN_STRINGS: "strings" + } + if get_ida_version() >= 900: + mapping.update({ + idaapi.BWN_TILIST: "types" + }) + else: + mapping.update({ + idaapi.BWN_STRINGS: "structs", + idaapi.BWN_ENUMS: "enums", + 0x3c: "types" + }) + FORM_TYPE_TO_NAME = mapping + + return FORM_TYPE_TO_NAME + +# +# Wrappers for IDA Main thread r/w operations +# a special note about these functions: +# Any operation that needs to do some type of write to the ida db (idb), needs to be in the main thread due to +# some ida constraints. Sometimes reads also need to be in the main thread. To make things efficient, most heavy +# things are done in the deci and just setters and getters are done here. +# + + +def is_mainthread(): + """ + Return a bool that indicates if this is the main application thread. + """ + return isinstance(threading.current_thread(), threading._MainThread) + + +def execute_sync(func, sync_type): + """ + Synchronize with the disassembler for safe database access. + Modified from https://github.com/vrtadmin/FIRST-plugin-ida + """ + + @wraps(func) + def wrapper(*args, **kwargs): + output = [None] + + # + # this inline function definition is technically what will execute + # in the context of the main thread. we use this thunk to capture + # any output the function may want to return to the user. + # + + def thunk(): + output[0] = func(*args, **kwargs) + return 1 + + if is_mainthread(): + thunk() + else: + idaapi.execute_sync(thunk, sync_type) + + # return the output of the synchronized execution + return output[0] + return wrapper + +# TODO: a while ago we moved away from using read, but that was wrong. We should refactor the below code at some point +# to use the correct sync type instead of always execute_write +def execute_read(func): + return execute_sync(func, idaapi.MFF_READ) + + +def execute_write(func): + return execute_sync(func, idaapi.MFF_WRITE) + + +def execute_ui(func): + return execute_sync(func, idaapi.MFF_FAST) + + + +# +# Decompilation +# + + +class DummyIDACodeView: + """ + A stand-in for an IDA pseudocode ``vdui`` used in headless mode (where no GUI + view exists). It exposes the two mutations ``set_stack_variables`` needs — + renaming and retyping a local/stack variable — implemented against the + headless Hexrays APIs (``rename_lvar`` / ``modify_user_lvar_info``) so that + edits actually persist to the database. Any other attribute access falls + back to a no-op (e.g. ``refresh_view`` has nothing to refresh headless). + """ + def __init__(self, addr): + self.cfunc = ida_hexrays.decompile(addr) + self.addr = addr + + def __getattr__(self, item): + return lambda *x,**y: None + + def rename_lvar(self, lvar, name, is_user=1) -> bool: + """Rename a local variable headlessly. Mirrors vdui.rename_lvar.""" + ok = bool(ida_hexrays.rename_lvar(self.addr, lvar.name, name)) + if ok: + self._refresh_cfunc() + return ok + + def set_lvar_type(self, lvar, new_type) -> bool: + """Set a local variable's type headlessly. Mirrors vdui.set_lvar_type. + + Uses ``modify_user_lvar_info`` with ``MLI_TYPE`` since there is no GUI + ``vdui`` to drive; the user-lvar settings persist across redecompilation. + """ + info = ida_hexrays.lvar_saved_info_t() + info.ll.location = lvar.location + info.ll.defea = lvar.defea + info.type = new_type + ok = bool(ida_hexrays.modify_user_lvar_info(self.addr, ida_hexrays.MLI_TYPE, info)) + if ok: + self._refresh_cfunc() + return ok + + def _refresh_cfunc(self): + # Re-decompile so callers re-reading ``cfunc.lvars`` see the change. + new_cfunc = ida_hexrays.decompile(self.addr) + if new_cfunc is not None: + self.cfunc = new_cfunc + + +def requires_decompilation(f): + @wraps(f) + def _requires_decompilation(*args, **kwargs): + artifact = args[0] + if isinstance(artifact, Artifact): + addr = artifact.addr + else: + addr = artifact + + has_ui = not kwargs.get('headless', False) + has_decompiler = kwargs.get('decompiler_available', True) + ida_code_view = kwargs.get('ida_code_view', None) + + if ida_code_view is None and has_decompiler: + kwargs['ida_code_view'] = acquire_pseudocode_vdui(addr) if has_ui else DummyIDACodeView(addr) + + return f(*args, **kwargs) + + return _requires_decompilation + + +@execute_write +def get_func_ret_type(ea): + tinfo = ida_typeinf.tinfo_t() + got_info = idaapi.get_tinfo(tinfo, ea) + return str(tinfo.get_rettype()) if got_info else None + + +@execute_write +def get_func(ea): + return idaapi.get_func(ea) + + +def set_func_ret_type(ea, return_type_str): + tinfo = ida_typeinf.tinfo_t() + if not idaapi.get_tinfo(tinfo, ea): + _l.warning("Failed to get tinfo for function at %s", hex(ea)) + return False + + new_type = convert_type_str_to_ida_type(return_type_str) + if new_type is None: + _l.warning("Failed to convert type string %s to ida type.", return_type_str) + return False + + func_type_data = ida_typeinf.func_type_data_t() + if not tinfo.get_func_details(func_type_data): + _l.warning("Failed to get function details for function at %s", hex(ea)) + return False + + func_type_data.rettype = new_type + new_func_type = ida_typeinf.tinfo_t() + if not new_func_type.create_func(func_type_data): + _l.warning("Failed to create new function type for function at %s", hex(ea)) + return False + + # Apply the new function type to the function + if not idaapi.apply_tinfo(ea, new_func_type, idaapi.TINFO_DEFINITE): + _l.warning("Failed to apply new function type for function at %s", hex(ea)) + return False + + return True + +# +# Types +# + + +def get_ida_version(): + global _IDA_VERSION + if _IDA_VERSION is None: + _IDA_VERSION = idaapi.IDA_SDK_VERSION + + return _IDA_VERSION + +def get_ida_gui_version() -> str: + return "PySide6" if get_ida_version() >= 920 else "PyQt5" + + +def new_ida_typing_system(): + return get_ida_version() >= 840 + + +def get_ordinal_count(): + if new_ida_typing_system(): + return ida_typeinf.get_ordinal_count(idaapi.get_idati()) + else: + return ida_typeinf.get_ordinal_qty(idaapi.get_idati()) + + +@execute_write +def get_types(structs=True, enums=True, typedefs=True) -> typing.Dict[str, Artifact]: + types = {} + idati = idaapi.get_idati() + + for ord_num in range(1, get_ordinal_count()+1): + tif = ida_typeinf.tinfo_t() + success = tif.get_numbered_type(idati, ord_num) + if not success: + continue + + is_typedef, name, type_name = typedef_info(tif, use_new_check=True) + # must check typedefs first, since they can be structs + if is_typedef: + if typedefs: + types[name] = Typedef(name, type_name) + continue + + if structs and tif.is_struct(): + bs_struct = bs_struct_from_tif(tif) + # IDA exposes nested types inside anonymous unions/structs as separate + # numbered types whose name is "$PARENT_HASH::member" — that qualified + # form can't be looked up via get_named_type_tid, so skip it. + if bs_struct.name and "::" not in bs_struct.name: + types[bs_struct.name] = bs_struct + elif enums and tif.is_enum(): + bs_enum = enum_from_tif(tif) + if bs_enum is not None: + types[bs_enum.name] = bs_enum + + return types + + +@execute_write +def get_ord_to_type_names() -> typing.Dict[int, typing.Tuple[str, typing.Type[Artifact]]]: + idati = idaapi.get_idati() + ord_to_name = {} + for ord_num in range(1, get_ordinal_count()+1): + tif = ida_typeinf.tinfo_t() + success = tif.get_numbered_type(idati, ord_num) + if not success: + continue + + type_name = tif.get_type_name() + if tif.is_typedef(): + type_type = Typedef + elif tif.is_struct(): + type_type = Struct + elif tif.is_enum(): + type_type = Enum + else: + type_type = None + + if type_name: + ord_to_name[ord_num] = (type_name, type_type) + + return ord_to_name + + +def get_ida_type(ida_ord=None, name=None): + tif = ida_typeinf.tinfo_t() + idati = idaapi.get_idati() + if ida_ord is not None: + success = tif.get_numbered_type(idati, ida_ord) + if not success: + return None + elif name is not None: + success = tif.get_named_type(idati, name) + if not success: + return None + else: + return None + + return tif + +# +# Type Converters +# + +def type_str_to_size(type_str) -> typing.Optional[int]: + ida_type = convert_type_str_to_ida_type(type_str) + if ida_type is None: + return None + + return ida_type.get_size() + +@execute_write +def convert_type_str_to_ida_type(type_str) -> typing.Optional['ida_typeinf']: + if type_str is None or not isinstance(type_str, str): + return None + + tif = ida_typeinf.tinfo_t() + if type_str.strip() == "void": + valid_parse = tif.create_simple_type(ida_typeinf.BT_VOID) + else: + ida_type_str = type_str + ";" + valid_parse = ida_typeinf.parse_decl(tif, None, ida_type_str, 1) + + return tif if valid_parse is not None else None + +@execute_write +def convert_size_to_flag(size): + """ + Converts a size to the arch specific flag. + + Inspired by: https://github.com/arizvisa/ida-minsc/blob/master/base/_interface.py + + :param size: in bytes + :return: ida flag_t + """ + + size_map = { + 1: idaapi.byte_flag(), + 2: idaapi.word_flag(), + 4: idaapi.dword_flag(), + 8: idaapi.qword_flag() + } + + try: + flag = size_map[size] + except KeyError: + # just always assign something + flag = idaapi.byte_flag() + + return flag + + +# +# IDA Function r/w +# + +@execute_write +def ida_func_addr(addr): + ida_func = ida_funcs.get_func(addr) + if ida_func is None: + return None + + func_addr = ida_func.start_ea + return func_addr + + +@execute_write +def get_func_name(ea) -> typing.Optional[str]: + return idc.get_func_name(ea) + + +@execute_write +def get_func_size(ea): + func = idaapi.get_func(ea) + if not func: + raise ValueError("Unable to find function!") + + return func.size() + + +@execute_write +def set_ida_func_name(func_addr, new_name): + idaapi.set_name(func_addr, new_name, idaapi.SN_FORCE) + ida_kernwin.request_refresh(ida_kernwin.IWID_DISASM) + # XXX: why was this here?!?!? + #ida_kernwin.request_refresh(ida_kernwin.IWID_STRUCTS) + ida_kernwin.request_refresh(ida_kernwin.IWID_STKVIEW) + +def get_segment_range(segment_name) -> typing.Tuple[bool, int, int]: + # Find the segment by name + seg = ida_segment.get_segm_by_name(segment_name) + if seg is None: + return False, None, None + + start_ea = seg.start_ea + end_ea = seg.end_ea + return True, start_ea, end_ea + +@execute_write +def fast_get_function(ea, get_rtype=True): + ida_func = ida_funcs.get_func(ea) + if ida_func is None: + return None + + ret_type = get_func_ret_type(ea) if get_rtype else None + func_name = get_func_name(ea) + func_size = ida_func.size() + header = FunctionHeader( + addr=ea, + name=func_name, + type_=ret_type, + ) + func = Function(addr=ea, size=func_size, header=header) + return func + +@execute_write +def functions(): + blacklisted_segs = ["extern", ".plt", ".plt.sec"] + seg_to_range = {} + for seg in blacklisted_segs: + success, start, end = get_segment_range(seg) + if success: + seg_to_range[seg] = (start, end) + + funcs = {} + for func_addr in idautils.Functions(): + in_bad_seg = False + for seg, (start, end) in seg_to_range.items(): + if start <= func_addr < end: + in_bad_seg = True + break + + if in_bad_seg: + continue + + ida_function = fast_get_function(func_addr) + funcs[func_addr] = ida_function + + return funcs + + +@execute_write +@requires_decompilation +def function(addr, decompiler_available=True, ida_code_view=None, **kwargs): + ida_func = ida_funcs.get_func(addr) + if ida_func is None: + _l.warning("IDA function does not exist for %s.", hex(addr)) + return None + + func_addr = ida_func.start_ea + #change_time = datetime.datetime.now(tz=datetime.timezone.utc) + change_time = None + func = Function(func_addr, get_func_size(func_addr), last_change=change_time) + + if not decompiler_available: + func.header = FunctionHeader(get_func_name(func_addr), func_addr, last_change=change_time) + return func + + def _get_func_info(code_view): + if code_view is None: + _l.warning("IDA function %s is not decompilable", hex(func_addr)) + return func + + func_header: FunctionHeader = function_header(code_view) + stack_vars = { + offset: var + for offset, var in get_func_stack_var_info(ida_func.start_ea).items() + } + func.header = func_header + func.stack_vars = stack_vars + + return func + + if ida_code_view is not None: + func = _get_func_info(ida_code_view) + else: + with IDAViewCTX(func_addr) as ida_code_view: + func = _get_func_info(ida_code_view) + + func.dec_obj = ida_code_view.cfunc if ida_code_view is not None else None + return func + + +@execute_write +def set_function(func: Function, decompiler_available=True, **kwargs): + changes = False + + # acquire decompilation if it is needed + ida_code_view = kwargs.get('ida_code_view', None) + headless = kwargs.get('headless', False) + # these changes require a decompiler + needs_decompilation = bool(func.stack_vars) or bool(func.header.args) + if needs_decompilation and ida_code_view is None and decompiler_available: + ida_code_view = acquire_pseudocode_vdui(func.addr) if not headless else DummyIDACodeView(func.addr) + + # function header, may be only name if no decompiler + if func.header and needs_decompilation and ida_code_view is not None: + changes |= set_function_header(func.header, ida_code_view=ida_code_view) + elif func.header: + if func.name: + set_ida_func_name(func.addr, func.name) + if ida_code_view is None and decompiler_available: + ida_code_view = acquire_pseudocode_vdui(func.addr) if not headless else DummyIDACodeView(func.addr) + changes |= True + if func.type: + changes |= set_func_ret_type(func.addr, func.type) + + # stack vars + if func.stack_vars and ida_code_view is not None: + changes |= set_stack_variables(func.stack_vars, ida_code_view=ida_code_view) + + if changes and ida_code_view is not None: + ida_code_view.refresh_view(changes) + ida_code_view.cfunc.refresh_func_ctext() + + return changes + +@execute_write +def function_header(ida_code_view) -> FunctionHeader: + func_addr = ida_code_view.cfunc.entry_ea + + # collect the function arguments + func_args = {} + for idx, arg in enumerate(ida_code_view.cfunc.arguments): + size = arg.width + name = arg.name + type_str = str(arg.type()) + func_args[idx] = FunctionArgument(idx, name, type_str, size) + + # collect the header ret_type and name + func_name = get_func_name(func_addr) + try: + ret_type_str = str(ida_code_view.cfunc.type.get_rettype()) + except Exception: + ret_type_str = "" + #change_time = datetime.datetime.now(tz=datetime.timezone.utc) + change_time = None + ida_function_info = FunctionHeader(func_name, func_addr, type_=ret_type_str, args=func_args, last_change=change_time) + return ida_function_info + +@execute_write +@requires_decompilation +def set_function_header(bs_header: declib.artifacts.FunctionHeader, exit_on_bad_type=False, ida_code_view=None): + data_changed = False + func_addr = ida_code_view.cfunc.entry_ea + cur_ida_func = function_header(ida_code_view) + + # + # FUNCTION NAME + # + + if bs_header.name and bs_header.name != cur_ida_func.name: + set_ida_func_name(func_addr, bs_header.name) + + # + # FUNCTION RET TYPE + # + + func_name = get_func_name(func_addr) + cur_ret_type_str = str(ida_code_view.cfunc.type.get_rettype()) + if bs_header.type and bs_header.type != cur_ret_type_str: + old_prototype = str(ida_code_view.cfunc.type).replace("(", f" {func_name}(", 1) + new_prototype = old_prototype.replace(cur_ret_type_str, bs_header.type, 1) + parsed_new_proto = convert_type_str_to_ida_type(new_prototype) + if parsed_new_proto is None and exit_on_bad_type: + _l.warning("Failed to parse new prototype %s", new_prototype) + return False + + success = False + if parsed_new_proto is not None: + success = bool( + ida_typeinf.apply_tinfo(func_addr, convert_type_str_to_ida_type(new_prototype), ida_typeinf.TINFO_DEFINITE) + ) + + # we may need to reload types + if success is None and exit_on_bad_type: + return False + + data_changed |= success is True + ida_code_view.refresh_view(data_changed) + + # + # FUNCTION ARGS + # + + types_to_change = {} + for idx, bs_arg in bs_header.args.items(): + if not bs_arg: + continue + + if idx >= len(cur_ida_func.args): + break + + cur_ida_arg = cur_ida_func.args[idx] + + # record the type to change + if bs_arg.type and bs_arg.type != cur_ida_arg.type: + types_to_change[idx] = (cur_ida_arg.type, bs_arg.type) + + # change the name + if bs_arg.name and bs_arg.name != cur_ida_arg.name: + success = bool(ida_code_view.rename_lvar(ida_code_view.cfunc.arguments[idx], bs_arg.name, 1)) + data_changed |= success + + # crazy prototype parsing + func_prototype = str(ida_code_view.cfunc.type).replace("(", f" {func_name}(", 1) + proto_split = func_prototype.split("(", maxsplit=1) + proto_head, proto_body = proto_split[0], "(" + proto_split[1] + arg_strs = proto_body.split(",") + + # update prototype body from left to right + for idx in range(len(cur_ida_func.args)): + try: + old_t, new_t = types_to_change[idx] + except KeyError: + continue + + arg_strs[idx] = arg_strs[idx].replace(old_t, new_t, 1) + + # set the change + proto_body = ",".join(arg_strs) + new_prototype = proto_head + proto_body + success = idc.SetType(func_addr, new_prototype) + + # we may need to reload types + if success is None and exit_on_bad_type: + return False + + data_changed |= success is True + return data_changed + + +def bs_header_from_tif(tif, name=None, addr=None): + """ + Takes a ida_typeinf.tinfo_t and converts it into a BinSync FunctionHeader. + You can optionally specify the name of the function, which is usually not in the tif, otherwise it will be None. + + TODO: its kinda broken, better to use vdui ptr and grab artifacts + """ + ret_type = str(tif.get_rettype()) + bs_header = FunctionHeader(name, addr, type_=ret_type, args={}) + + nargs = tif.get_nargs() + if not nargs: + return bs_header + + bs_args = {} + # construct a really wack regex which essentially finds where the args are in the prototype + proto_str_regex = "\\(" + for idx in range(nargs): + arg_ida_type = tif.get_nth_arg(idx) + bs_arg = FunctionArgument(idx, None, str(arg_ida_type), arg_ida_type.get_size()) + bs_args[bs_arg.offset] = bs_arg + + # make sure the * does not make it into the regex + arg_type_str = bs_arg.type.replace("*", "\\*").replace("(", "\\(").replace(")", "") + # every arg has some space and a name, group the name + proto_str_regex += rf"\s*{arg_type_str}\s*(.+?)" + if idx != nargs - 1: + proto_str_regex += "," + + proto_str_regex += "\\)" + matches = re.findall(proto_str_regex, str(tif)) + if not matches: + _l.warning("Failed to parse a function header with header: %s", str(tif)) + return bs_header + + match = matches[0] + for i, name in enumerate(match): + bs_args[i].name = name + + return bs_header + + +# +# Variables +# + + +def lvars_to_bs(lvars: list, vdui=None, var_names: list = None, recover_offset=False) -> list[typing.Union[FunctionArgument, StackVariable]]: + bs_vars = [] + arg_name_to_off = {} + if var_names and len(var_names) == len(lvars): + if recover_offset: + for offset, _lvar in enumerate(vdui.cfunc.lvars): + if _lvar.is_arg_var: + arg_name_to_off[_lvar.name] = offset + + for lvar_off, lvar in enumerate(lvars): + if lvar is None: + # this should really never happen + continue + + if vdui is None: + _l.warning("Cannot gather local variables from decompilation that does not exist!") + return bs_vars + + if lvar.is_arg_var: + if recover_offset: + offset = arg_name_to_off.get(lvar.name, None) + if offset is None: + continue + else: + offset = lvar_off + bs_cls = FunctionArgument + elif lvar.is_stk_var(): + offset = lvar.location.stkoff() + bs_cls = StackVariable + elif lvar.is_reg_var(): + # TODO: implement register variables + continue + else: + continue + + name = None + if var_names: + name = var_names[lvar_off] + if not name: + name = lvar.name + type_ = str(lvar.type()) + size = lvar.width + + var = bs_cls(name=name, type_=type_, size=size) + var.offset = offset + if isinstance(var, StackVariable): + var.addr = vdui.cfunc.entry_ea + + bs_vars.append(var) + + return bs_vars + + +@execute_write +@requires_decompilation +def rename_local_variables_by_names(func: Function, name_map: typing.Dict[str, str], ida_code_view=None) -> bool: + lvars = { + lvar.name: lvar for lvar in ida_code_view.cfunc.get_lvars() if lvar.name + } + update = False + for name, lvar in lvars.items(): + new_name = name_map.get(name, None) + if new_name is None: + continue + + ida_hexrays.rename_lvar(func.addr, lvar.name, new_name) + update |= True + + if update and ida_code_view is not None: + ida_code_view.cfunc.refresh_func_ctext() + ida_code_view.refresh_view(True) + + return update + + +# +# Stack Vars +# + +def _deprecated_ida_to_bs_offset(func_addr, ida_stack_off): + frame = idaapi.get_frame(func_addr) + if not frame: + return ida_stack_off + + frame_size = idc.get_struc_size(frame) + + if frame_size == 0: + return ida_stack_off + + last_member_size = idaapi.get_member_size(frame.get_member(frame.memqty - 1)) + bs_soff = ida_stack_off - frame_size + last_member_size + return bs_soff + +def _deprecated_bs_to_ida_offset(func_addr, bs_stack_off): + frame = idaapi.get_frame(func_addr) + if not frame: + return bs_stack_off + + frame_size = idc.get_struc_size(frame) + + if frame_size == 0: + return bs_stack_off + + last_member_size = idaapi.get_member_size(frame.get_member(frame.memqty - 1)) + ida_soff = bs_stack_off + frame_size - last_member_size + return ida_soff + + +def get_func_stack_tif(func): + if isinstance(func, int): + func = idaapi.get_func(func) + + if func is None: + return None + + tif = ida_typeinf.tinfo_t() + if not tif.get_func_frame(func): + return None + + return tif + +def get_frame_info(func_addr) -> typing.Tuple[int, int]: + func = idaapi.get_func(func_addr) + if not func: + raise ValueError(f"Function {hex(func_addr)} does not exist.") + + stack_tif = get_func_stack_tif(func) + if stack_tif is None: + _l.warning("Function %s does not have a stack frame.", hex(func_addr)) + return None, None + + frame_size = stack_tif.get_size() + if frame_size == 0: + _l.warning("Function %s has a stack frame size of 0.", hex(func_addr)) + return None, None + + # get the last member size + udt_data = ida_typeinf.udt_type_data_t() + stack_tif.get_udt_details(udt_data) + membs = [m for m in udt_data] + if not membs: + _l.warning("Function %s has a stack frame with no members.", hex(func_addr)) + return None, None + + last_member_type = membs[-1].type + if not last_member_type: + _l.warning("Function %s has a stack frame with a member with no type.", hex(func_addr)) + return None, None + + last_member_size = last_member_type.get_size() + return frame_size, last_member_size + +def ida_to_bs_stack_offset(func_addr: int, ida_stack_off: int): + if get_ida_version() < 900: + return _deprecated_ida_to_bs_offset(func_addr, ida_stack_off) + + frame_size, last_member_size = get_frame_info(func_addr) + if frame_size is None or last_member_size is None: + return ida_stack_off + + bs_soff = ida_stack_off - frame_size + last_member_size + return bs_soff + +def bs_to_ida_stack_offset(func_addr: int, bs_stack_off: int): + if get_ida_version() < 900: + # maintain backwards compatibility + return _deprecated_bs_to_ida_offset(func_addr, bs_stack_off) + + frame_size, last_member_size = get_frame_info(func_addr) + if frame_size is None or last_member_size is None: + return bs_stack_off + + ida_soff = bs_stack_off + frame_size - last_member_size + return ida_soff + +def set_stack_variables(svars: list[StackVariable], decompiler_available=True, **kwargs) -> bool: + """ + This function should only be called in a function that is already used in main-thread. + This should also mean decompilation is passed in. + """ + ida_code_view = kwargs.get('ida_code_view', None) + changes = False + if ida_code_view is None: + # TODO: support decompilation-less stack var setting + _l.warning("Cannot set stack variables without a decompiler.") + return changes + + lvars = {v.location.stkoff(): v for v in ida_code_view.cfunc.lvars if v.is_stk_var()} + if not lvars: + _l.warning("No stack variables found in decompilation to set. Making new ones is not supported") + return changes + + for bs_off, bs_var in svars.items(): + if bs_off not in lvars: + _l.warning("Stack variable at offset %s not found in decompilation.", bs_off) + continue + + lvar = lvars[bs_off] + + # naming: + if bs_var.name and bs_var.name != lvar.name: + ida_code_view.rename_lvar(lvar, bs_var.name, 1) + changes |= True + ida_code_view.cfunc.refresh_func_ctext() + lvars = {v.location.stkoff(): v for v in ida_code_view.cfunc.lvars if v.is_stk_var()} + + # typing + if bs_var.type: + curr_ida_type_str = str(lvar.type()) if lvar.type() else None + curr_ida_type = IDAArtifactLifter.lift_ida_type(curr_ida_type_str) if curr_ida_type_str else None + if curr_ida_type and bs_var.type != curr_ida_type: + new_type = convert_type_str_to_ida_type(bs_var.type) + if new_type is None: + _l.warning("Failed to convert type string %s to ida type.", bs_var.type) + continue + + updated_type = ida_code_view.set_lvar_type(lvar, new_type) + if updated_type: + changes |= True + ida_code_view.cfunc.refresh_func_ctext() + lvars = {v.location.stkoff(): v for v in ida_code_view.cfunc.lvars if v.is_stk_var()} + + if changes: + ida_code_view.refresh_view(True) + + return changes + + +@execute_write +def get_func_stack_var_info(func_addr) -> typing.Dict[int, StackVariable]: + try: + decompilation = ida_hexrays.decompile(func_addr) + except ida_hexrays.DecompilationFailure: + _l.debug("Decompiling too many functions too fast! Slow down and try that operation again.") + return {} + + if decompilation is None: + _l.warning("Decompiled something that gave no decompilation") + return {} + + stack_var_info = {} + + for var in decompilation.lvars: + if not var.is_stk_var(): + continue + + size = var.width + name = var.name + + ida_offset = var.location.stkoff() - decompilation.get_stkoff_delta() + bs_offset = ida_to_bs_stack_offset(func_addr, ida_offset) + type_str = str(var.type()) + stack_var_info[bs_offset] = StackVariable( + ida_offset, name, type_str, size, func_addr + ) + + return stack_var_info + + +@execute_write +def _deprecated_set_stack_vars_types(var_type_dict, ida_code_view) -> bool: + """ + Sets the type of a stack variable, which should be a local variable. + Take special note of the types of first two parameters used here: + var_type_dict is a dictionary of the offsets and the new proposed type info for each offset. + This typeinfo should be gotten either by manully making a new typeinfo object or using the + parse_decl function. code_view is a _instance of vdui_t, which should be gotten through + open_pseudocode() from ida_hexrays. + + This function also is special since it needs to iterate all of the stack variables an unknown amount + of times until a fixed point of variables types not changing is met. + + + @param var_type_dict: Dict[stack_offset, ida_typeinf_t] + @param ida_code_view: A pointer to a vdui_t screen + @param deci: The declib deci to do operations on + @return: + """ + + data_changed = False + fixed_point = False + func_addr = ida_code_view.cfunc.entry_ea + while not fixed_point: + fixed_point = True + for lvar in ida_code_view.cfunc.lvars: + if lvar.is_stk_var(): + # TODO: this algorithm may need be corrected for programs with func args on the stack + cur_off = abs(ida_to_bs_stack_offset(func_addr, lvar.location.stkoff())) + if cur_off in var_type_dict: + if str(lvar.type()) != str(var_type_dict[cur_off]): + data_changed |= ida_code_view.set_lvar_type(lvar, var_type_dict.pop(cur_off)) + fixed_point = False + # make sure to break, in case the size of lvars array has now changed + break + + return data_changed + + +# +# IDA Comment r/w +# + +@execute_write +def set_ida_comment(addr, cmt, decompiled=False): + func = ida_funcs.get_func(addr) + if not func: + _l.info("No function found at %s", addr) + return False + + rpt = 1 + ida_code_view = None + if decompiled: + try: + ida_code_view = acquire_pseudocode_vdui(func.start_ea) + except Exception: + pass + + # function comment + if addr == func.start_ea: + idc.set_func_cmt(addr, cmt, rpt) + if ida_code_view: + ida_code_view.refresh_view(True) + return True + + # a comment in decompilation + elif decompiled: + if ida_code_view is None: + ida_bytes.set_cmt(addr, cmt, rpt) + return True + + eamap = ida_code_view.cfunc.get_eamap() + decomp_obj_addr = eamap[addr][0].ea + tl = idaapi.treeloc_t() + + # try to set a comment using the cfunc obj and normal address + for a in [addr, decomp_obj_addr]: + tl.ea = a + for itp in range(idaapi.ITP_SEMI, idaapi.ITP_COLON): + tl.itp = itp + ida_code_view.cfunc.set_user_cmt(tl, cmt) + ida_code_view.cfunc.save_user_cmts() + ida_code_view.cfunc.refresh_func_ctext() + + # attempt to set until it does not fail (orphan itself) + if not ida_code_view.cfunc.has_orphan_cmts(): + ida_code_view.cfunc.save_user_cmts() + ida_code_view.refresh_view(True) + return True + ida_code_view.cfunc.del_orphan_cmts() + return False + # a comment in disassembly + else: + ida_bytes.set_cmt(addr, cmt, rpt) + return True + + +def get_ida_comment(addr, decompiled=True): + # TODO: support more than just functions + # TODO: support more than just function headers + if decompiled and not ida_hexrays.init_hexrays_plugin(): + raise ValueError("Decompiler is not available, but you are requesting a decompiled comment") + + func = idaapi.get_func(addr) + if func is None: + return None + + if func.start_ea == addr: + cmt = idc.get_func_cmt(addr, 1) + return cmt if cmt else None + + +@execute_write +def set_decomp_comments(func_addr, cmt_dict: typing.Dict[int, str]): + for addr in cmt_dict: + ida_cmts = ida_hexrays.user_cmts_new() + + comment = cmt_dict[addr] + tl = ida_hexrays.treeloc_t() + tl.ea = addr + # XXX: need a real value here at some point + tl.itp = 90 + ida_cmts.insert(tl, ida_hexrays.citem_cmt_t(comment)) + + ida_hexrays.save_user_cmts(func_addr, ida_cmts) + + +# +# IDA Struct r/w +# + +def bs_struct_from_tif(tif): + if not tif.is_struct(): + return None + + size = tif.get_size() + name = tif.get_type_name() + + # get members + members = {} + if size > 0: + udt_data = ida_typeinf.udt_type_data_t() + if tif.get_udt_details(udt_data): + for udt_memb in udt_data: + # TODO: warning if offset is not a multiple of 8 (a bit offset), we are in trouble + byte_offset = udt_memb.offset // 8 + m_name = udt_memb.name + m_type = udt_memb.type + type_name = m_type.get_type_name() or str(m_type) + m_size = m_type.get_size() + members[byte_offset] = StructMember(name=m_name, type_=type_name, size=m_size, offset=byte_offset) + + return Struct(name=name, size=size, members=members) + + +@execute_write +def structs(): + if new_ida_typing_system(): + _structs = get_types(structs=True, enums=False, typedefs=False) + else: + _l.warning("You are using an old IDA, this will be deprecated in the future!") + _structs = {} + for struct_item in idautils.Structs(): + idx, sid, name = struct_item[:] + sptr = idc.get_struc(sid) + size = idc.get_struc_size(sptr) + _structs[name] = Struct(name, size, {}) + + return _structs + +def _deprecated_get_struct(name): + + sid = idc.get_struc_id(name) + if sid == idaapi.BADADDR: + return None + + sptr = idc.get_struc(sid) + size = idc.get_struc_size(sptr) + _struct = Struct(name, size, {}, last_change=datetime.datetime.now(tz=datetime.timezone.utc)) + for mptr in sptr.members: + mid = mptr.id + m_name = idc.get_member_name(mid) + m_off = mptr.soff + m_type = ida_typeinf.idc_get_type(mptr.id) if mptr.has_ti() else "" + m_size = idc.get_member_size(mptr) + _struct.add_struct_member(m_name, m_off, m_type, m_size) + + return _struct + +@execute_write +def struct(name): + if not new_ida_typing_system(): + return _deprecated_get_struct(name) + + tid = ida_typeinf.get_named_type_tid(name) + tif = ida_typeinf.tinfo_t() + if tid != idaapi.BADADDR and tif.get_type_by_tid(tid) and tif.is_udt(): + return bs_struct_from_tif(tif) + + return None + +@execute_write +def del_ida_struct(name) -> bool: + sid = idc.get_struc_id(name) + if sid == idaapi.BADADDR: + return False + + sptr = sid if new_ida_typing_system() else idc.get_struc(sid) + return idc.del_struc(sptr) + + +def expand_ida_struct(sid, new_size): + """ + Only works in IDA 9 and up + """ + tif = ida_typeinf.tinfo_t() + if tif.get_type_by_tid(sid) and tif.is_udt(): + if tif.get_size() == new_size: + return True + + udm = ida_typeinf.udm_t() + udm.offset = 0 + idx = tif.find_udm(udm, ida_typeinf.STRMEM_LOWBND|ida_typeinf.STRMEM_SKIP_GAPS) + if idx != -1: + return tif.expand_udt(idx, new_size) + + return False + + +@execute_write +def set_ida_struct(struct: Struct) -> bool: + new_struct_system = new_ida_typing_system() + # first, delete any struct by the same name if it exists + sid = idc.get_struc_id(struct.name) + if sid != idaapi.BADADDR: + sptr = sid if new_struct_system else idc.get_struc(sid) + idc.del_struc(sptr) + + # now make a struct header + idc.add_struc(ida_idaapi.BADADDR, struct.name, False) + sid = idc.get_struc_id(struct.name) + + struct_identifier = sid if new_struct_system else idc.get_struc(sid) + + # expand the struct to the desired size + # XXX: do not increment API here, why? Not sure, but you cant do it here. + if get_ida_version() >= 900: + expand_ida_struct(sid, struct.size) + else: + idc.expand_struc(struct_identifier, 0, struct.size, False) + + # add every member of the struct + for off, member in struct.members.items(): + if member.size is None: + if member.type is None: + raise ValueError("Member size and type cannot both be None when setting a struct!") + + type_size = type_str_to_size(member.type) + if type_size is None: + _l.warning("Failed to get size for member %s of struct %s, assuming 8!", member.name, struct.name) + type_size = 8 + + member.size = type_size + + if member.offset is None: + member.offset = off + + # convert to ida's flag system + mflag = convert_size_to_flag(member.size) + + # create the new member + idc.add_struc_member( + struct_identifier, + member.name, + member.offset, + mflag, + -1, + member.size, + ) + + return True + +def _depreacated_set_ida_struct_member_types(struct: Struct) -> bool: + # find the specific struct + sid = idc.get_struc_id(struct.name) + sptr = idc.get_struc(sid) + data_changed = False + + for idx, member in enumerate(struct.members.values()): + # set the new member type if it has one + if member.type == "": + continue + + # assure its convertible + tif = convert_type_str_to_ida_type(member.type) + if tif is None: + continue + + # set the type + mptr = sptr.get_member(idx) + was_set = idc.set_member_tinfo( + sptr, + mptr, + 0, + tif, + mptr.flag + ) + data_changed |= was_set == 1 + + return data_changed + + +@execute_write +def set_ida_struct_member_types(bs_struct: Struct): + if not new_ida_typing_system(): + return _depreacated_set_ida_struct_member_types(bs_struct) + + struct_tif = get_ida_type(name=bs_struct.name) + if struct_tif is None: + return False + + udt_data = ida_typeinf.udt_type_data_t() + if not struct_tif.get_udt_details(udt_data): + return False + + data_changed = False + for member_index, udt_memb in enumerate(udt_data): + if udt_memb.offset % 8 != 0: + _l.warning( + f"Struct member %s of struct %s is not byte aligned! This is currently unsupported.", + udt_memb.name, + bs_struct.name + ) + continue + + byte_offset = udt_memb.offset // 8 + bs_member = bs_struct.members.get(byte_offset, None) + if bs_member is None: + continue + + member_tif = convert_type_str_to_ida_type(bs_member.type) + if member_tif is None: + _l.warning("Failed to convert type %s for struct member %s", bs_member.type, bs_member.name) + continue + + if member_tif != udt_memb.type: + struct_tif.set_udm_type(member_index, member_tif) + data_changed |= True + + return data_changed + +# +# Global Vars +# + + +@execute_write +def global_vars(): + gvars = {} + known_segs = [".artifacts", ".bss"] + for seg_name in known_segs: + seg = idaapi.get_segm_by_name(seg_name) + if not seg: + continue + + for seg_ea in range(seg.start_ea, seg.end_ea): + xrefs = idautils.XrefsTo(seg_ea) + try: + next(xrefs) + except StopIteration: + continue + + name = idaapi.get_name(seg_ea) + if not name: + continue + + gvars[seg_ea] = GlobalVariable(seg_ea, name) + + return gvars + + +@execute_write +def global_var(addr): + name = idaapi.get_name(addr) + if not name: + return None + + type_ = idc.get_type(addr) + size = idaapi.get_item_size(addr) + return GlobalVariable(addr, name, size=size, last_change=datetime.datetime.now(tz=datetime.timezone.utc), type_=type_) + + +@execute_write +def set_global_var_name(var_addr, name): + return idaapi.set_name(var_addr, name) + +@execute_write +def set_global_var_type(var_addr, type_str): + """ + To make sure the type is correctly displayed (especially for arrays of structs, or arrayy of chars, a.k.a. strings), + we first undefine the items where the type is going to be applied. + Parse the applied type string to infer its size, and thus the number of bytes to undefine. + """ + tif = convert_type_str_to_ida_type(type_str) + if tif is None: + idc.del_items(var_addr, flags=idc.DELIT_SIMPLE) + else: + type_size = tif.get_size() + idc.del_items(var_addr, flags=idc.DELIT_SIMPLE, nbytes=type_size) + return idc.SetType(var_addr, type_str) + + +def ida_type_from_serialized(typ: bytes, fields: bytes): + tif = ida_typeinf.tinfo_t() + if not tif.deserialize(ida_typeinf.get_idati(), typ, fields): + tif = None + + return tif + +# +# Enums +# + + +def _deprecated_get_enum_mmebers(_enum_id, max_size=100) -> typing.Dict[str, int]: + enum_members = {} + + member = idc.get_first_enum_member(_enum_id) + member_addr = idc.get_enum_member(_enum_id, member, 0, 0) + member_name = idc.get_enum_member_name(member_addr) + if member_name is None: + return enum_members + + enum_members[member_name] = member + + member = idc.get_next_enum_member(_enum_id, member, 0) + for _ in range(max_size): + if member == idaapi.BADADDR: + break + + member_addr = idc.get_enum_member(_enum_id, member, 0, 0) + member_name = idc.get_enum_member_name(member_addr) + if member_name: + enum_members[member_name] = member + + member = idc.get_next_enum_member(_enum_id, member, 0) + else: + _l.critical("IDA failed to iterate all enum members for enum %s", _enum_id) + + return enum_members + + +def get_enum_members(_enum: typing.Union["ida_typeinf.tinfo_t", int], max_size=100) -> typing.Optional[typing.Dict[str, int]]: + """ + _enum can either be an ida_typeinf.tinfo_t or an int (the old enum id system). + Returns None if the tif reports as an enum but IDA can't fetch its details + (e.g. typedef wrappers that pass tif.is_enum() but aren't real enums). + """ + if not new_ida_typing_system(): + _enum_id: int = _enum + return _deprecated_get_enum_mmebers(_enum_id, max_size=max_size) + + # this is an enum tif if we are here + enum_tif: "ida_typeinf.tinfo_t" = _enum + ei = ida_typeinf.enum_type_data_t() + if not enum_tif.get_enum_details(ei): + _l.debug("IDA could not get enum details for %s; treating as non-enum", enum_tif) + return None + + enum_members = {} + for e_memb in ei: + val = e_memb.value + if val == -1: + _l.warning("IDA failed to get enum member value for %s", e_memb) + break + + name = e_memb.name + if name is None: + _l.warning("IDA failed to get enum member name for %s", e_memb) + break + + enum_members[name] = val + + return enum_members + + +def enum_from_tif(tif): + enum_name = tif.get_type_name() + if not enum_name: + return None + + enum_members = get_enum_members(tif) + if enum_members is None: + return None + return Enum(enum_name, enum_members) + + +@execute_write +def enums() -> typing.Dict[str, Enum]: + return get_types(structs=False, enums=True, typedefs=False) + + +@execute_write +def enum(name) -> typing.Optional[Enum]: + new_enums = new_ida_typing_system() + _enum = get_ida_type(name=name) if new_enums else idc.get_enum(name) + if _enum is None or _enum == idaapi.BADADDR: + return None + + enum_name = str(_enum.get_type_name()) if new_enums else idc.get_enum_name(_enum) + enum_members = get_enum_members(_enum) + if enum_members is None: + return None + return Enum(enum_name, enum_members) + + +@execute_write +def set_enum(bs_enum: Enum): + _enum = idc.get_enum(bs_enum.name) + if not _enum: + return False + + idc.del_enum(_enum) + ords = get_ordinal_count() + enum_id = idc.add_enum(ords, bs_enum.name, 0) + + if enum_id is None: + _l.warning("IDA failed to create a new enum with %s", bs_enum.name) + return False + + for member_name, value in bs_enum.members.items(): + idc.add_enum_member(enum_id, member_name, value) + + return True + +# +# Typedefs +# + + +def use_new_typedef_check(): + return get_ida_version() >= 900 + + +def typedef_info(tif, use_new_check=False) -> typing.Tuple[bool, typing.Optional[str], typing.Optional[str]]: + invalid_typedef = False, None, None + tdef_checker = lambda t: t.is_typedef() if use_new_check else t.is_typeref() + if not tdef_checker(tif): + return invalid_typedef + + name = tif.get_type_name() + type_name = tif.get_next_type_name() + if not name: + return invalid_typedef + + # in older versions we have to parse the type directly (thanks @arizvisa) + if not type_name: + ser_info = idaapi.get_named_type(None, name, idaapi.NTF_TYPE) + ser_bytes = ser_info[1] + if ser_info is not None: + base_tif = ida_typeinf.tinfo_t() + found_base_type = base_tif.deserialize(idaapi.get_idati(), ser_bytes, None, None) + if not base_tif.is_struct(): + type_name = str(base_tif) if found_base_type else None + + if not name or not type_name or name == type_name: + return invalid_typedef + + return True, name, type_name + + +@execute_write +def typedefs() -> typing.Dict[str, Typedef]: + return get_types(structs=False, enums=False, typedefs=True) + + +@execute_write +def typedef(name) -> typing.Optional[Typedef]: + idati = idaapi.get_idati() + tif = ida_typeinf.tinfo_t() + success = tif.get_named_type(idati, name) + if not success: + return None + + is_typedef, name, type_name = typedef_info(tif, use_new_check=use_new_typedef_check()) + if not is_typedef: + return None + + return Typedef(name=name, type_=type_name) + + +def make_typedef_tif(name, type_str): + tif = ida_typeinf.tinfo_t() + ida_type_str = f"typedef {type_str} {name};" + valid_parse = ida_typeinf.parse_decl(tif, None, ida_type_str, 1) + return tif if valid_parse is not None else None + + +@execute_write +def set_typedef(bs_typedef: Typedef): + type_tif = convert_type_str_to_ida_type(bs_typedef.type) + if type_tif is None: + _l.critical("Attempted to set a typedef with an invalid type: %s (does not exist)", bs_typedef.name) + return False + + typedef_tif = make_typedef_tif(bs_typedef.name, bs_typedef.type) + if typedef_tif is None: + _l.critical("Failed to create a typedef name=%s type=%s", bs_typedef.name, bs_typedef.type) + return False + + typedef_tif.set_named_type(idaapi.get_idati(), bs_typedef.name, ida_typeinf.NTF_TYPE) + return True + +# +# IDA GUI r/w +# + +@execute_read +def get_image_base(): + return idaapi.get_imagebase() + +@execute_read +def get_first_segment_base(): + """ + Get the virtual address of the first segment. + """ + # First, try to find the code segment specifically + for seg_addr in idautils.Segments(): + return seg_addr + + # Fallback to image base if no segments found + return idaapi.get_imagebase() + + +@execute_write +def acquire_pseudocode_vdui(addr): + """ + Acquires a IDA HexRays vdui pointer, which is a pointer to a pseudocode view that contains + the cfunc which describes the code on the screen. Using this function optimizes the switching of code views + by using in-place switching if a view is already present. + + @param addr: + @return: + """ + func = ida_funcs.get_func(addr) + if not func: + return None + + names = ["Pseudocode-%c" % chr(ord("A") + i) for i in range(5)] + for name in names: + widget = ida_kernwin.find_widget(name) + if not widget: + continue + + vu = ida_hexrays.get_widget_vdui(widget) + break + else: + vu = ida_hexrays.open_pseudocode(func.start_ea, False) + + if func.start_ea != vu.cfunc.entry_ea: + target_cfunc = ida_hexrays.decompile(func.start_ea) + if target_cfunc is None: + return None + vu.switch_to(target_cfunc, False) + else: + vu.refresh_view(True) + + return vu + + +@execute_write +def refresh_pseudocode_view(ea, set_focus=True): + """Refreshes the pseudocode view in IDA.""" + names = ["Pseudocode-%c" % chr(ord("A") + i) for i in range(5)] + for name in names: + widget = ida_kernwin.find_widget(name) + if widget: + vu = ida_hexrays.get_widget_vdui(widget) + + # Check if the address is in the same function + func_ea = vu.cfunc.entry_ea + func = ida_funcs.get_func(func_ea) + if ida_funcs.func_contains(func, ea): + vu.refresh_view(True) + ida_kernwin.activate_widget(widget, set_focus) + + +class IDAViewCTX: + @execute_write + def __init__(self, func_addr): + self.view = ida_hexrays.open_pseudocode(func_addr, 0) + + def __enter__(self): + return self.view + + @execute_write + def __exit__(self, exc_type, exc_val, exc_tb): + self.close_pseudocode_view(self.view) + + @execute_write + def close_pseudocode_view(self, ida_vdui_t): + if ida_vdui_t is None: + return + widget = ida_vdui_t.toplevel + idaapi.close_pseudocode(widget) + + +def get_screen_ea(): + return idc.get_screen_ea() + + +@execute_write +def get_function_cursor_at(): + curr_addr = get_screen_ea() + if curr_addr is None: + return None, None + + return curr_addr, ida_func_addr(curr_addr) + + +# +# Other Utils +# + +@execute_write +def get_ptr_size(): + """ + Gets the size of the ptr, which in affect tells you the bit size of the binary. + + Taken from: https://github.com/arizvisa/ida-minsc/blob/master/base/database.py + :return: int, size in bytes + """ + tif = ida_typeinf.tinfo_t() + tif.create_ptr(ida_typeinf.tinfo_t(ida_typeinf.BT_VOID)) + return tif.get_size() + + +@execute_write +def get_binary_path(): + return idaapi.get_input_file_path() + + +@execute_write +def jumpto(addr): + """ + Changes the pseudocode view to the function address provided. + + @param addr: Address of function to jump to + @return: + """ + idaapi.jumpto(addr) + + +@execute_write +def jumpto_type(type_name: str) -> None: + """ + Changes the view to the Local Types window, focusing on the specified type. + Does nothing if type is not found + + @param type_name: Name of the user-defined type to jump to + @return: + """ + tif = convert_type_str_to_ida_type(type_name) + if tif is not None: + ida_kernwin.open_loctypes_window(tif.get_ordinal()) + + +@execute_write +def xrefs_to(addr): + return list(idautils.XrefsTo(addr)) + + +@execute_write +def xrefs_from(addr): + """Return the list of code refs originating at ``addr``. + + Filters to code-flow xrefs of kind Near/Far call, so the results line + up with ``Function.getCalledFunctions()`` on Ghidra and angr's + ``kb.callgraph.successors`` — i.e. only direct callees. + """ + out = [] + for xref in idautils.XrefsFrom(addr): + if xref.iscode and xref.type in (idaapi.fl_CN, idaapi.fl_CF): + out.append(int(xref.to)) + return out + + +@execute_write +def list_strings(): + """Return ``(ea, text)`` tuples for every string IDA found. + + Mirrors the Strings window / ``idautils.Strings()``; the caller filters + on text. + """ + results = [] + for s in idautils.Strings(): + try: + text = str(s) + except Exception: + continue + if not text: + continue + results.append((int(s.ea), text)) + return results + + +@execute_write +def read_memory(addr, size): + """Read ``size`` bytes from the IDB at ``addr``. + + Uses ``ida_bytes.get_bytes`` which honors loaded segments and patched + bytes. Returns ``None`` when IDA can't satisfy the read at all. + """ + if size <= 0: + return b"" + data = ida_bytes.get_bytes(addr, size) + if data is None: + return None + return bytes(data) + + +@execute_write +def disassemble_function(addr): + """Return a single-string disassembly for the function containing ``addr``.""" + func = ida_funcs.get_func(addr) + if func is None: + return None + lines = [] + start, end = func.start_ea, func.end_ea + ea = start + while ea < end and ea != idaapi.BADADDR: + line = idc.generate_disasm_line(ea, 0) + if line is not None: + lines.append(f"{ea:016x} {line}") + ea = idc.next_head(ea, end) + return "\n".join(lines) if lines else None + + +@execute_write +def wait_for_idc_initialization(): + idc.auto_wait() + + +def initialize_decompiler(): + return bool(ida_hexrays.init_hexrays_plugin()) + + +def has_older_hexrays_version(): + wait_for_idc_initialization() + # any 8.2 versions is bad + return 820 <= get_ida_version() < 830 + + +@execute_write +def get_decompiler_version() -> typing.Optional[Version]: + wait_for_idc_initialization() + + # init_hexrays_plugin() must succeed before any other ida_hexrays.* call — + # otherwise IDA emits "Hex-Rays Decompiler got called from Python without + # being loaded" warnings (e.g. during early plugin load before Hex-Rays + # finishes wiring up). Returns False if the decompiler is genuinely + # unavailable (headless without license, etc.); the caller should treat + # None as "decompiler unavailable, skip version-gated behavior". + if not ida_hexrays.init_hexrays_plugin(): + return None + + return Version(ida_hexrays.get_hexrays_version()) + + +# +# Segment Management +# + +@execute_write +def set_segment(segment: Segment) -> bool: + """ + Creates or updates a segment in IDA Pro. + """ + if not segment.name or segment.start_addr is None or segment.end_addr is None: + return False + + # Check if segment already exists + existing_seg = ida_segment.get_segm_by_name(segment.name) + if existing_seg is not None: + # TODO: maybe we can do this more efficiently? + # delete the segment + del_seg = del_segment(segment.name) + if not del_seg: + _l.warning("Failed to delete existing segment %s before updating it.", segment.name) + return False + + # Create new segment + seg = ida_segment.segment_t() + seg.start_ea = segment.start_addr + seg.end_ea = segment.end_addr + seg.sel = idaapi.setup_selector(0) + + # Add the segment + result = ida_segment.add_segm(seg.sel, segment.start_addr, segment.end_addr, segment.name, "DATA") + if result: + # Set segment name explicitly + new_seg = ida_segment.get_segm_by_name(segment.name) + if new_seg is None: + new_seg = ida_segment.getseg(segment.start_addr) + if new_seg is not None: + ida_segment.set_segm_name(new_seg, segment.name) + return result + + +def segment(name: str) -> typing.Optional[Segment]: + """ + Gets a segment by name. + """ + seg = ida_segment.get_segm_by_name(name) + if seg is None: + return None + + # Convert IDA segment to DecLib Segment + return Segment( + name=name, + start_addr=seg.start_ea, + end_addr=seg.end_ea, + permissions=None # TODO: extract permissions if needed + ) + + +def segments() -> typing.Dict[str, Segment]: + """ + Returns all segments in the binary. + """ + segs = {} + for seg_addr in idautils.Segments(): + seg = ida_segment.getseg(seg_addr) + if seg is not None: + seg_name = ida_segment.get_segm_name(seg) + if seg_name: + segs[seg_name] = Segment( + name=seg_name, + start_addr=seg.start_ea, + end_addr=seg.end_ea, + permissions=None # TODO: extract permissions if needed + ) + return segs + + +@execute_write +def del_segment(name: str) -> bool: + """ + Deletes a segment by name. + """ + seg = ida_segment.get_segm_by_name(name) + if seg is None: + return False + + return ida_segment.del_segm(seg.start_ea, ida_segment.SEGMOD_KILL) + + +def view_to_bs_context(view, get_var=True, action: str = Context.ACT_UNKNOWN) -> typing.Optional[Context]: + form_type = idaapi.get_widget_type(view) + if form_type is None: + return None + + form_to_type_name = get_form_to_type_name() + view_name = form_to_type_name.get(form_type, "unknown") + ctx = Context(screen_name=view_name, action=action) + if view_name in FUNC_FORMS: + ctx.addr = idaapi.get_screen_ea() + func = idaapi.get_func(ctx.addr) + if func is not None: + ctx.func_addr = func.start_ea + # exit early when we are still rendering the screen (no real click info) + if action == Context.ACT_MOUSE_MOVE: + return ctx + + if view_name == "decompilation" and get_var: + # get lvar info at cursor + vu = idaapi.get_widget_vdui(view) + if vu and vu.item: + lvar = vu.item.get_lvar() + if lvar: + ctx.variable = lvar.name + if vu.cpos is not None: + ctx.line_number = vu.cpos.lnnum + ctx.col_number = vu.cpos.x + + return ctx + + +# +# IDA Classes +# + +def generate_generic_ida_plugic_cls(cls_name=None): + """ + This code is pretty complicated, but the gist is that we need to dynamically create this IDA Plugin entry point + for two main reasons: + 1. We can't import PyQt5 until load time, which means this class can't be in the import + 2. Plugins are not allowed to share the same name in IDA Pro plugin init, but we want many downstream people + to be able to import this class and modify it + + Below the class gets dynamically created and, if you provide a name, we copy the direct contents of that class + into a new Python type, essentially making a new class of the exact same contents + """ + class GenericIDAPlugin(idaapi.plugin_t): + """Plugin entry point. Does most of the skinning magic.""" + flags = idaapi.PLUGIN_FIX + + def __init__(self, *args, name=None, comment=None, interface=None, **kwargs): + idaapi.plugin_t.__init__(self) + self.wanted_name = name or "generic_declib_plugin" + self.comment = comment or "A generic DecLib plugin" + self.interface: "IDAInterface" = interface + + def init(self): + self.interface._init_gui_hooks() + return idaapi.PLUGIN_KEEP + + def run(self, arg): + pass + + def term(self): + try: + self.interface._term_gui_hooks() + except Exception: + _l.exception("Error tearing down GUI hooks") + self.interface.decompiler_closed_event() + del self.interface + + cls = GenericIDAPlugin + if cls_name is not None: + cls = type(cls_name, (idaapi.plugin_t,), dict(GenericIDAPlugin.__dict__)) + + return cls + + +class GenericAction(idaapi.action_handler_t): + def __init__(self, action_target, action_function, deci=None): + idaapi.action_handler_t.__init__(self) + self.action_target = action_target + self.action_function = action_function + self.deci: IDAInterface = deci + + def activate(self, ctx): + if ctx is None or ctx.action != self.action_target: + return + + bs_ctx = view_to_bs_context(ctx.widget) + if bs_ctx is None: + return + + bs_ctx = self.deci.art_lifter.lift(bs_ctx) + dec_view = ida_hexrays.get_widget_vdui(ctx.widget) + self.action_function(bs_ctx, deci=self.deci, context=bs_ctx) + + if dec_view is not None: + dec_view.refresh_view(False) + + return 1 + + # This action is always available. + def update(self, ctx): + return idaapi.AST_ENABLE_ALWAYS + + diff --git a/declib/decompilers/ida/hooks.py b/declib/decompilers/ida/hooks.py new file mode 100644 index 00000000..fbf59ab4 --- /dev/null +++ b/declib/decompilers/ida/hooks.py @@ -0,0 +1,700 @@ +# ---------------------------------------------------------------------------- +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# +# This program describes each hook in IDA that we want to overwrite on the +# startup of IDA. Each hook function/class describes a different scenario +# that we try to track when a user makes a change. For _instance, the function +# `cmt_changed` is activated every time a user changes a disassembly comment, +# allowing us to send the new comment to be queued in the Controller actions. +# +# ---------------------------------------------------------------------------- +import functools +import logging +from typing import TYPE_CHECKING +from packaging.version import Version +import datetime + +from .compat import IDA_IS_INTERACTIVE, get_ida_gui_version + +import ida_bytes +import ida_funcs +import ida_hexrays +import ida_idp +import ida_kernwin +import ida_typeinf +import idaapi +import idc + +from . import compat +from declib.artifacts import ( + FunctionHeader, StackVariable, + Comment, GlobalVariable, Enum, Struct, Context, Typedef, StructMember, + Decompilation +) + +if TYPE_CHECKING: + from .interface import IDAInterface + +_l = logging.getLogger(__name__) + +IDA_STACK_VAR_PREFIX = "$" +IDA_CMT_CMT = "cmt" +IDA_RANGE_CMT = "range" +IDA_EXTRA_CMT = "extra" +IDA_CMT_TYPES = {IDA_CMT_CMT, IDA_EXTRA_CMT, IDA_RANGE_CMT} + + +def while_should_watch(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if self.interface.should_watch_artifacts(): + return func(self, *args, **kwargs) + else: + return 0 + + return wrapper + + +# +# Data Change Hooks (excludes decompilation changes) +# + +class IDBHooks(ida_idp.IDB_Hooks): + def __init__(self, interface): + ida_idp.IDB_Hooks.__init__(self) + self.interface: "IDAInterface" = interface + self._seen_function_prototypes = {} + self._ver_9_or_higher = compat.get_ida_version() >= 900 + + def bs_type_deleted(self, ordinal): + old_name, old_type = self.interface.cached_ord_to_type_names[ordinal] + if old_type == Struct: + self.interface.struct_changed(Struct(old_name, -1, members={}), deleted=True) + elif old_type == Enum: + self.interface.enum_changed(Enum(old_name, members={}), deleted=True) + elif old_type == Typedef: + self.interface.typedef_changed(Typedef(name=old_name), deleted=True) + + del self.interface.cached_ord_to_type_names[ordinal] + + def local_types_changed(self, ltc, ordinal, name): + # this can't be a decorator for this function due to how IDA implements these overrides + if not self.interface.should_watch_artifacts(): + return 0 + + tif = compat.get_ida_type(ida_ord=ordinal, name=name) + # was the type deleted? + if tif is None: + if ltc == ida_idp.LTC_DELETED and ordinal in self.interface.cached_ord_to_type_names: + self.bs_type_deleted(ordinal) + + return 0 + + # was the type renamed? + if ordinal in self.interface.cached_ord_to_type_names: + old_name, _ = self.interface.cached_ord_to_type_names[ordinal] + if old_name != name: + self.bs_type_deleted(ordinal) + + # at this point, the type is either completely new or renamed from an existing type. + # in either case we need to just gather the new info and trigger an update + # + # check if it's a typedef first since this these can also trigger strucs + is_typedef, name, type_name = compat.typedef_info(tif, use_new_check=True) + new_type_type = None + if is_typedef: + self.interface.typedef_changed(Typedef(name=name, type_=type_name)) + new_type_type = Typedef + elif tif.is_struct(): + bs_struct = compat.bs_struct_from_tif(tif) + self.interface.struct_changed(bs_struct) + name = tif.get_type_name() + new_type_type = Struct + elif tif.is_enum(): + bs_enum = compat.enum_from_tif(tif) + self.interface.enum_changed(bs_enum) + name = tif.get_type_name() + new_type_type = Enum + + self.interface.cached_ord_to_type_names[ordinal] = (name, new_type_type) + return 0 + + @while_should_watch + def ti_changed(self, ea, type_, fields): + pfn = ida_funcs.get_func(ea) + # only record return type changes + if pfn and pfn.start_ea == ea: + proto_tif = compat.ida_type_from_serialized(type_, fields) + curr_ret_type = str(proto_tif.get_rettype()) + seen_ret_type = self._seen_function_prototypes.get(ea, None) + if seen_ret_type is None: + self._seen_function_prototypes[ea] = curr_ret_type + elif curr_ret_type != seen_ret_type: + self._seen_function_prototypes[ea] = curr_ret_type + self.interface.function_header_changed( + FunctionHeader(None, ea, type_=curr_ret_type, args={}) + ) + elif not pfn: + # Must be a global variable type change + self.interface.global_variable_changed( + GlobalVariable(addr=ea, name=idaapi.get_name(ea), size=idaapi.get_item_size(ea), type_=idc.get_type(ea)) + ) + + return 0 + + # + # Enum Hooks + # + + @while_should_watch + def ida_enum_changed(self, enum_id, new_name=None, deleted=False, member_deleted=False): + name = idc.get_enum_name(enum_id) + _enum = compat.enum(name) if not deleted else Enum(name, {}) + if name in self.interface._deleted_artifacts[Enum]: + if member_deleted: + _l.debug("Attempting to delete the member of an already deleted enum. Skipping...") + return 0 + else: + self.interface._deleted_artifacts[Enum].remove(name) + + if deleted: + self.interface._deleted_artifacts[Enum].add(name) + + if new_name: + _enum.name = new_name + + self.interface.enum_changed(_enum, deleted=deleted) + + @while_should_watch + def enum_created(self, enum): + self.ida_enum_changed(enum) + return 0 + + # XXX - use enum_deleted(self, id) instead? + @while_should_watch + def deleting_enum(self, id): + self.ida_enum_changed(id, deleted=True) + return 0 + + # XXX - use enum_renamed(self, id) instead? + @while_should_watch + def renaming_enum(self, id, is_enum, newname): + enum_id = id + if not is_enum: + enum_id = idc.get_enum_member_enum(id) + + # delete it + self.ida_enum_changed(enum_id, deleted=True) + # readd it with the new name + self.ida_enum_changed(enum_id, new_name=newname) + return 0 + + @while_should_watch + def enum_bf_changed(self, id): + return 0 + + @while_should_watch + def enum_cmt_changed(self, tid, repeatable_cmt): + return 0 + + @while_should_watch + def enum_member_created(self, id, cid): + self.ida_enum_changed(id) + return 0 + + # XXX - use enum_member_deleted(self, id, cid) instead? + @while_should_watch + def deleting_enum_member(self, id, cid): + self.ida_enum_changed(id, member_deleted=True) + return 0 + + # + # Stack Variable + # + + @while_should_watch + def frame_udm_renamed(self, func_ea, udm, oldname): + self._ida_stack_var_changed(func_ea, udm) + return 0 + + @while_should_watch + def frame_udm_changed(self, func_ea, udm_tid, udm_old, udm_new): + self._ida_stack_var_changed(func_ea, udm_new) + return 0 + + def _ida_stack_var_changed(self, func_ea, udm): + # TODO: implement stack var support when there is no decompiler available + return + + def _deprecated_ida_stack_var_changed(self, sptr, mptr): + # XXX: This is a deprecated function that will be removed in the future + func_addr = idaapi.get_func_by_frame(sptr.id) + try: + stack_var_info = compat.get_func_stack_var_info(func_addr)[ + compat.ida_to_bs_stack_offset(func_addr, mptr.soff) + ] + except KeyError: + _l.debug("Failed to track an internal changing stack var: %s.", mptr.id) + return 0 + + # find the properties of the changed stack var + bs_offset = compat.ida_to_bs_stack_offset(func_addr, stack_var_info.offset) + size = stack_var_info.size + type_str = stack_var_info.type + + new_name = idc.get_member_name(mptr.id) + self.interface.stack_variable_changed( + StackVariable(bs_offset, new_name, type_str, size, func_addr) + ) + + # + # Struct & Stack Var Hooks + # + + def ida_struct_changed(self, sid: int, new_name=None, deleted=False, member_deleted=False): + # parse the info of the current struct + struct_name = new_name if new_name else idc.get_struc_name(sid) + + if struct_name in self.interface._deleted_artifacts[Struct]: + if member_deleted: + # attempting to re-delete an already deleted struct + _l.debug("Attempting to delete the member of an already deleted struct. Skipping...") + return 0 + else: + # if we readded a struct that was previously deleted, remove it from the deleted list + self.interface._deleted_artifacts[Struct].remove(struct_name) + + if struct_name.startswith(IDA_STACK_VAR_PREFIX) or struct_name.startswith("__"): + _l.info("Not recording change to %s since its likely an internal IDA struct.", struct_name) + return 0 + + if deleted: + self.interface._deleted_artifacts[Struct].add(struct_name) + self.interface.struct_changed(Struct(struct_name, -1, {}), deleted=True) + return 0 + + struct_ptr = idc.get_struc(sid) + bs_struct = Struct( + struct_name, + idc.get_struc_size(struct_ptr), + {}, + ) + + for mptr in struct_ptr.members: + m_name = idc.get_member_name(mptr.id) + m_off = mptr.soff + m_type = ida_typeinf.idc_get_type(mptr.id) if mptr.has_ti() else "" + m_size = idc.get_member_size(mptr) + bs_struct.add_struct_member(m_name, m_off, m_type, m_size) + + self.interface.struct_changed(bs_struct, deleted=False) + return 0 + + @while_should_watch + def struc_created(self, tid): + sptr = idc.get_struc(tid) + if not sptr.is_frame(): + self.ida_struct_changed(tid) + + return 0 + + # XXX - use struc_deleted(self, struc_id) instead? + @while_should_watch + def deleting_struc(self, sptr): + if not sptr.is_frame(): + self.ida_struct_changed(sptr.id, deleted=True) + + return 0 + + @while_should_watch + def struc_align_changed(self, sptr): + if not sptr.is_frame(): + self.ida_struct_changed(sptr.id) + + return 0 + + # XXX - use struc_renamed(self, sptr) instead? + @while_should_watch + def renaming_struc(self, id, oldname, newname): + sptr = idc.get_struc(id) + if not sptr.is_frame(): + # delete it + self.ida_struct_changed(id, deleted=True) + # add it + self.ida_struct_changed(id, new_name=newname) + return 0 + + @while_should_watch + def struc_expanded(self, sptr): + if not sptr.is_frame(): + self.ida_struct_changed(sptr.id) + + return 0 + + @while_should_watch + def struc_member_created(self, sptr, mptr): + if not sptr.is_frame(): + self.ida_struct_changed(sptr.id) + + return 0 + + @while_should_watch + def struc_member_deleted(self, sptr, off1, off2): + if not sptr.is_frame(): + self.ida_struct_changed(sptr.id, member_deleted=True) + + return 0 + + @while_should_watch + def struc_member_renamed(self, sptr, mptr): + if sptr.is_frame() and not compat.new_ida_typing_system(): + # TODO: this will be deprecated in the future + self._deprecated_ida_stack_var_changed(sptr, mptr) + else: + self.ida_struct_changed(sptr.id) + + return 0 + + @while_should_watch + def struc_member_changed(self, sptr, mptr): + if sptr.is_frame() and not compat.new_ida_typing_system(): + # TODO: this will be deprecated in the future + self._deprecated_ida_stack_var_changed(sptr, mptr) + else: + self.ida_struct_changed(sptr.id) + + return 0 + + def _valid_rename_event(self, ea): + if not self._ver_9_or_higher: + # ignore any changes landing here for structs and stack vars + import ida_struct, ida_enum + return not (ida_struct.is_member_id(ea) or ida_struct.get_struc(ea) or ida_enum.get_enum_name(ea)) + + # in version 9 and above, this event is not triggered by structs + return True + + @while_should_watch + def renamed(self, ea, new_name, local_name): + if not self._valid_rename_event(ea): + return 0 + + ida_func = idaapi.get_func(ea) + # symbols changing without any corresponding func is assumed to be global var + if ida_func is None: + self.interface.global_variable_changed( + GlobalVariable(ea, new_name, size=idaapi.get_item_size(ea), type_=idc.get_type(ea)) + ) + # function name renaming + elif ida_func.start_ea == ea: + self.interface.function_header_changed( + FunctionHeader(idc.get_func_name(ida_func.start_ea), ida_func.start_ea) + ) + + return 0 + + # + # Comment handlers + # + + def ida_comment_changed(self, comment: str, address: int, cmt_type: str): + if cmt_type not in IDA_CMT_TYPES: + _l.debug("An unknown IDA comment type was changed, unknown how to handle!") + return 0 + + ida_func = idaapi.get_func(address) + func_addr = ida_func.start_ea if ida_func else None + bs_cmt = Comment(address, comment, func_addr=func_addr) + if cmt_type == IDA_RANGE_CMT: + bs_cmt.decompiled = True + + if cmt_type != IDA_EXTRA_CMT: + self.interface.comment_changed(bs_cmt, deleted=not comment) + + return 0 + + @while_should_watch + def cmt_changed(self, ea, repeatable_cmt): + if repeatable_cmt: + cmt = ida_bytes.get_cmt(ea, repeatable_cmt) + if cmt: + self.ida_comment_changed(cmt, ea, IDA_CMT_CMT) + return 0 + + @while_should_watch + def range_cmt_changed(self, kind, a, cmt, repeatable): + cmt = idc.get_func_cmt(a.start_ea, repeatable) + if cmt: + self.ida_comment_changed(cmt, a.start_ea, IDA_RANGE_CMT) + return 0 + + @while_should_watch + def extra_cmt_changed(self, ea, line_idx, cmt): + cmt = ida_bytes.get_cmt(ea, 0) + if cmt: + self.ida_comment_changed(cmt, ea, IDA_CMT_CMT) + return 0 + + # + # Unused handlers, to be implemented eventually + # + + @while_should_watch + def struc_cmt_changed(self, id, repeatable_cmt): + """ + fullname = idc.get_struc_name(id) + if "." in fullname: + sname, smname = fullname.split(".", 1) + else: + sname = fullname + smname = "" + cmt = idc.get_struc_cmt(id, repeatable_cmt) + """ + return 0 + + @while_should_watch + def sgr_changed(self, start_ea, end_ea, regnum, value, old_value, tag): + return 0 + + @while_should_watch + def byte_patched(self, ea, old_value): + return 0 + + +# +# Special event hooks +# + + +class IDPHooks(ida_idp.IDP_Hooks): + def __init__(self, interface): + self.interface: "IDAInterface" = interface + ida_idp.IDP_Hooks.__init__(self) + + def ev_adjust_argloc(self, *args): + return ida_idp.IDP_Hooks.ev_adjust_argloc(self, *args) + + def ev_ending_undo(self, action_name, is_undo): + """ + This is the hook called by IDA when an undo event occurs + action name is a vague String description of what changes occured + is_undo specifies if this action was an undo or a redo + """ + self.interface.gui_undo_event(action=action_name) + return 0 + + def ev_replaying_undo(self, action_name, vec, is_undo): + """ + This hook is also called by IDA during the undo + contains the same information as ev_ending_undo + vec also contains a short summary of changes incurred + """ + return 0 + + +# +# Decompilation change hooks +# + +class HexraysHooks(ida_hexrays.Hexrays_Hooks): + def __init__(self, interface, *args, **kwargs): + # this needs to be set from the ourside before hook + self.interface: "IDAInterface" = interface + ida_hexrays.Hexrays_Hooks.__init__(self) + + @while_should_watch + def lvar_name_changed(self, vdui, lvar, new_name, *args): + self.local_var_changed(vdui, lvar, reset_type=True, var_name=new_name) + self._send_decompilation_event(vdui.cfunc) + return 0 + + @while_should_watch + def lvar_type_changed(self, vu: "vdui_t", v: "lvar_t", *args) -> int: + self.local_var_changed(vu, v, reset_name=True) + self._send_decompilation_event(vu.cfunc) + return 0 + + @while_should_watch + def cmt_changed(self, cfunc, treeloc, cmt_str, *args): + self.interface.comment_changed( + Comment(treeloc.ea, cmt_str, func_addr=cfunc.entry_ea, decompiled=True), deleted=not cmt_str + ) + self._send_decompilation_event(cfunc) + return 0 + + @while_should_watch + def refresh_pseudocode(self, vu): + self._send_decompilation_event(vu.cfunc) + return 0 + + def curpos(self, vu): + # Hex-Rays cursor moved within pseudocode. View_Hooks.view_curpos doesn't + # fire reliably for pseudocode caret motion (esp. arrow-key navigation), + # so mirror it through to the same context-update path the disassembly + # view uses, so "users on current function" updates promptly. + if not (self.interface.force_click_recording or self.interface.artifact_watchers_started): + return 0 + widget = vu.ct if hasattr(vu, "ct") else None + if widget is None: + return 0 + ctx = compat.view_to_bs_context(widget, action=Context.ACT_VIEW_OPEN) + if ctx is None: + return 0 + ctx = self.interface.art_lifter.lift(ctx) + ctx.last_change = datetime.datetime.now(tz=datetime.timezone.utc) + self.interface._gui_active_context = ctx + self.interface.gui_context_changed(ctx) + return 0 + + # + # helpers + # + + def _send_decompilation_event(self, cfunc): + if cfunc is None: + return + + lifted_addr = self.interface.art_lifter.lift_addr(cfunc.entry_ea) + function = self.interface.fast_get_function(lifted_addr) + dec = Decompilation( + addr=cfunc.entry_ea, + text=str(cfunc), + decompiler="ida", + bs_func=function + ) + self.interface.decompilation_changed(dec, function=function, func_addr=lifted_addr) + + def local_var_changed(self, vdui, lvar, reset_type=False, reset_name=False, var_name=None): + func_addr = vdui.cfunc.entry_ea + is_func_arg = lvar.is_arg_var + bs_vars = compat.lvars_to_bs( + [lvar], vdui=vdui, var_names=[var_name], + recover_offset=True if is_func_arg else False + ) + if not bs_vars: + return + bs_var = next(iter(bs_vars)) + + if reset_type: + bs_var.type = None + if reset_name: + bs_var.name = None + + # proxy the change through the func header + if is_func_arg: + self.interface.function_header_changed( + FunctionHeader(None, func_addr, args={bs_var.offset: bs_var}), + fargs={bs_var.offset: bs_var}, + ) + else: + self.interface.stack_variable_changed(bs_var) + + +# +# IDA GUI-only hooks +# + +if IDA_IS_INTERACTIVE: + from declib.ui.version import set_ui_version + set_ui_version(get_ida_gui_version()) + from declib.ui.qt_objects import QKeyEvent, QEvent, Qt + + class IDAHotkeyHook(ida_kernwin.UI_Hooks): + def __init__(self, keys_to_pass, uiptr): + super().__init__() + self.keys_to_pass = keys_to_pass + self.ui = uiptr + + def preprocess_action(self, action_name): + uie = ida_kernwin.input_event_t() + ida_kernwin.get_user_input_event(uie) + key_event = uie.get_source_QEvent() + keycode = key_event.key() + if keycode[0] in self.keys_to_pass: + ke = QKeyEvent(QEvent.KeyPress, keycode[0], Qt.NoModifier) + # send new event + self.ui.event(ke) + # consume the event so ida doesn't take it + return 1 + return 0 + + + class ContextMenuHooks(idaapi.UI_Hooks): + def __init__(self, *args, menu_strs=None, **kwargs): + idaapi.UI_Hooks.__init__(self) + self.menu_strs = menu_strs or [] + + def finish_populating_widget_popup(self, form, popup): + # Add actions to the context menu of the Pseudocode view + if idaapi.get_widget_type(form) == idaapi.BWN_PSEUDOCODE or idaapi.get_widget_type( + form) == idaapi.BWN_DISASM: + for menu_str, category in self.menu_strs: + idaapi.attach_action_to_popup(form, popup, menu_str, f"{category}/") + + + class ScreenHook(ida_kernwin.View_Hooks): + def __init__(self, interface: "IDAInterface"): + self.interface = interface + super(ScreenHook, self).__init__() + + def view_click(self, view, event): + self._handle_view_event(view, action_type=Context.ACT_MOUSE_CLICK) + + def view_activated(self, view: "TWidget *"): + self._handle_view_event(view, action_type=Context.ACT_VIEW_OPEN) + + def view_curpos(self, view: "TWidget *"): + # fires when the cursor (current position) moves within the view — + # includes keyboard navigation (G/jump, arrow keys, double-clicking + # in the Functions list, etc.), which view_click/view_activated miss. + self._handle_view_event(view, action_type=Context.ACT_VIEW_OPEN) + + def view_mouse_moved(self, view: "TWidget *", event: "view_mouse_event_t"): + if self.interface.track_mouse_moves: + self._handle_view_event(view, ida_event=event, action_type=Context.ACT_MOUSE_MOVE) + + def _handle_view_event(self, view, action_type=Context.ACT_UNKNOWN, ida_event=None): + if self.interface.force_click_recording or self.interface.artifact_watchers_started: + # drop ctx for speed when the artifact watches have not been officially started, and we are not clicking + if (self.interface.force_click_recording and not self.interface.artifact_watchers_started) and \ + action_type == Context.ACT_MOUSE_MOVE: + return + + ctx = compat.view_to_bs_context(view, action=action_type) + if ctx is None: + return + + # handle special case of mouse move + if action_type == Context.ACT_MOUSE_MOVE and ida_event is not None: + ctx.line_number = ida_event.renderer_pos.cy + ctx.col_number = ida_event.renderer_pos.cx + if ctx.screen_name == "disassembly" and ida_event.renderer_pos.node != -1: + # TODO: this is not an addr, but the node number in graph view + ctx.extras['node'] = ida_event.renderer_pos.node + elif ctx.screen_name == "decompilation": + # TODO: the address is useless here! + ctx.addr = ctx.func_addr + + ctx = self.interface.art_lifter.lift(ctx) + ctx.last_change = datetime.datetime.now(tz=datetime.timezone.utc) + self.interface._gui_active_context = ctx + + self.interface.gui_context_changed(ctx) + + +else: + IDAHotkeyHook = None + ContextMenuHooks = None + ScreenHook = None diff --git a/declib/decompilers/ida/ida_ui.py b/declib/decompilers/ida/ida_ui.py new file mode 100644 index 00000000..235cd6f3 --- /dev/null +++ b/declib/decompilers/ida/ida_ui.py @@ -0,0 +1,80 @@ +import logging + +import idaapi + +from .compat import get_ida_gui_version + +from declib.ui.version import set_ui_version +set_ui_version(get_ida_gui_version()) +from declib.ui.qt_objects import QWidget, QVBoxLayout, wrapInstance + +_l = logging.getLogger(__name__) + + +def ask_choice(question, choices, title="Choose an option"): + class MyForm(idaapi.Form): + def __init__(self, options): + self.dropdown = idaapi.Form.DropdownListControl(items=options) + form_string = ("STARTITEM 0\n" + f"{title}\n" + f"{question}\n" + "") + idaapi.Form.__init__(self, form_string, {'dropdown': self.dropdown}) + + # Instantiate and display the form + form = MyForm(choices) + form.Compile() + ok = form.Execute() + if ok == 1: + selected_index = form.dropdown.value + selected_item = choices[selected_index] + else: + selected_item = "" + form.Free() + return selected_item + + +class IDAWidgetWrapper(object): + def __init__(self, qt_cls, window_name: str, *args, **kwargs): + self.twidget = idaapi.create_empty_widget(window_name) + self.widget = wrapInstance(int(self.twidget), QWidget) + self.name = window_name + self.widget.name = window_name + self.width_hint = 250 + + self._widget = qt_cls(*args, **kwargs) + layout = QVBoxLayout() + layout.addWidget(self._widget) + layout.setContentsMargins(2, 2, 2, 2) + self.widget.setLayout(layout) + + +def attach_qt_widget(qt_cls, window_name: str, target_window=None, position=None, *args, **kwargs): + wrapper = IDAWidgetWrapper(qt_cls, window_name, *args, **kwargs) + if not wrapper.twidget: + _l.error("Failed to create widget %s", window_name) + return False + + flags = idaapi.PluginForm.WOPN_TAB | idaapi.PluginForm.WOPN_RESTORE | idaapi.PluginForm.WOPN_PERSIST + idaapi.display_widget(wrapper.twidget, flags) + wrapper.widget.visible = True + + if position is None: + # make a new tab in the target window + position = idaapi.DP_RIGHT + + if target_window == "Functions": + dock_dst = "Functions" + position = idaapi.DP_INSIDE + else: + # attempt to 'dock' the widget in a reasonable location + for target in ["IDA View-A", "Pseudocode-A"]: + dwidget = idaapi.find_widget(target) + if dwidget: + dock_dst = target + break + else: + raise RuntimeError("Could not find a suitable dock position for the widget") + + idaapi.set_dock_pos(wrapper.name, dock_dst, position) + return True diff --git a/declib/decompilers/ida/interface.py b/declib/decompilers/ida/interface.py new file mode 100755 index 00000000..9787e517 --- /dev/null +++ b/declib/decompilers/ida/interface.py @@ -0,0 +1,659 @@ +import logging +from typing import Dict, Optional, List +from collections import OrderedDict, defaultdict +from packaging.version import Version +import declib +from declib.api.decompiler_interface import DecompilerInterface +from declib.artifacts import ( + StackVariable, Function, FunctionHeader, Struct, Comment, GlobalVariable, Enum, Patch, Artifact, Decompilation, + Context, Typedef, Segment +) +from declib.api.decompiler_interface import requires_decompilation +from . import compat +from .artifact_lifter import IDAArtifactLifter +from .compat import get_ida_gui_version +from .hooks import ContextMenuHooks, ScreenHook, IDBHooks, IDPHooks, HexraysHooks + +if compat.IDA_IS_INTERACTIVE: + from . import ida_ui +else: + try: + # IDA 9+ + import idapro + except ImportError: + # IDA 9 Beta + import ida as idapro + +import idc +import idaapi +import ida_hexrays +import ida_auto + + +_l = logging.getLogger(name=__name__) + + +def _qt_shortcut_to_ida(shortcut: str) -> str: + """Convert a Qt-style shortcut like "Ctrl+Shift+D" to IDA's "Ctrl-Shift-D".""" + if not shortcut: + return "" + return shortcut.replace("+", "-") + + +# +# Controller +# + +class IDAInterface(DecompilerInterface): + # idalib (IDA's headless mode) enforces main-thread-only API access and + # raises ``RuntimeError: Function can be called from the main thread only`` + # when called from a worker thread. The DecompilerServer checks this flag + # and routes backend calls through its main-thread dispatcher. + requires_main_thread_dispatch = True + + def __init__(self, project_dir=None, **kwargs): + self._ctx_menu_names = [] + self._ui_hooks = [] + self._artifact_watcher_hooks = [] + self._gui_active_context = None + self._deleted_artifacts = defaultdict(set) + self.cached_ord_to_type_names = {} + # Optional cache directory where the .id* database files should live. + self._project_dir = project_dir + + super().__init__( + name="ida", qt_version=get_ida_gui_version(), artifact_lifter=IDAArtifactLifter(self), + decompiler_available=compat.initialize_decompiler(), **kwargs + ) + + self._max_patch_size = 0xff + self._decompiler_available = None + self._dec_version = None + self._ida_analysis_finished = False + + # GUI properties + self._updated_ctx = None + + def _init_headless_components(self, *args, **kwargs): + """ + This function initializes the headless functionality of IDA through idalib. + This also means that this feature is only supported in IDA versions >= 9.0 + """ + super()._init_headless_components(*args, **kwargs) + binary_path = str(self.binary_path) + extra_args = self._ida_open_args() + # IDA <= 9.1 only accepts (path, run_auto_analysis); the extra_args + # parameter was added in 9.2. + if compat.get_ida_version() <= 910: + if extra_args: + _l.warning( + "project_dir/extra open args are only supported on IDA >= 9.2; ignoring %r.", + extra_args, + ) + failure = idapro.open_database(binary_path, True) + else: + failure = idapro.open_database(binary_path, True, extra_args) + if failure: + raise RuntimeError(f"Failed to open database {binary_path}") + + def _ida_open_args(self) -> Optional[str]: + """Build the extra args string passed to ``idapro.open_database``. + + When ``project_dir`` is configured we redirect IDA's database sidecar + files (``.id0/.id1/.id2/.nam/.til``) into that directory using IDA's + own ``-o`` command-line flag. The sidecars go into a nested + ``ida/`` subdirectory so they don't collide with anything else the + user / other backends leave in the top-level project_dir (Ghidra's + ``_ghidra/`` project, stale symlinks, etc.). + """ + from pathlib import Path as _Path + + if not self._project_dir: + return None + + project_dir = _Path(self._project_dir).expanduser().resolve() + ida_dir = project_dir / "ida" + ida_dir.mkdir(parents=True, exist_ok=True) + binary_name = _Path(str(self.binary_path)).name + # IDA's -o takes the database base path (no extension); it picks + # .idb / .i64 / .id* itself. + db_base = ida_dir / binary_name + return f"-o{db_base}" + + def _deinit_headless_components(self): + """ + This function deinitializes the headless functionality of IDA through idalib. + This also means that this feature is only supported in IDA versions >= 9.0 + """ + idapro.close_database(False) + + def _init_gui_hooks(self): + """ + This function can only be called from inside the compat.GenericIDAPlugin and is meant for IDA code which + should be run as a plugin. + """ + self._ui_hooks = [ + ScreenHook(self), + ContextMenuHooks(self, menu_strs=self._ctx_menu_names), + IDPHooks(self), + ] + for hook in self._ui_hooks: + hook.hook() + + def _term_gui_hooks(self): + """ + Symmetric teardown for _init_gui_hooks. Must run before IDAPython tears + down — otherwise a still-registered hook can fire during shutdown + events (e.g. term_database) and try to re-enter a finalized Python. + """ + for hook in self._ui_hooks: + try: + hook.unhook() + except Exception: + _l.exception("Failed to unhook %r", hook) + self._ui_hooks = [] + + def _init_gui_plugin(self, *args, **kwargs): + self.decompiler_opened_event() + plugin_cls_name = self._plugin_name + "_cls" + IDAPluginCls = compat.generate_generic_ida_plugic_cls(cls_name=plugin_cls_name) + return IDAPluginCls(*args, name=self._plugin_name, interface=self, **kwargs) + + @property + def dec_version(self): + if self._dec_version is None: + self._dec_version = compat.get_decompiler_version() + + return self._dec_version + + # + # GUI + # + + def gui_ask_for_string(self, question, title="Plugin Question", default="") -> str: + resp = idaapi.ask_str(default, 0, question) + return resp if resp else "" + + def gui_ask_for_choice(self, question: str, choices: list, title="Plugin Question") -> str: + return ida_ui.ask_choice(question, choices, title=title) + + def gui_register_ctx_menu(self, name, action_string, callback_func, category=None, shortcut=None) -> bool: + ida_shortcut = _qt_shortcut_to_ida(shortcut) if shortcut else "" + action = idaapi.action_desc_t( + name, + action_string, + compat.GenericAction(name, callback_func, deci=self), + ida_shortcut, + action_string, + 199 + ) + idaapi.register_action(action) + idaapi.attach_action_to_menu( + f"Edit/{category}/{name}" if category else f"Edit/{name}", + name, + idaapi.SETMENU_APP + ) + self._ctx_menu_names.append((name, category or "")) + return True + + def gui_attach_qt_window(self, qt_window: type["QWidgt"], title: str, target_window=None, position=None, *args, **kwargs) -> bool: + return ida_ui.attach_qt_widget(qt_window, title, target_window=None, position=None, *args, **kwargs) + + # + # Mandatory API + # + + @property + def binary_base_addr(self) -> int: + return compat.get_first_segment_base() + + @property + def binary_hash(self) -> str: + return idc.retrieve_input_file_md5().hex() + + @property + def binary_path(self) -> Optional[str]: + return self._binary_path or compat.get_binary_path() + + def get_func_size(self, func_addr) -> int: + func_addr = self.art_lifter.lower_addr(func_addr) + return compat.get_func_size(func_addr) + + @property + def decompiler_available(self) -> bool: + if self._decompiler_available is None: + self._decompiler_available = ida_hexrays.init_hexrays_plugin() + + return self._decompiler_available + + def xrefs_to(self, artifact: Artifact, decompile=False, only_code=False) -> List[Artifact]: + if not isinstance(artifact, Function): + _l.warning("xrefs_to is only implemented for functions.") + return [] + + function: Function = self.art_lifter.lower(artifact) + return self._collect_xrefs_to(function.addr, only_code=only_code) + + def xrefs_to_addr(self, addr: int, only_code: bool = False) -> List[Artifact]: + lowered = self.art_lifter.lower_addr(addr) + return self._collect_xrefs_to(lowered, only_code=only_code) + + def xrefs_from(self, func_addr: int) -> List[Function]: + """Direct callees of ``func_addr`` — just the call targets, no data.""" + lowered = self.art_lifter.lower_addr(func_addr) + func = compat.fast_get_function(lowered, get_rtype=False) + if func is None: + return [] + callees: List[Function] = [] + seen = set() + # Walk every instruction in the function body; cheap because fauxware- + # sized binaries are typical, and this is the same approach the + # ``idautils.CodeRefsFrom`` helpers use under the hood. + import ida_funcs as _ida_funcs # local to keep interface.py clean + ida_func = _ida_funcs.get_func(lowered) + if ida_func is None: + return [] + ea = ida_func.start_ea + while ea < ida_func.end_ea and ea != idaapi.BADADDR: + for callee_ea in compat.xrefs_from(ea): + callee_func_addr = compat.ida_func_addr(callee_ea) or callee_ea + if callee_func_addr in seen: + continue + seen.add(callee_func_addr) + lifted = self.art_lifter.lift_addr(callee_func_addr) + fast_func = self.fast_get_function(lifted) or Function(lifted, 0) + callees.append(fast_func) + ea = idc.next_head(ea, ida_func.end_ea) + return callees + + def list_strings(self, filter: Optional[str] = None) -> List[tuple]: + import re as _re + pattern = _re.compile(filter) if filter else None + out = [] + for ea, text in compat.list_strings(): + if pattern is not None and not pattern.search(text): + continue + out.append((self.art_lifter.lift_addr(ea), text)) + out.sort(key=lambda item: item[0]) + return out + + def disassemble(self, addr: int, **kwargs) -> Optional[str]: + lowered = self.art_lifter.lower_addr(addr) + return compat.disassemble_function(lowered) + + def read_memory(self, addr: int, size: int) -> Optional[bytes]: + if size <= 0: + return b"" + lowered = self.art_lifter.lower_addr(addr) + return compat.read_memory(lowered, size) + + def _collect_xrefs_to(self, lowered_addr: int, only_code: bool, + _max_chase: int = 2) -> List[Artifact]: + """Collect function-level xrefs to ``lowered_addr``. + + PIE binaries route string / global references through indirection + tables (GOT / _RDATA pointer arrays), so a direct + ``idautils.XrefsTo(str_addr)`` only lands on the pointer — not on + the code that dereferences it. We BFS up to ``_max_chase`` levels + of data indirection so ``xrefs_to SOSNEAKY`` can still name the + caller. + """ + visited_targets: set = set() + seen_funcs: set = set() + xrefs: List[Artifact] = [] + + frontier = [(lowered_addr, 0)] + while frontier: + target, depth = frontier.pop(0) + if target in visited_targets: + continue + visited_targets.add(target) + + for ida_xref in compat.xrefs_to(target): + if only_code and not ida_xref.iscode: + continue + from_ea = int(ida_xref.frm) + from_func_addr = compat.ida_func_addr(from_ea) + if from_func_addr is not None: + if from_func_addr in seen_funcs: + continue + seen_funcs.add(from_func_addr) + lifted = self.art_lifter.lift_addr(from_func_addr) + fast_func = self.fast_get_function(lifted) or Function(lifted, 0) + xrefs.append(fast_func) + elif depth < _max_chase and not only_code: + # data-to-data indirection: chase one hop further. + frontier.append((from_ea, depth + 1)) + return xrefs + + def get_decompilation_object(self, function: Function, do_lower=True, **kwargs) -> Optional[object]: + function = self.art_lifter.lower(function) if do_lower else function + dec = ida_hexrays.decompile(function.addr) + if dec is None: + return None + + return dec + + def _decompile(self, function: Function, map_lines=False, **kwargs) -> Optional[Decompilation]: + try: + cfunc = ida_hexrays.decompile(function.addr) + if cfunc is None: + return None + except Exception: + return None + + decompilation = Decompilation(addr=function.addr, text=str(cfunc), decompiler=self.name) + if map_lines: + linenum_to_addr = defaultdict(set) + # always add the start as line 1 + linenum_to_addr[1].add(cfunc.entry_ea) + + # find all lines 2 - N + for addr, lines in cfunc.get_eamap().items(): + for line in lines: + y_holder = idaapi.int_pointer() + if not cfunc.find_item_coords(line, None, y_holder): + continue + + linenum = y_holder.value() + linenum_to_addr[linenum].add(addr) + + decompilation.line_map = {k: v for k, v in linenum_to_addr.items()} + + return decompilation + + def fast_get_function(self, func_addr) -> Optional[Function]: + lowered_addr = self.art_lifter.lower_addr(func_addr) + lowered_func = compat.fast_get_function(lowered_addr, get_rtype=False) + if lowered_func is None: + #_l.error(f"Function does not exist at {lowered_addr}") + return None + + return self.art_lifter.lift(lowered_func) + + # + # GUI API + # + + def start_artifact_watchers(self): + super().start_artifact_watchers() + # TODO: this is a hack for backwards compatibility and should be removed in IDA 9 + idb_hook = IDBHooks(self) + if self.decompiler_available and self.dec_version < Version("8.4"): + idb_hook.local_types_changed = lambda: 0 + else: + # this code in this block must exist in 9.0, so don't delete it! + self.cached_ord_to_type_names = compat.get_ord_to_type_names() + + self._artifact_watcher_hooks = [ + idb_hook, + # this hook is special because it relies on the decompiler being present, which can only be checked + # after the plugin loading phase. this means the user will need to manually init this hook in the UI + # either through scripting or a UI. + HexraysHooks(self), + ] + for hook in self._artifact_watcher_hooks: + hook.hook() + + def stop_artifact_watchers(self): + super().stop_artifact_watchers() + for hook in self._artifact_watcher_hooks: + hook.unhook() + + def gui_active_context(self) -> Context: + if self._gui_active_context is None: + # in cases that we end up here, we are likely in UI mode, without artifact watchers started, + # so we should not cache this result (as it will be stale) + low_addr, low_func_addr = compat.get_function_cursor_at() + return self.art_lifter.lift(Context(addr=low_addr, func_addr=low_func_addr)) + + return self._gui_active_context + + def gui_goto(self, func_addr) -> None: + func_addr = self.art_lifter.lower_addr(func_addr) + compat.jumpto(func_addr) + + def gui_show_type(self, type_name: str) -> None: + compat.jumpto_type(type_name) + + def should_watch_artifacts(self) -> bool: + # never do hooks while IDA is in initial startup phase + if not self._ida_analysis_finished: + self._ida_analysis_finished = ida_auto.auto_is_ok() + + return self._ida_analysis_finished and self.artifact_watchers_started + + # + # Optional API + # + + @requires_decompilation + def local_variable_names(self, func: Function) -> List[str]: + dec = func.dec_obj + if dec is None: + return [] + + return [lvar.name for lvar in dec.get_lvars() if lvar.name] + + @requires_decompilation + def rename_local_variables_by_names(self, func: Function, name_map: Dict[str, str], **kwargs) -> bool: + func = self.art_lifter.lower(func) + return compat.rename_local_variables_by_names(func, name_map) + + # + # Artifact API + # + + # functions + def _set_function(self, func: Function, **kwargs) -> bool: + """ + Overrides the normal _set_function for speed optimizations + """ + return compat.set_function(func, headless=self.headless, decompiler_available=self.decompiler_available, **kwargs) + + def _get_function(self, addr, **kwargs) -> Optional[Function]: + return compat.function(addr, headless=self.headless, decompiler_available=self.decompiler_available, **kwargs) + + def _functions(self) -> Dict[int, Function]: + return compat.functions() + + # stack vars + def _set_stack_variable(self, svar: StackVariable, **kwargs) -> bool: + _l.warning("Setting stack vars using this API is deprecared. Use _set_function instead.") + return False + + # global variables + def _set_global_variable(self, gvar: GlobalVariable, **kwargs) -> bool: + modified = False + if gvar.name: + modified |= compat.set_global_var_name(gvar.addr, gvar.name) + if gvar.type: + modified |= compat.set_global_var_type(gvar.addr, gvar.type) + + return modified + + def _get_global_var(self, addr) -> Optional[GlobalVariable]: + return compat.global_var(addr) + + def _global_vars(self, **kwargs) -> Dict[int, GlobalVariable]: + """ + Returns a dict of declib.GlobalVariable that contain the addr and size of each global var. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return compat.global_vars() + + # structs + def _set_struct(self, struct: Struct, header=True, members=True, **kwargs) -> bool: + data_changed = False + if self.decompiler_available and self.dec_version < Version("8.3") and "gcc_va_list" in struct.name: + _l.critical("Syncing the struct %s in IDA Pro 8.2 <= will cause a crash. Skipping...", struct.name) + return False + + if header: + data_changed |= compat.set_ida_struct(struct) + + if members: + data_changed |= compat.set_ida_struct_member_types(struct) + + return data_changed + + def _get_struct(self, name) -> Optional[Struct]: + return compat.struct(name) + + def _del_struct(self, name) -> bool: + return compat.del_ida_struct(name) + + def _structs(self) -> Dict[str, Struct]: + """ + Returns a dict of declib.Structs that contain the name and size of each struct in the decompiler. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return compat.structs() + + # enums + def _set_enum(self, enum: Enum, **kwargs) -> bool: + return compat.set_enum(enum) + + def _get_enum(self, name) -> Optional[Enum]: + return compat.enum(name) + + def _enums(self) -> Dict[str, Enum]: + """ + Returns a dict of declib.Enum that contain the name of the enums in the decompiler. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return compat.enums() + + # typedefs + def _set_typedef(self, typedef: Typedef, **kwargs) -> bool: + return compat.set_typedef(typedef) + + def _get_typedef(self, name) -> Optional[Typedef]: + return compat.typedef(name) + + def _typedefs(self) -> Dict[str, Typedef]: + return compat.typedefs() + + # patches + def _set_patch(self, patch: Patch, **kwargs) -> bool: + idaapi.patch_bytes(patch.addr, patch.bytes) + return True + + def _get_patch(self, addr) -> Optional[Patch]: + patches = self._collect_continuous_patches(min_addr=addr-1, max_addr=addr+self._max_patch_size, stop_after_first=True) + return patches.get(addr, None) + + def _patches(self) -> Dict[int, Patch]: + """ + Returns a dict of declib.Patch that contain the addr of each Patch and the bytes. + Note: this does not contain the live artifacts of the Artifact, only the minimum knowledge to that the Artifact + exists. To get live artifacts, use the singleton function of the same name. + + @return: + """ + return self._collect_continuous_patches() + + # comments + def _set_comment(self, comment: Comment, **kwargs) -> bool: + return compat.set_ida_comment(comment.addr, comment.comment, decompiled=comment.decompiled) + + def _get_comment(self, addr) -> Optional[Comment]: + ida_cmt = compat.get_ida_comment(addr) + if ida_cmt is None: + return None + + # TODO: need to be better implemented! + return Comment(addr=addr, comment=str(ida_cmt), decompiled=True) + + def _comments(self) -> Dict[int, Comment]: + # TODO: implement me! + return {} + + # segments + def _set_segment(self, segment: Segment, **kwargs) -> bool: + return compat.set_segment(segment) + + def _get_segment(self, name) -> Optional[Segment]: + return compat.segment(name) + + def _del_segment(self, name) -> bool: + return compat.del_segment(name) + + def _segments(self) -> Dict[str, Segment]: + return compat.segments() + + # others... + def _set_function_header(self, fheader: FunctionHeader, **kwargs) -> bool: + return compat.set_function_header(fheader) + + # + # utils + # + + @staticmethod + def _ea_to_func(addr): + if not addr or addr == idaapi.BADADDR: + return None + + func_addr = compat.ida_func_addr(addr) + if func_addr is None: + return None + + func = declib.artifacts.Function( + func_addr, 0, header=FunctionHeader(compat.get_func_name(func_addr), func_addr) + ) + return func + + @staticmethod + def _collect_continuous_patches(min_addr=None, max_addr=None, stop_after_first=False) -> Dict[int, Patch]: + patches = {} + + def _patch_collector(ea, fpos, org_val, patch_val): + patches[ea] = bytes([patch_val]) + + if min_addr is None: + min_addr = idaapi.inf_get_min_ea() + if max_addr is None: + max_addr = idaapi.inf_get_max_ea() + + if min_addr is None or max_addr is None: + return patches + + idaapi.visit_patched_bytes(min_addr, max_addr, _patch_collector) + + # now convert into continuous patches + continuous_patches = defaultdict(bytes) + patch_start = None + last_pos = None + for pos, patch in patches.items(): + should_break = False + if last_pos is None or pos != last_pos + 1: + patch_start = pos + + if last_pos is not None and stop_after_first: + should_break = True + + continuous_patches[patch_start] += patch + if should_break: + break + + last_pos = pos + + # convert the patches + continuous_patches = dict(continuous_patches) + normalized_patches = { + offset: Patch(offset, _bytes) + for offset, _bytes in continuous_patches.items() + } + + return normalized_patches + diff --git a/declib/logger.py b/declib/logger.py new file mode 100644 index 00000000..e640504f --- /dev/null +++ b/declib/logger.py @@ -0,0 +1,101 @@ +import logging +import logging.config +import tempfile +from datetime import datetime + +timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") +_, tempfilename = tempfile.mkstemp(prefix=timestamp + '.declib.', suffix='.log') +string_format = "%(levelname)s | %(asctime)s | %(name)-8s | %(message)s" + + +default_config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "console": { + "format": string_format + }, + "logfile": { + "format": string_format + }, + }, + + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "console", + "stream": "ext://sys.stdout" + }, + + "local_file_handler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "logfile", + "filename": tempfilename, + "maxBytes": 1000000, + "backupCount": 20, + "encoding": "utf8", + "delay": True + } + }, + 'loggers': { + 'declib': { + 'handlers': ["console", "local_file_handler"], + 'level': 'INFO', + 'propagate': False + }, + } +} + + +class Loggers: + """ + Logger Manager. + """ + IN_SCOPE_LOGGERS = ('declib', ) + + def __init__(self): + self._loggers = {} + self.load_all_loggers() + self.profiling_enabled = False + + # disable filelock info logs + logging.getLogger("filelock").setLevel(logging.WARNING) + + self.config_dict = None + if default_config is not None: + self.config_dict = default_config + if self.config_dict is not None: + logging.config.dictConfig(self.config_dict) + self.handler = logging.StreamHandler() + self.handler.setFormatter(logging.Formatter('%(levelname)-7s | %(asctime)-23s | %(name)-8s | %(message)s')) + + def load_all_loggers(self): + for name, logger in logging.Logger.manager.loggerDict.items(): + if any(name.startswith(x + '.') or name == x for x in self.IN_SCOPE_LOGGERS): + self._loggers[name] = logger + + def __getattr__(self, k): + real_k = k.replace('_', '.') + if real_k in self._loggers: + return self._loggers[real_k] + else: + raise AttributeError(k) + + def __dir__(self): + return list(super(Loggers, self).__dir__()) + list(self._loggers.keys()) + + +def is_enabled_for(logger, level): + if level == 1: + from .. import loggers + return loggers.profiling_enabled + return originalIsEnabledFor(logger, level) + + +originalIsEnabledFor = logging.Logger.isEnabledFor + +# Override isEnabledFor() for Logger class +logging.Logger.isEnabledFor = is_enabled_for + diff --git a/declib/plugin_installer.py b/declib/plugin_installer.py new file mode 100644 index 00000000..def31ddd --- /dev/null +++ b/declib/plugin_installer.py @@ -0,0 +1,259 @@ +import os +import platform +from pathlib import Path +import textwrap +import sys +import shutil +from typing import Optional, Union, Tuple + +from prompt_toolkit import prompt +from prompt_toolkit.completion.filesystem import PathCompleter + + +class Color: + """ + Used to colorify terminal output. + Taken from: https://github.com/hugsy/gef/blob/dev/tests/utils.py + """ + NORMAL = "\x1b[0m" + GRAY = "\x1b[1;38;5;240m" + LIGHT_GRAY = "\x1b[0;37m" + RED = "\x1b[31m" + GREEN = "\x1b[32m" + YELLOW = "\x1b[33m" + BLUE = "\x1b[34m" + PINK = "\x1b[35m" + CYAN = "\x1b[36m" + BOLD = "\x1b[1m" + UNDERLINE = "\x1b[4m" + UNDERLINE_OFF = "\x1b[24m" + HIGHLIGHT = "\x1b[3m" + HIGHLIGHT_OFF = "\x1b[23m" + BLINK = "\x1b[5m" + BLINK_OFF = "\x1b[25m" + + +class PluginInstaller: + DECOMPILERS = ( + "ida", + "binja", + "ghidra", + "angr" + ) + + DEBUGGERS = ( + "gdb", + ) + + def __init__(self, targets=None, target_install_paths=None): + self.targets = targets if targets is not None else self.DECOMPILERS+self.DEBUGGERS + self._home = Path(os.getenv("HOME") or "~/").expanduser().absolute() + self.target_install_paths = target_install_paths or {} #or self._populate_installs_from_config() + self._successful_installs = {} + + @staticmethod + def find_pkg_files(pkg_name): + if sys.version_info >= (3, 9): + import importlib.resources + path = str(importlib.resources.files(pkg_name)) + else: + import pkg_resources + path = pkg_resources.resource_filename(pkg_name, "") + + return Path(path).absolute() + + def install(self, interactive=True, paths_by_target=None): + self.target_install_paths.update(paths_by_target or {}) + self.display_prologue() + + if interactive: + self.display_install_instructions() + + try: + self.install_targets(interactive=interactive) + except Exception as e: + print(f"Stopping Install... because: {e}") + except KeyboardInterrupt: + print("Goodbye...") + + self.display_epilogue() + + def display_prologue(self): + pass + + def display_install_instructions(self): + print(textwrap.dedent(""" + Each decompiler/debugger will be prompted for install below. If you would like to skip install for something + you can enter 'n' or just hit enter. Each path prompt has tab path completion. + Enter nothing in each path prompt to get the default listed. + """)) + + def display_epilogue(self): + self.good("Plugin install completed! If anything was skipped by mistake, please manually install it.") + + @staticmethod + def info(msg): + print(f"{Color.BLUE}{msg}{Color.NORMAL}") + + @staticmethod + def good(msg): + print(f"{Color.GREEN}[+] {msg}{Color.NORMAL}") + + @staticmethod + def warn(msg): + print(f"{Color.YELLOW}[!] {msg}{Color.NORMAL}") + + @staticmethod + def ask_path(target, location, default=None) -> Optional[Union[bool, Path]]: + """ + Possible return values: + - None: install failed or skipped + - Path: install succeeded + """ + + PluginInstaller.info(f"Install for {target}? [y/n]") + res = prompt("") + if res.lower() != "y": + return None + + PluginInstaller.info(location + f" [default = {default}] (enter nothing to use default): ") + filepath = prompt("", completer=PathCompleter(expanduser=True)) + if not filepath and default: + return default + + filepath = Path(filepath).expanduser().absolute() + if not filepath.exists(): + PluginInstaller.warn(f"Provided filepath {filepath} does not exist. {'Using default.' if default else 'Skipping.'}") + return default if default else None + + return filepath + + @staticmethod + def link_or_copy(src, dst, is_dir=False, symlink=False): + if platform.platform().startswith("Windows"): + # you can't symlink on windows, so just copy + symlink = False + + # clean the install location + shutil.rmtree(dst, ignore_errors=True) + try: + os.unlink(dst) + except: + pass + + if not symlink: + # copy if symlinking is not available on target system + if is_dir: + shutil.copytree(src, dst) + else: + shutil.copy(src, dst) + else: + # first attempt a symlink, if it works, exit early + try: + os.symlink(src, dst, target_is_directory=is_dir) + return + except: + pass + + @staticmethod + def _get_path_without_ask(path, default_path=None, interactive=True) -> Tuple[Path, bool]: + path = Path(path) if path else None + if not interactive and path.exists(): + return path, True + + if path and path.exists(): + default_path = path + else: + default_path = Path(default_path) if default_path else None + if not default_path or not default_path.exists(): + default_path = None + + return default_path, (not interactive and default_path and default_path.exists()) + + def install_targets(self, interactive=True): + for target in self.targets: + try: + target_installer = getattr(self, f"install_{target}") + except AttributeError: + continue + + path = self.target_install_paths.get(f"{target}", None) + if path: + path = Path(path).expanduser().absolute() + + if not path and not interactive: + continue + + res = target_installer(path=path, interactive=interactive) + if res is None: + self.warn(f"Skipping or failed install for {target}... {Color.NORMAL}\n") + else: + self.good(f"Installed {target} to {res}\n") + self._successful_installs[target] = res + #GlobalConfig.update_or_make(self._home, **{f"{target}_path": res.parent}) + + def install_ida(self, path=None, interactive=True): + default_path, skip_ask = self._get_path_without_ask( + path, default_path=self._home.joinpath(".idapro").joinpath("plugins").expanduser(), interactive=interactive + ) + return self.ask_path("IDA Pro", "Plugins Path", default=default_path) if not skip_ask \ + else default_path + + def install_ghidra(self, path=None, interactive=True): + potential_path = self._home.joinpath('ghidra_scripts').expanduser() + if self._home.exists() and not potential_path.exists(): + self.info(f"Creating Ghidra Scripts directory at {potential_path}...") + potential_path.mkdir() + + default_path, skip_ask = self._get_path_without_ask( + path, default_path=potential_path, interactive=interactive + ) + return self.ask_path("Ghidra", "Ghidra Scripts Path", default=default_path) if not skip_ask \ + else default_path + + def install_binja(self, path=None, interactive=True): + os_name = platform.system() + if os_name == "Windows": + default_path = Path(os.environ.get("APPDATA", str(self._home))) / "Binary Ninja" / "plugins" + elif os_name == "Darwin": + default_path = (self._home / "Library" / "Application Support" / "Binary Ninja" / "plugins").expanduser() + else: + default_path = (self._home / ".binaryninja" / "plugins").expanduser() + default_path, skip_ask = self._get_path_without_ask( + path, default_path=default_path, + interactive=interactive + ) + return self.ask_path("Binary Ninja", "Plugins Path", default=default_path) if not skip_ask \ + else default_path + + def install_angr(self, path=None, interactive=True): + # attempt to find the plugins folder for angr-management which is installed via pip + angr_resolved = True + try: + import angrmanagement + except ImportError: + angr_resolved = False + default_path = Path(angrmanagement.__file__).parent / "plugins" if angr_resolved else None + + default_path, skip_ask = self._get_path_without_ask(path, default_path=default_path, interactive=interactive) + return self.ask_path("angr-management", "angr-management Plugins Path", default=default_path) if not skip_ask \ + else default_path + + def install_gdb(self, path=None, interactive=True): + default_path, skip_ask = self._get_path_without_ask( + path, default_path=self._home.joinpath(".gdbinit").expanduser(), + interactive=interactive + ) + return self.ask_path("GDB", "gdbinit Path", default=default_path) if not skip_ask \ + else default_path + + +class DecLibPluginInstaller(PluginInstaller): + def __init__(self, targets=None, target_install_paths=None): + targets = targets or PluginInstaller.DECOMPILERS + super().__init__(targets=targets, target_install_paths=target_install_paths) + self._declib_plugins_path = self.find_pkg_files("declib").joinpath("decompiler_stubs") + + def display_prologue(self): + print(textwrap.dedent(""" + Now installing DecLib plugins for all supported decompilers...""")) diff --git a/declib/skills/__init__.py b/declib/skills/__init__.py new file mode 100644 index 00000000..0722acae --- /dev/null +++ b/declib/skills/__init__.py @@ -0,0 +1,24 @@ +"""Bundled Agent Skills for declib. + +Each subdirectory holds a SKILL.md (and any optional resources) that an LLM can +load to learn how to drive declib via the `decompiler` CLI. Use +`decompiler install-skill` to copy a skill into Claude Code or Codex. +""" +from pathlib import Path + +SKILLS_DIR = Path(__file__).parent + + +def available_skills() -> list[str]: + return sorted( + p.name + for p in SKILLS_DIR.iterdir() + if p.is_dir() and (p / "SKILL.md").is_file() + ) + + +def skill_path(name: str) -> Path: + path = SKILLS_DIR / name + if not (path / "SKILL.md").is_file(): + raise FileNotFoundError(f"Unknown bundled skill: {name!r}") + return path diff --git a/declib/skills/decompiler/SKILL.md b/declib/skills/decompiler/SKILL.md new file mode 100644 index 00000000..25455f47 --- /dev/null +++ b/declib/skills/decompiler/SKILL.md @@ -0,0 +1,316 @@ +--- +name: decompiler +description: Reverse-engineer and modify binaries with a single `decompiler` CLI that drives IDA Pro, Ghidra, Binary Ninja, or angr via DecLib. Use whenever the user asks to decompile, disassemble, look up cross references, rename functions or variables, define or change types, sync work between decompilers, search strings or functions, or otherwise inspect a binary file. Also use for multi-binary workflows (load several binaries at once and switch between them with --id). +--- + +# `decompiler` — DecLib CLI for LLMs + +The `decompiler` command is a thin client that talks to a long-running +`DecompilerServer` (IDA / Ghidra / Binary Ninja / angr). The first `load` of a +binary spawns a server in the background; every subsequent call reuses that +server, so repeated `decompile`/`disassemble`/`xref_*` calls are fast. + +## Setup (once per environment) + +```bash +pip install declib # installs the `decompiler` and `declib` entry points +``` + +That's it — the `decompiler` CLI drives every backend headlessly via DecLib +and does **not** need any plugins installed inside IDA/Ghidra/Binary Ninja +to run. `angr` needs no host tool at all (it's a pure Python dependency) +and is the fastest way to verify the pipeline end-to-end. + +## Mental model + +| Concept | Description | +|---|---| +| **Server** | A headless `declib --server` process holding a single binary open. Identified by a short ID. | +| **Client** | Every `decompiler ` call is a short-lived client that picks a server, does one thing, and exits. | +| **Registry** | `decompiler list` / the shared registry under the declib state dir. Each record has `id`, `backend`, `binary_path`, `socket_path`, `pid`. Use `decompiler list --show-registry` to print just the path. | +| **Address form** | Servers expose **lifted** addresses (relative to the binary base). The CLI accepts either lifted (`0x71d`) or absolute (`0x40071d`) and does the conversion. JSON output always includes both `addr` (int) and `addr_hex` (hex string). | + +## First moves on a new binary + +**Always prefer IDA Pro when it's available** (`--backend ida`) — it +generally produces the cleanest decompilation and the most accurate type +recovery. If IDA fails to load the binary (missing license, unsupported +file type, decompiler error), fall back to `--backend ghidra`, then +`--backend angr` as a last resort. + +**Always start with `list_functions` and `list_strings`** — the same binary +can have the entry named `main` (angr), `FUN_00101c5c` (Ghidra), or +`sub_101c5c` (IDA). Don't assume `main` exists. + +```bash +decompiler load ./target --backend ida # prefer IDA; fall back to ghidra if it fails +decompiler list_functions # enumerate every function — pick a real entry +decompiler list_functions --filter 'main|auth' # or narrow by regex +decompiler list_strings --filter 'flag|pass' # find interesting string constants +``` + +Typical first-hour workflow on a stripped binary: + +1. `decompiler load ./bin --backend ida` (fall back to `--backend ghidra`, + then `--backend angr`, if IDA can't open the binary) +2. `decompiler list_functions` → note non-stub function names + sizes +3. `decompiler list_strings` → look for error messages, user prompts, + format strings — they often point at the interesting code +4. `decompiler xref_to "Welcome"` → jump from a string to its users +5. `decompiler decompile ` on whichever function came out of steps 3–4 + +## Core workflow + +```bash +decompiler load ./fauxware --backend ida # start a server (prefer IDA) +decompiler list_functions # enumerate functions (do this first) +decompiler list_strings --filter 'pass|key' # strings the decompiler identified +decompiler xref_to SOSNEAKY # who references this string? +decompiler decompile authenticate # by name (from list_functions) +decompiler disassemble 0x40071d # by absolute address +decompiler xref_to authenticate # every code+data reference +decompiler get_callers authenticate # call-sites only (subset of xref_to) +decompiler xref_from main # what does main call? +decompiler rename func sub_400662 trampoline # rename a function +decompiler rename var v2 auth_result --function main # rename a local +decompiler create-type "struct Point { int x; int y; }" # define a new type +decompiler retype main buf "Point *" # set a variable's type +decompiler stop --all +``` + +## Running multiple binaries concurrently + +Each binary gets its own server ID: + +```bash +decompiler load ./my-binary # id=abc1234 +decompiler load ./my-binary-2 # id=def5678 +decompiler list +# ID BACKEND PID BINARY +# abc1234... angr 4213 .../my-binary +# def5678... angr 4217 .../my-binary-2 +decompiler decompile main --id abc1234 +decompiler decompile main --binary ./my-binary-2 # or target by path +``` + +When more than one server matches, the CLI refuses and prints a +disambiguation list. Narrow with `--id`, `--binary`, or `--backend`. If you +want to restart the server for a binary cleanly, use `load ... --replace` +which stops the old server and starts a new one (vs `--force` which adds a +second server alongside the existing one). + +## Choosing a backend + +**Default: IDA Pro.** Use `--backend ida` whenever IDA is installed and +licensed — its decompilation is the most reliable across architectures. +Only switch backends if IDA fails to load the binary (the `load` call +errors, or analysis stalls); fall through in this order: `ida → ghidra +→ angr`. Use `binja` only when explicitly requested. + +```bash +decompiler load ./my-binary --backend ida # PREFERRED: IDA Pro (needs install + license) +decompiler load ./my-binary --backend ghidra # FALLBACK: needs GHIDRA_INSTALL_DIR +decompiler load ./my-binary --backend angr # LAST RESORT: pure-Python, always available +decompiler load ./my-binary --backend binja # Binary Ninja, needs license +``` + +If the IDA `load` fails (e.g. unsupported file format, decompiler error), +re-issue `load` with `--backend ghidra` — `load` is idempotent per +backend, so this leaves any other server alone and just brings up a +Ghidra one alongside. + +`--backend` is also accepted on the inspection/mutation subcommands to +narrow which server to target when multiple backends are loaded for the +same binary. + +## Full subcommand reference + +| Subcommand | Purpose | Key flags | +|---|---|---| +| `load ` | Start a server on the binary. Idempotent: returns existing unless `--force`/`--replace`. | `--backend`, `--id`, `--force`, `--replace`, `--project-dir`, `--json` | +| `list` | Show all running servers and the registry path. | `--show-registry`, `--json` | +| `stop` | Shut down one or all servers. | `--id`, `--binary`, `--all`, `--json` | +| `list_functions` | Enumerate every function (ADDR, SIZE, NAME). | `--filter REGEX`, `--json` | +| `decompile ` | Pseudocode for a function (name or address). | `--raw`, `--id`, `--binary`, `--backend`, `--json` | +| `disassemble ` | Assembly for a function. | `--raw`, same | +| `xref_to ` | Every reference (code + data) to the target. | `--decompile`, same | +| `xref_from ` | Functions that `target` calls. | same | +| `rename func ` | Rename a function. | same + `--json` | +| `rename var --function ` | Rename a local variable inside a function. | same | +| `create-type ""` | Define a new `struct`/`enum`/`typedef` from a C string and add it to the type database. | same + `--json` | +| `retype ` | Set the type of a function's local variable or argument. | same | +| `sync --from-id ` | Copy a function's work (names, return/arg types, stack-var names+types, referenced user types) from one running server into another for the same binary. | dest: `--id`/`--binary`/`--backend`; `--json` | +| `list_strings` | Strings the decompiler found (may be incomplete — see below). | `--filter`, `--min-length N`, same | +| `get_callers ` | Call-sites only — subset of `xref_to`. | same | +| `read_memory ` | Read raw bytes from the binary at ``. Default output is a hexdump. | `--format {hexdump,hex,raw}`, same + `--json` (base64-encoded bytes) | +| `install-skill` | Install this file for Claude Code or Codex. | `--agent`, `--dest`, `--force`, `--json` | + +### `xref_to` vs `get_callers` + +- `xref_to` asks the backend for **every reference** — code *and* data. On + Ghidra with `--decompile` this includes global variables and string + references. Rows include a `kind` field (`Function`, `GlobalVariable`, + ...). `xref_to` also accepts **strings and raw addresses**: if the + target isn't a function, it's looked up in `list_strings` first, then + queried as a raw-address xref — so you can go straight from + `list_strings --filter "admin"` to `xref_to admin` to find who reads + that constant. +- `get_callers` is the narrower call-sites-only view: only functions that + contain a `call` to the target. When you want "who calls this?" reach + for `get_callers`; when you want "who touches this in any way?" reach + for `xref_to`. + +### `read_memory` — raw bytes at an address + +`read_memory ` reads `` bytes from the loaded binary's +mapped memory starting at ``. It goes through the backend's own +memory accessor, so it returns whatever the decompiler currently has +loaded for that address (post-relocation, post-mapping) — not the raw +bytes from the on-disk ELF/PE/Mach-O. Use it when you need to: + +- Inspect a constant table, jump table, or vtable that the decompiler + rendered as `dword_` / `unk_`. +- Read a string the backend's string detector missed (cross-check + against `list_strings` first; if absent, dump bytes manually). +- Verify the actual bytes behind a global the decompiler shows as an + opaque symbol. +- Pull a magic header / signature out of `.rodata` to confirm a file + format or library version. + +```bash +decompiler read_memory 0x4008e0 64 # default: hexdump +decompiler read_memory 0x4008e0 64 --format hex # one-line hex blob +decompiler read_memory 0x4008e0 64 --format raw > bytes # raw bytes to a file +decompiler read_memory 0x4008e0 64 --json # base64-encoded payload +``` + +JSON output includes both `size` (actual bytes returned) and +`requested_size` — backends may produce **short reads** when the request +straddles the end of a mapped segment. In text mode the CLI prints a +`# short read: ...` notice on stderr in that case. If the address is +unmapped or uninitialized, the CLI exits non-zero with a message saying +the backend couldn't satisfy the read; try a smaller `size` or confirm +the address with `list_functions` / `xref_to`. + +Address formats follow the same rules as everywhere else: hex (`0x4008e0`), +decimal (`4197088`), or lifted (`0x8e0`) all work. + +### Editing types and syncing across decompilers + +`create-type` parses a C type *definition* and adds it to the binary's type +database. `retype` then points a variable at it (or at any built-in type). +Both work on every backend; refer to the struct by name, with `*`/`[]` for +pointers and arrays: + +```bash +decompiler create-type "struct Point { int x; int y; }" +decompiler create-type "enum Color { RED, GREEN=5, BLUE }" +decompiler retype main buf "Point *" # stack var or argument, by name +``` + +`sync` copies one function's work from a **source** server into a +**destination** server for the *same* binary — handy when you reverse a +function in one tool and want it mirrored in another. It transfers the +function name, return/argument types, stack-variable names and types, and +any user-defined types those reference. The source is chosen with +`--from-id`; the destination with the usual `--id`/`--binary`/`--backend`: + +```bash +decompiler load ./fauxware --backend ida # id=ida123 (do your work here) +decompiler load ./fauxware --backend ghidra # id=ghi456 +decompiler rename func 0x71d auth_check --id ida123 +decompiler retype 0x71d buf "Point *" --id ida123 +decompiler sync 0x71d --from-id ida123 --id ghi456 # push it into Ghidra +``` + +Addresses and stack-variable offsets are normalized, so the function and +its variables re-key correctly even when the two backends name them +differently. Pass a function **address** (most robust) or a name. + +### `list_strings` may be incomplete + +`list_strings` returns exactly what the backend's own string detector +surfaced — the CLI does not second-guess the decompiler. Fidelity varies +(`angr < ghidra < ida`); angr in particular misses most of `.rodata`. If +the output looks thin, check the binary file directly with an external +scanner: + +```bash +strings -a -n 4 ./target # classic strings(1) +rabin2 -z ./target # radare2: ASCII data-section scan +readelf -p .rodata ./target # ELF-specific, per section +``` + +Use those to confirm a specific constant exists, then come back and +`decompile` / `xref_to` its address inside the CLI. `--min-length` +defaults to 4. + +## Machine-readable output + +Pass `--json` on any subcommand to get a structured payload suitable for +downstream parsing — ideal when an LLM wants to chain commands. Every +JSON payload that mentions an address provides both `addr` (int, lifted) +and `addr_hex` (hex string, also lifted): + +```bash +decompiler list_functions --filter '^main$' --json +# [{"addr": 1821, "size": 184, "name": "main", "addr_hex": "0x71d"}] + +decompiler list_strings --filter 'flag' --json +# [{"addr": 4197168, "string": "flag{...}", "addr_hex": "0x4008e0"}] + +decompiler decompile main --json +# {"addr": 1821, "decompiler": "angr", "text": "void main(...){...}", "addr_hex": "0x71d"} + +# Terminal-friendly form of decompile: skip JSON wrapping entirely. +decompiler decompile main --raw +``` + +## Gotchas and tips + +- **First `load` is slow** (backend analysis pass). Subsequent calls on the + same server are fast. +- **`rename` exit codes**: every CLI command exits `0` on success and `1` + on any failure (including "rename didn't find the old name"). Use + `&&` safely. +- **Stripped binaries**: use `list_functions` before `decompile` to find + the real entry. `main` may not exist; look for non-default names + (`sub_XXXX`, `FUN_...`, `entry`, etc.) with plausible sizes and xrefs. +- **Backend main-naming varies**: angr promotes the entry to `main`, + Ghidra leaves `FUN_00101c5c`, IDA emits `sub_101c5c`. Always resolve via + `list_functions` or a known entry address, not by assuming `main`. +- **Invalid addresses** fail with a clear message distinguishing "no + function starts here" from "decompiler engine failed". The CLI does not + auto-round-trip invalid addresses. +- **Address formats**: `0x71d`, `0x40071d`, and `1821` all resolve the + same function in fauxware. Names are also accepted wherever an address + is. +- **Servers persist** until explicitly stopped (`decompiler stop --all`) + or the host reboots; `decompiler list` always reflects live processes. +- **Registry path**: `decompiler list --show-registry` prints just the + directory so you can clean up manually if you ever need to (e.g. after + a `kill -9`). +- **Project/database files**: by default they live in + `/declib/projects/-/`, not next to the binary. + Pass `--project-dir ` to `load` to override, or `--project-dir ""` + to restore the legacy "write next to the binary" behavior. + +## Library-level API (for Python scripts) + +Everything the CLI does is also available as a library: + +```python +from declib.api.decompiler_client import DecompilerClient + +client = DecompilerClient.discover_from_registry(binary_path="./fauxware") +for addr, func in client.functions.items(): + if func.name == "main": + print(client.decompile(addr).text) +``` + +The new core APIs (`list_strings(filter=...)`, `get_callers(target)`, +`disassemble(addr)`, `read_memory(addr, size)`) are on both the local +`DecompilerInterface` and the `DecompilerClient` proxy. `read_memory` +returns `bytes` (or `None` if the backend can't satisfy the read), so +you can hexdump, decode, or feed the result straight into struct +parsers without going through the CLI. diff --git a/declib/ui/__init__.py b/declib/ui/__init__.py new file mode 100644 index 00000000..7a1a93e3 --- /dev/null +++ b/declib/ui/__init__.py @@ -0,0 +1,33 @@ +import math + +import tqdm + + +def progress_bar(items, gui=True, desc="Progressing..."): + """ + This displays either a text or GUI progress bar using the DecLib GUI backend. + This assumes that the GUI is already initialized and running if in GUI mode. + """ + if not gui: + for item in tqdm.tqdm(items, desc=desc): + yield item + else: + from declib.ui.utils import QProgressBarDialog + pbar = QProgressBarDialog(label_text=desc) + pbar.show() + callback_stub = pbar.update_progress + bucket_size = len(items) / 100.0 + if bucket_size < 1: + callback_amt = int(1 / bucket_size) + bucket_size = 1 + else: + callback_amt = 1 + bucket_size = math.ceil(bucket_size) + + for i, item in enumerate(items): + yield item + if i % bucket_size == 0: + callback_stub(callback_amt) + + # close the progress bar since it may not hit 100% + pbar.close() diff --git a/declib/ui/qt_objects.py b/declib/ui/qt_objects.py new file mode 100644 index 00000000..5222e9ea --- /dev/null +++ b/declib/ui/qt_objects.py @@ -0,0 +1,146 @@ +from declib.ui.version import ui_version + +if ui_version == "PySide6": + from PySide6.QtCore import ( + QDir, Qt, Signal, QAbstractTableModel, QModelIndex, QSortFilterProxyModel, QPersistentModelIndex, + QEvent, QThread, Slot, QObject, QPropertyAnimation, QAbstractAnimation, QParallelAnimationGroup, + QLineF, QTimer, QRect, QDateTime, + ) + from PySide6.QtWidgets import ( + QAbstractItemView, + QCheckBox, + QComboBox, + QDialog, + QFileDialog, + QFormLayout, + QGridLayout, + QGroupBox, + QHBoxLayout, + QHeaderView, + QLabel, + QLineEdit, + QMenu, + QMessageBox, + QPushButton, + QStatusBar, + QTableWidget, + QTableWidgetItem, + QTabWidget, + QTextBrowser, + QVBoxLayout, + QWidget, + QDialogButtonBox, + QTableView, + QFontDialog, + QCheckBox, + QMainWindow, + QApplication, + QFrame, + QWidget, + QSizePolicy, + QScrollArea, + QToolButton, + QProgressBar, + QGraphicsScene, + QGraphicsView, + QGraphicsEllipseItem, + QGraphicsTextItem, + QGraphicsLineItem, + QGraphicsItem, + QToolTip, + QStackedLayout, + QDateTimeEdit, + QSplitter, + ) + from PySide6.QtGui import ( + QFontDatabase, + QColor, + QKeyEvent, + QFocusEvent, + QIntValidator, + QAction, + QImage, + QFontMetrics, + QFont, + QPainter, + QBrush, + QPen, + QCursor, + ) + from shiboken6 import wrapInstance +else: + from PyQt5.QtCore import ( + QDir, Qt, QAbstractTableModel, QModelIndex, QSortFilterProxyModel, QPersistentModelIndex, + QEvent, QThread, QObject, QPropertyAnimation, QAbstractAnimation, QParallelAnimationGroup, + QLineF, QTimer, QRect, QDateTime, + ) + from PyQt5.QtCore import pyqtSignal as Signal + from PyQt5.QtCore import pyqtSlot as Slot + from PyQt5.QtWidgets import ( + QAbstractItemView, + QCheckBox, + QComboBox, + QDialog, + QFileDialog, + QFormLayout, + QGridLayout, + QGroupBox, + QHBoxLayout, + QHeaderView, + QLabel, + QLineEdit, + QMenu, + QMessageBox, + QPushButton, + QStatusBar, + QTableWidget, + QTableWidgetItem, + QTabWidget, + QTextBrowser, + QVBoxLayout, + QWidget, + QDialogButtonBox, + QTableView, + QAction, + QFontDialog, + QCheckBox, + QMainWindow, + QApplication, + QFrame, + QWidget, + QSizePolicy, + QScrollArea, + QToolButton, + QProgressBar, + QGraphicsScene, + QGraphicsView, + QGraphicsEllipseItem, + QGraphicsTextItem, + QGraphicsLineItem, + QGraphicsItem, + QToolTip, + QStackedLayout, + QDateTimeEdit, + QSplitter, + ) + from PyQt5.QtGui import ( + QFontDatabase, + QColor, + QKeyEvent, + QFocusEvent, + QIntValidator, + QImage, + QFontMetrics, + QFont, + QPainter, + QBrush, + QPen, + QCursor, + ) + try: + # new location for sip + # https://www.riverbankcomputing.com/static/Docs/PyQt5/incompatibilities.html#pyqt-v5-11 + from PyQt5.sip import wrapinstance as wrapInstance + except ImportError: + from sip import wrapinstance as wrapInstance + diff --git a/declib/ui/utils.py b/declib/ui/utils.py new file mode 100644 index 00000000..c80be8de --- /dev/null +++ b/declib/ui/utils.py @@ -0,0 +1,115 @@ +import sys + +from .qt_objects import ( + QDialog, QVBoxLayout, QProgressBar, QLabel, QPushButton, Qt, QThread, QApplication, Signal, QLineEdit, + QComboBox, QFontMetrics, QMessageBox +) + + +def gui_popup_text(text, title="Plugin Info") -> bool: + message_box = QMessageBox() + message_box.setIcon(QMessageBox.Information) + message_box.setWindowTitle(title) + message_box.setText(text) + message_box.setStandardButtons(QMessageBox.Ok) + + if message_box.exec() == QMessageBox.Ok: + return True + else: + return False + + +def gui_ask_for_string(question, title="Plugin Question", default="") -> str: + dialog = QDialog() + dialog.setWindowTitle(title) + + layout = QVBoxLayout(dialog) + + # Question label + label = QLabel(question) + layout.addWidget(label) + + # Text input field + text_input = QLineEdit() + if default: + text_input.setText(default) + layout.addWidget(text_input) + + # Submit button + submit_button = QPushButton("Submit") + layout.addWidget(submit_button) + submit_button.clicked.connect(dialog.accept) + + dialog.setLayout(layout) + + # Show the dialog and wait for user to submit + if dialog.exec(): + return text_input.text() + else: + return "" + + +def gui_ask_for_choice(question: str, choices: list, title="Plugin Question") -> str: + dialog = QDialog() + dialog.setWindowTitle(title) + + layout = QVBoxLayout() + label = QLabel(question) + layout.addWidget(label) + + combo_box = QComboBox() + combo_box.addItems(choices) + layout.addWidget(combo_box) + + button = QPushButton('Confirm') + button.clicked.connect(dialog.accept) + layout.addWidget(button) + + dialog.setLayout(layout) + dialog.exec() + + return combo_box.currentText() + + +class QProgressBarDialog(QDialog): + def __init__(self, label_text="Loading...", on_cancel_callback=None, parent=None): + super().__init__(parent) + self.on_cancel_callback = on_cancel_callback + + self.setWindowTitle("DecLib Loading...") + self.setWindowModality(Qt.ApplicationModal) + self.layout = QVBoxLayout() + + # Add the label + self.layout.addWidget(QLabel(label_text)) + + # Add the progress bar + self.progressBar = QProgressBar(self) + self.progressBar.setValue(0) + self.layout.addWidget(self.progressBar) + + # Add cancel button on the bottom + self.button = QPushButton("Cancel", self) + self.button.clicked.connect(self.on_cancel_clicked) + self.layout.addWidget(self.button) + + self.setLayout(self.layout) + + # Initialize progress value + self.progress = 0 + + def on_cancel_clicked(self): + if self.on_cancel_callback is not None: + self.on_cancel_callback() + + self.close() + + def on_finished(self): + self.close() + + def update_progress(self, value): + self.progress += value + if self.progress >= 100: + self.on_finished() + + self.progressBar.setValue(self.progress) diff --git a/declib/ui/version.py b/declib/ui/version.py new file mode 100644 index 00000000..b4bb5f92 --- /dev/null +++ b/declib/ui/version.py @@ -0,0 +1,14 @@ +ui_version = "PySide6" + + +def set_ui_version(version): + global ui_version + valid_version = [ + "PyQt5", + "PySide6" + ] + + if version in valid_version: + ui_version = version + else: + raise Exception("Failed to set BinSync UI version") diff --git a/docs/decompiler_cli.md b/docs/decompiler_cli.md new file mode 100644 index 00000000..631178ad --- /dev/null +++ b/docs/decompiler_cli.md @@ -0,0 +1,566 @@ +# `decompiler` CLI + +The `decompiler` command is a thin, LLM-friendly client over DecLib. You load a +binary once (which spawns a headless decompiler server in the background) and +then run quick inspection or mutation commands against it. Multiple binaries +and backends can be loaded at the same time; each server is identified by a +short ID. + +This document is for humans; the short reference version used by LLM agents +lives at [`declib/skills/decompiler/SKILL.md`](../declib/skills/decompiler/SKILL.md) +and can be installed with `decompiler install-skill`. + +--- + +## Table of contents + +- [Install & setup](#install--setup) +- [Quick start](#quick-start) +- [How it works](#how-it-works) +- [Subcommand reference](#subcommand-reference) + - [`load`](#load) + - [`list`](#list) + - [`stop`](#stop) + - [`list_functions`](#list_functions) + - [`decompile`](#decompile) + - [`disassemble`](#disassemble) + - [`xref_to`](#xref_to) + - [`xref_from`](#xref_from) + - [`rename`](#rename) + - [`list_strings`](#list_strings) + - [`get_callers`](#get_callers) + - [`install-skill`](#install-skill) +- [Server selection (`--id`, `--binary`, `--backend`)](#server-selection) +- [JSON output (`--json`, `--raw`)](#json-output) +- [Exit codes](#exit-codes) +- [Running multiple binaries at once](#running-multiple-binaries-at-once) +- [Address formats](#address-formats) +- [Library-level API](#library-level-api) +- [Troubleshooting](#troubleshooting) + +--- + +## Install & setup + +```bash +pip install declib +# Register DecLib plugins into every detected decompiler. +declib --install +# Or point the installer at one specific decompiler: +declib --single-decompiler-install binja "/Applications/Binary Ninja.app" +``` + +After `pip install declib`, two entry points are available: + +- `declib` — the existing management CLI (install plugins, run the server, + etc.) +- `decompiler` — the new LLM-facing CLI documented here. + +Pick a backend you have available: + +- **angr** — pure Python, always available. Good for end-to-end testing and + small/medium binaries. +- **ghidra** — requires `GHIDRA_INSTALL_DIR` and uses PyGhidra. +- **binja** — requires a Binary Ninja license. +- **ida** — requires IDA Pro. + +--- + +## Quick start + +```bash +# 1. Load a binary. The first call spawns a detached headless server. +decompiler load ./fauxware --backend angr +# id: 3308b81cf8 … + +# 2. Poke around. +decompiler list_functions # enumerate every function first +decompiler decompile main # by name +decompiler disassemble 0x40071d # by absolute address +decompiler xref_to authenticate # every code+data reference +decompiler get_callers authenticate # call-sites only (subset of xref_to) +decompiler xref_from main # what main calls +decompiler list_strings --filter 'pass|key' # regex-filtered strings + +# 3. Mutate the database. +decompiler rename func sub_400662 trampoline +decompiler rename var v2 auth_result --function main + +# 4. Tear it down when you're done. +decompiler stop --all +``` + +--- + +## How it works + +``` +┌─────────────┐ spawns ┌─────────────────────────┐ +│ decompiler │ ────────────────▶ │ declib --server (headless│ +│ CLI │ (first load) │ decompiler + AF_UNIX │ +│ │ │ socket) │ +│ │ ◀─────────────────│ │ +└─────────────┘ every command └─────────────────────────┘ + │ + ▼ +~/.local/state/declib/servers/.json ← the shared registry +``` + +Each running server writes a small JSON descriptor (`id`, `socket_path`, +`binary_path`, `binary_hash`, `backend`, `pid`, `started_at`) into a shared +registry directory. The CLI reads the registry to figure out which server to +talk to. Stale records (server exited, socket missing) are pruned on read. +Run `decompiler list --show-registry` to print just the path. + +Every subcommand except `load`, `list`, and `install-skill` accepts +`--id`, `--binary`, and `--backend` to pick which server to target when you +have more than one running. + +--- + +## Subcommand reference + +### `load` + +Load a binary, starting a headless server if one isn't already running for +it. + +```bash +decompiler load [--backend {angr,ghidra,binja,ida}] + [--id SERVER_ID] + [--force | --replace] + [--project-dir PATH] + [--json] +``` + +- **`--backend`** (default: `angr`) — which decompiler to use. +- **`--id`** — explicit server ID; otherwise one is auto-generated. +- **`--force`** — start an additional server even if one already matches + this `(binary, backend)`. Keeps the old server alive. +- **`--replace`** — stop any existing server for this `(binary, backend)` + first, then start a fresh one. Use this when you want to re-analyze from + scratch. +- **`--project-dir PATH`** — where to keep the backend's + project/database files (Ghidra project, IDA `.id*`/`.til`, etc.). + Default: a per-binary directory under the user cache + (`/declib/projects/-/`), so analysis + artifacts don't pollute the binary's directory. Pass `--project-dir ""` + to disable the cache dir and let the backend drop files alongside the + binary (legacy behavior). + +Outputs `id`, `socket_path`, `binary_path`, `backend`, `project_dir`, and +`status` (either `started` or `already_loaded`). + +### `list` + +Show all running decompiler servers. + +```bash +decompiler list [--show-registry] [--json] +``` + +Text output: + +``` +ID BACKEND PID BINARY +3308b81cf8 angr 57613 /…/fauxware +9d77ab8fd4 angr 57786 /…/posix_syscall + +(registry: /Users/me/Library/Application Support/declib/servers) +``` + +- **`--show-registry`** — print the registry directory and exit (useful for + scripting manual cleanup). +- **`--json`** emits `{"registry_dir": "...", "servers": [...]}`. + +### `stop` + +Stop one or all servers. + +```bash +decompiler stop [--id SERVER_ID] [--binary PATH] [--all] [--json] +``` + +You must pass one of `--id`, `--binary`, or `--all`. + +### `list_functions` + +Enumerate every function in the loaded binary. This is usually the first +thing you want on a new (possibly stripped) binary. + +```bash +decompiler list_functions [--filter REGEX] [--id ID] [--binary PATH] [--backend BACKEND] [--json] +``` + +Text output: + +``` +ADDR SIZE NAME +0x540 6 __libc_start_main +0x71d 184 main +0x664 184 authenticate +... +``` + +JSON output is a list of `{"addr": int, "size": int, "name": str, "addr_hex": str}`. + +### `decompile` + +Decompile a function to pseudocode. + +```bash +decompiler decompile [--raw] [--id ID] [--binary PATH] [--backend BACKEND] [--json] +``` + +`` is a function name or address (hex/decimal, lifted or absolute — +see [Address formats](#address-formats)). + +- **`--raw`** — print the decompilation text directly, skipping all + wrapping. Useful at a terminal when `--json`'s escaped `\n`s are + unreadable. + +Default text output is the decompilation. JSON output includes `addr`, +`addr_hex`, `decompiler`, and `text`. + +Error messages distinguish three failure modes: + +- **target not found** — function name/address doesn't resolve. +- **not a function start** — address resolves, but isn't a function + boundary. Exit 1. +- **decompiler engine failed** — address is a known function start, but + the backend gave up. Exit 1. + +### `disassemble` + +Disassemble a function to text assembly. + +```bash +decompiler disassemble [--raw] [--id ID] [--binary PATH] [--backend BACKEND] [--json] +``` + +Same error semantics and `--raw` flag as `decompile`. + +### `xref_to` + +**Every reference** to `target` — code AND data. + +```bash +decompiler xref_to [--decompile] [--id ID] [--binary PATH] [--backend BACKEND] [--json] +``` + +`` can be: + +- a **function name or address** — resolves to function xrefs (who calls + this function), +- a **raw address** that isn't a function start — resolves via the + backend's address-level reference table (useful for globals, jump + table entries, etc.), +- a **string literal** — looked up via `list_strings` first, then queried + as a raw-address xref. Great for "who reads this constant?". + +Each row has a `kind` field (`Function`, `GlobalVariable`, …) so you can +tell code refs from data refs. The JSON payload also carries +`target_kind` (`function`, `address`, or `string`) so callers can tell +which resolution path fired. + +- **`--decompile`** — ask the backend to decompile first. On Ghidra this + surfaces additional references (e.g. globals pulled in through the + HighFunction's global symbol map). + +When you want only call-sites, reach for `get_callers` instead. + +### `xref_from` + +Functions that `target` calls (its callees). + +```bash +decompiler xref_from [--id ID] [--binary PATH] [--backend BACKEND] [--json] +``` + +Implementation note: this prefers the backend's call-graph. If the +call-graph is unavailable it falls back to scanning the function's +disassembly for `call 0x…` instructions. + +### `rename` + +Rename a function or a local variable. + +```bash +# Rename a function. +decompiler rename func [--id ID] [--json] + +# Rename a local variable inside a function. +decompiler rename var --function [--id ID] [--json] +``` + +The CLI exits `1` if the rename didn't actually change anything (the +response's `success` field is authoritative). + +### `list_strings` + +List strings the decompiler's own string detector has identified in the +binary. + +```bash +decompiler list_strings [--filter REGEX] + [--min-length N] + [--id ID] [--binary PATH] [--backend BACKEND] [--json] +``` + +- **`--filter REGEX`** — only return strings matching the regex. +- **`--min-length N`** — drop strings shorter than N characters (default 4). + +Text output is `0x\t` per line. JSON output is a list of +`{"addr", "addr_hex", "string"}` entries. + +**Fidelity caveat.** This command only returns what the decompiler +itself surfaced — it does not second-guess the backend or supplement with +a file-level scan. Backend string detection quality varies +(`angr < ghidra < ida`); angr in particular misses most of `.rodata`. +If the output looks thin, cross-check with an external tool before +concluding a string isn't in the binary: + +```bash +strings -a -n 4 ./target # classic strings(1) +rabin2 -z ./target # radare2, structured output +readelf -p .rodata ./target # ELF-specific, per section +``` + +Once you've located a string that way you can feed its address back into +the CLI via `decompiler xref_to 0x...` or `decompiler decompile 0x...`. + +### `get_callers` + +Functions that contain a call to `target` — a strict subset of `xref_to`. + +```bash +decompiler get_callers [--id ID] [--binary PATH] [--backend BACKEND] [--json] +``` + +Unlike `xref_to`, this never returns globals or other data refs. Rows are +always of kind `Function`. + +### `install-skill` + +Copy the bundled Agent Skill into a supported agent skill directory so Claude +Code or Codex learns how to drive the CLI. + +```bash +decompiler install-skill [names ...] [--agent claude|codex|all] [--dest DIR] [--force] [--json] +``` + +With no `names`, every bundled skill is installed. By default the installer +uses Codex when `CODEX_*` environment variables are present, otherwise Claude. +Use `--agent codex`, `--agent claude`, repeated `--agent` flags, or +`--agent all` to choose explicitly. Claude installs under `~/.claude/skills`; +Codex installs under `$CODEX_HOME/skills` when `CODEX_HOME` is set, otherwise +`~/.codex/skills`. + +Use `--dest` to copy the skill somewhere else, and `--force` to overwrite an +existing directory. `--json` emits a well-formed JSON payload suitable for +piping through `jq`. + +--- + +## Server selection + +When more than one server is running, the inspection/mutation commands need +to know which one to talk to. Narrow with any combination of: + +- **`--id `** — exact match. +- **`--binary `** — match by binary path (resolved to an absolute + path). +- **`--backend `** — match by backend. + +If zero servers match, the CLI errors out and tells you to run +`decompiler load`. If multiple match, it prints a disambiguation list: + +``` +Multiple servers match. Specify --id to disambiguate: + 3308b81cf8 backend=angr binary=/…/fauxware + 9d77ab8fd4 backend=angr binary=/…/posix_syscall +``` + +--- + +## JSON output + +Pass `--json` on any subcommand to get a structured payload suitable for +downstream parsing. This is the recommended mode for scripts and LLM +callers. Every JSON payload that mentions an address provides both +`addr` (integer, lifted) and `addr_hex` (hex string, also lifted), so you +can copy either form without re-formatting: + +```bash +decompiler list_functions --filter '^main$' --json +# [{"addr": 1821, "size": 184, "name": "main", "addr_hex": "0x71d"}] + +decompiler xref_to authenticate --json +# {"addr": 1636, "direction": "to", +# "xrefs": [{"kind": "Function", "addr": 1821, "name": "main", "addr_hex": "0x71d"}, ...], +# "addr_hex": "0x664"} +``` + +For decompile/disassemble output, JSON wraps the text in a `text` field +with escaped newlines. At a terminal this is awkward; pass `--raw` +instead: + +```bash +decompiler decompile main --raw # prints the pseudocode directly +decompiler disassemble 0x71d --raw # prints assembly directly +``` + +--- + +## Exit codes + +Every CLI command uses these exit codes: + +| Code | Meaning | +|---|---| +| `0` | Success. | +| `1` | User-visible error — target not found, rename didn't apply, decompile failed, etc. All failure modes unify to `1` so that shell `&&` chaining works cleanly. | + +Argparse-level errors (unknown subcommand, missing required argument) exit +with Python's standard argparse code `2`. + +--- + +## Running multiple binaries at once + +```bash +decompiler load ./my-binary # id=abc1234 +decompiler load ./my-binary-2 # id=def5678 + +decompiler list +# ID BACKEND PID BINARY +# abc1234... angr 4213 .../my-binary +# def5678... angr 4217 .../my-binary-2 +# +# (registry: /…/declib/servers) + +# Target by ID … +decompiler decompile main --id abc1234 + +# … or by binary path. +decompiler decompile main --binary ./my-binary-2 + +# Restart a server cleanly (stop existing, spawn fresh): +decompiler load ./my-binary --replace + +# Run an additional server alongside the existing one: +decompiler load ./my-binary --force + +# Tear them all down. +decompiler stop --all +``` + +You can even mix backends on the same binary — add `--force` to `load` to +launch a second server for the same file: + +```bash +decompiler load ./bin --backend ghidra +decompiler load ./bin --backend angr --force +decompiler decompile main --binary ./bin --backend ghidra +decompiler decompile main --binary ./bin --backend angr +``` + +--- + +## Address formats + +DecLib normalizes addresses to a **lifted** form (relative to the binary's +base address), so artifacts stay stable across decompilers. The CLI, though, +accepts whatever is natural for the user: + +- `0x71d`, `1821` — lifted +- `0x40071d` — absolute (base + lifted) +- `main` — symbol name + +The CLI converts on the fly. The returned `addr` fields in JSON output are +**always lifted**, which matches what the server's artifact dictionaries +use. `addr_hex` is the same value as a hex string for convenience. + +--- + +## Library-level API + +Everything the CLI does is also available as a library — useful when you +want to chain operations or integrate DecLib into a larger tool: + +```python +from declib.api.decompiler_client import DecompilerClient + +# Pick a running server out of the shared registry. +client = DecompilerClient.discover_from_registry(binary_path="./fauxware") + +for addr, func in client.functions.items(): + if func.name == "main": + print(client.decompile(addr).text) + print(client.disassemble(addr)) + for caller in client.get_callers(addr): + print(caller.addr, caller.name) +``` + +The three APIs added to power the CLI are also usable directly through +`DecompilerInterface` (headless/embedded) and `DecompilerClient` (remote): + +- `list_strings(filter: str | None = None) -> list[tuple[int, str]]` +- `get_callers(target: Function | int | str) -> list[Function]` +- `disassemble(addr: int) -> str | None` + +Backends currently implementing them: angr and Ghidra. IDA and Binary Ninja +fall back to the default implementations. + +--- + +## Troubleshooting + +**`No running decompiler server matches …`** +You haven't loaded the binary yet. Run +`decompiler load --backend ` first, or use +`decompiler list` to see what's already running. + +**`Multiple servers match. Specify --id to disambiguate`** +Two servers match your filters. Either pass `--id` with one of the printed +IDs, or narrow with `--binary`/`--backend`. + +**`Timed out waiting … for server … to start.`** +The detached server process didn't come up in time (default 5 minutes). +Check backend prerequisites: +- Ghidra: `GHIDRA_INSTALL_DIR` must be set. +- IDA/Binary Ninja: their Python bindings must be importable. +- angr: should just work. + +**`No function starts at 0x…`** +The address is valid in the binary but doesn't correspond to the first +byte of any known function. Use `decompiler list_functions` to find a +valid start. (Prior to v2 this was reported with the same error as +"decompiler engine failed"; they're now distinct.) + +**Rename reports `success: False` (exit 1)** +The old name was not found in the function (e.g. it was already renamed, +or you targeted the wrong function). + +**`list_strings` looks thin** +This is expected on angr (and can happen on Ghidra for stripped binaries) — +`list_strings` returns only what the decompiler itself identified. Use an +external scanner to see every ASCII constant in the file, then feed the +address back into `xref_to` / `decompile`: + +```bash +strings -a -n 4 ./target +rabin2 -z ./target +readelf -p .rodata ./target +``` + +**Server-side logs** +Spawned servers have their stdout/stderr sent to `/dev/null`. If you're +debugging server startup, start one by hand in a foreground terminal: + +```bash +declib --server --headless --decompiler angr --binary-path ./bin --server-id my-srv +``` + +That will print log output to the terminal, and the CLI in another terminal +can still drive it via `decompiler decompile main --id my-srv`. diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..13a6399e --- /dev/null +++ b/examples/README.md @@ -0,0 +1,11 @@ +# DecLib Examples +This directory contains a series of example uses of DecLib in both plugins and as scripting library utilities. +When used as a plugin, DecLib requires a bit more setup to both init the UI components and start the artifact +watching backend. + +## Plugins +### change_watcher_plugins +This plugin shows off a few things: +1. Passing a generic function to be called on Artifact changes +2. Initing a context menu in any decompiler +3. Generally setting up a plugin as a package with its own installer \ No newline at end of file diff --git a/examples/change_watcher_plugin/README.md b/examples/change_watcher_plugin/README.md new file mode 100644 index 00000000..ec35ca4c --- /dev/null +++ b/examples/change_watcher_plugin/README.md @@ -0,0 +1,13 @@ +# Example BS Change Watcher Plugin +The example plugin to show of DecLib for watching artifact changes. + +## Install +``` +pip3 install -e . && python3 -m dl_change_watcher --install +``` + +## Usage +Open the decompiler: +1. If you are in Ghidra, use the menu to start the BS backend first +2. Right click on any function and select the `ArtifactChangeWatcher` and start the change watcher backend +3. Change any stack variable (as an example), you should see a printout that it was changed \ No newline at end of file diff --git a/examples/change_watcher_plugin/dl_change_watcher/__init__.py b/examples/change_watcher_plugin/dl_change_watcher/__init__.py new file mode 100644 index 00000000..91bca598 --- /dev/null +++ b/examples/change_watcher_plugin/dl_change_watcher/__init__.py @@ -0,0 +1,109 @@ +from pathlib import Path + +from declib.artifacts import Typedef +from declib.plugin_installer import DecLibPluginInstaller + +__version__ = "0.0.1" + +def create_plugin(*args, **kwargs): + """ + This is the entry point that all decompilers will call in various ways. To remain agnostic, + always pass the args and kwargs to the gui_init_args and gui_init_kwargs of DecompilerInterface, inited + through the discover api. + """ + + from declib.api import DecompilerInterface + from declib.artifacts import ( + FunctionHeader, StackVariable, Enum, Struct, GlobalVariable, Comment, Context + ) + + decompiler_opened_callbacks = [lambda *x, **y: print(f"[DLChangeWatcher] Started with plugin version {__version__}")] + decompiler_closed_callbacks = [lambda *x, **y: print(f"[DLChangeWatcher] Goodbye!")] + deci = DecompilerInterface.discover( + plugin_name="ArtifactChangeWatcher", + init_plugin=True, + decompiler_opened_callbacks=decompiler_opened_callbacks, + decompiler_closed_callbacks=decompiler_closed_callbacks, + # passing the flag below forces click recording to start on decompiler startup + # force_click_recording = True, + gui_init_args=args, + gui_init_kwargs=kwargs, + ) + # create a function to print a string in the decompiler console + decompiler_printer = lambda *x, **y: deci.print(f"Changed {x}") + # register the callback for all the types we want to print + deci.artifact_change_callbacks = { + typ: [decompiler_printer] for typ in ( + FunctionHeader, StackVariable, Enum, Struct, GlobalVariable, Comment, Typedef, Context + ) + } + + def _start_watchers(*x, **y): + deci.start_artifact_watchers() + deci.info("Artifact watchers started!") + + # register a menu to open when you right click on the psuedocode view + deci.gui_register_ctx_menu( + "StartArtifactChangeWatcher", + "Start watching artifact changes", + _start_watchers, + category="ArtifactChangeWatcher" + ) + + # return a plugin so the decompiler sets up the UI + return deci.gui_plugin + + +class DLChangeWatcherInstaller(DecLibPluginInstaller): + """ + This acts as a simple installer for the plugin + """ + + def __init__(self): + super().__init__() + self.pkg_path = self.find_pkg_files("dl_change_watcher") + + def _copy_plugin_to_path(self, path): + src = self.pkg_path / "dl_change_watcher_plugin.py" + dst = Path(path) / "dl_change_watcher_plugin.py" + self.link_or_copy(src, dst, symlink=True) + + def display_prologue(self): + print("Now installing DLChangeWatcher plugin...") + + def install_ida(self, path=None, interactive=True): + path = super().install_ida(path=path, interactive=interactive) + if not path: + return + + self._copy_plugin_to_path(path) + return path + + def install_ghidra(self, path=None, interactive=True): + path = super().install_ghidra(path=path, interactive=interactive) + if not path: + return + + self._copy_plugin_to_path(path) + return path + + def install_binja(self, path=None, interactive=True): + path = super().install_binja(path=path, interactive=interactive) + if not path: + return + + self._copy_plugin_to_path(path) + return path + + def install_angr(self, path=None, interactive=True): + path = super().install_angr(path=path, interactive=interactive) + if not path: + return + + path = path / "dl_change_watcher" + path.mkdir(parents=True, exist_ok=True) + src = self.pkg_path / "plugin.toml" + dst = Path(path) / "plugin.toml" + self.link_or_copy(src, dst, symlink=True) + self._copy_plugin_to_path(path) + return path diff --git a/examples/change_watcher_plugin/dl_change_watcher/__main__.py b/examples/change_watcher_plugin/dl_change_watcher/__main__.py new file mode 100644 index 00000000..420ca91e --- /dev/null +++ b/examples/change_watcher_plugin/dl_change_watcher/__main__.py @@ -0,0 +1,28 @@ +import argparse + +from . import DLChangeWatcherInstaller, create_plugin +import dl_change_watcher + + +def main(): + parser = argparse.ArgumentParser(description="An example CLI for the example change watcher plugin") + parser.add_argument( + "-i", "--install", action="store_true", help="Install plugin into your decompilers" + ) + parser.add_argument( + "-s", "--server", help="Run a a headless server for the watcher plugin", choices=["ghidra"] + ) + parser.add_argument("-v", "--version", action="version", version=dl_change_watcher.__version__) + args = parser.parse_args() + + if args.install: + DLChangeWatcherInstaller().install() + elif args.server: + if args.server != "ghidra": + raise NotImplementedError("Only Ghidra is supported for now") + + create_plugin(force_decompiler="ghidra") + + +if __name__ == "__main__": + main() diff --git a/examples/change_watcher_plugin/dl_change_watcher/dl_change_watcher_plugin.py b/examples/change_watcher_plugin/dl_change_watcher/dl_change_watcher_plugin.py new file mode 100644 index 00000000..61fb17d9 --- /dev/null +++ b/examples/change_watcher_plugin/dl_change_watcher/dl_change_watcher_plugin.py @@ -0,0 +1,61 @@ +# An example DecLib plugin that will print when every artifact is changed inside the decompiler +# @author BinSync +# @category BinSync +# @menupath Tools.ArtifactChangeWatcher.Start the BS backed for watcher + +# Note: this requires that your plugin, which is a package, exposes a function called `create_plugin` AND it +# exposes a command line interface that can be run (for Ghidra). +plugin_command = "dl_change_watcher -s ghidra" +def create_plugin(*args, **kwargs): + from dl_change_watcher import create_plugin as _create_plugin + return _create_plugin(*args, **kwargs) + + +# ============================================================================= +# DecLib generic plugin loader (don't touch) +# ============================================================================= + +import sys +# Python 2 has special requirements for Ghidra, which forces us to use a different entry point +# and scope for defining plugin entry points +if sys.version[0] == "2": + # Do Ghidra Py2 entry point + import subprocess + from declib_vendored.ghidra_bridge_server import GhidraBridgeServer + + GhidraBridgeServer.run_server(background=True) + process = subprocess.Popen(plugin_command.split(" ")) + if process.poll() is not None: + raise RuntimeError("Failed to run the Python3 backed. It's likely Python3 is not in your Path inside Ghidra.") +else: + # Try plugin discovery for other decompilers + try: + import idaapi + has_ida = True + except ImportError: + has_ida = False + try: + import angrmanagement + has_angr = True + except ImportError: + has_angr = False + + if not has_ida and not has_angr: + create_plugin() + elif has_angr: + from angrmanagement.plugins import BasePlugin + class AngrDLPluginThunk(BasePlugin): + def __init__(self, workspace): + super().__init__(workspace) + globals()["workspace"] = workspace + self.plugin = create_plugin() + + def teardown(self): + pass + + +def PLUGIN_ENTRY(*args, **kwargs): + """ + This is the entry point for IDA to load the plugin. + """ + return create_plugin(*args, **kwargs) diff --git a/examples/change_watcher_plugin/dl_change_watcher/plugin.toml b/examples/change_watcher_plugin/dl_change_watcher/plugin.toml new file mode 100644 index 00000000..6784a105 --- /dev/null +++ b/examples/change_watcher_plugin/dl_change_watcher/plugin.toml @@ -0,0 +1,13 @@ +[meta] +plugin_metadata_version = 0 + +[plugin] +name = "dl_change_watcher" +shortname = "dl_change_watcher" +version = "0.0.0" +description = "" +long_description = "" +platforms = ["windows", "linux", "macos"] +min_angr_version = "9.0.0.0" +author = "The BinSync Team" +entrypoints = ["dl_change_watcher_plugin.py"] \ No newline at end of file diff --git a/examples/change_watcher_plugin/setup.cfg b/examples/change_watcher_plugin/setup.cfg new file mode 100644 index 00000000..b4f9b8c4 --- /dev/null +++ b/examples/change_watcher_plugin/setup.cfg @@ -0,0 +1,24 @@ +[metadata] +name = dl_change_watcher +version = 0.0.0 +url = https://github.com/binsync/declib/tree/main/examples/change_watcher_plugin +classifiers = + License :: OSI Approved :: BSD License + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 +license = BSD 2 Clause +description = An example plugin using DecLib to watch and report changes to artifacts +long_description = file: README.md +long_description_content_type = text/markdown + +[options] +install_requires = + declib + +python_requires = >= 3.7 +include_package_data = True +packages = find: + +[options.entry_points] +console_scripts = + dl_change_watcher = dl_change_watcher.__main__:main diff --git a/examples/change_watcher_plugin/setup.py b/examples/change_watcher_plugin/setup.py new file mode 100644 index 00000000..8bf1ba93 --- /dev/null +++ b/examples/change_watcher_plugin/setup.py @@ -0,0 +1,2 @@ +from setuptools import setup +setup() diff --git a/examples/decompiler_client_example.py b/examples/decompiler_client_example.py new file mode 100644 index 00000000..2fc09ba8 --- /dev/null +++ b/examples/decompiler_client_example.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the RPyC-based DecompilerClient. + +This script shows how to use the new RPyC DecompilerClient which provides +identical API to DecompilerInterface but connects to a remote server. +""" + +import logging +import time +import sys +from typing import Optional + +# Set up logging +logging.basicConfig(level=logging.INFO) + +def example_with_local_decompiler(): + """Example using local DecompilerInterface""" + try: + from declib.api import DecompilerInterface + + print("=== Using Local DecompilerInterface ===") + deci = DecompilerInterface.discover() + if deci is None: + print("No local decompiler found") + return + + demo_decompiler_operations(deci) + + except Exception as e: + print(f"Local decompiler error: {e}") + + +def example_with_remote_decompiler(server_url: str = "rpyc://localhost:18861"): + """Example using remote DecompilerClient""" + try: + from declib.api.decompiler_client import DecompilerClient + + print(f"\n=== Using Remote DecompilerClient ({server_url}) ===") + with DecompilerClient.discover(server_url=server_url) as deci: + demo_decompiler_operations(deci) + + except Exception as e: + print(f"Remote decompiler error: {e}") + print("Make sure to start the server first with: declib --server") + + +def demo_decompiler_operations(deci): + """ + Demo function that works identically with both DecompilerInterface and DecompilerClient. + + This shows the power of the unified API - the same code works regardless of whether + the decompiler is local or remote. + """ + print(f"Decompiler: {deci.name}") + print(f"Binary path: {deci.binary_path}") + print(f"Binary hash: {deci.binary_hash}") + print(f"Base address: 0x{deci.binary_base_addr:x}" if deci.binary_base_addr else "None") + print(f"Decompiler available: {deci.decompiler_available}") + + # Test fast collection operations (this is where RPyC shines) + print(f"\n=== Testing Fast Collection Operations ===") + + # This should be fast - single bulk request for all light artifacts + start_time = time.time() + functions = list(deci.functions.items()) + end_time = time.time() + print(f"Retrieved {len(functions)} functions in {end_time - start_time:.3f}s") + + # Test other collections + collections = [ + ("comments", deci.comments), + ("patches", deci.patches), + ("global_vars", deci.global_vars), + ("structs", deci.structs), + ("enums", deci.enums), + ("typedefs", deci.typedefs) + ] + + for name, collection in collections: + start_time = time.time() + items = list(collection.keys()) + end_time = time.time() + print(f" {name}: {len(items)} items in {end_time - start_time:.3f}s") + + # Test function access (if any functions exist) + if len(deci.functions) > 0: + print(f"\n=== Testing Individual Access ===") + first_addr = functions[0][0] + + # Test full artifact access via __getitem__ (standard behavior) + start_time = time.time() + full_func = deci.functions[first_addr] # This gets the full artifact + end_time = time.time() + print(f"Full artifact access via []: {end_time - start_time:.3f}s") + print(f"Function: {full_func.name} at 0x{full_func.addr:x} (size: {full_func.size})") + + # Test light artifact access (fast, cached) + if hasattr(deci.functions, 'get_light'): + start_time = time.time() + light_func = deci.functions.get_light(first_addr) + end_time = time.time() + print(f"Light artifact access via get_light(): {end_time - start_time:.6f}s") + + # Show first few functions + print("\nFirst 5 functions:") + for addr, func in functions[:5]: + print(f" 0x{addr:x}: {func.name} (size: {func.size})") + + # Test method calls + try: + print(f"\n=== Testing Method Calls ===") + if len(deci.functions) > 0: + first_addr = list(deci.functions.keys())[0] + light_func = deci.fast_get_function(first_addr) + if light_func: + print(f" fast_get_function(0x{first_addr:x}): {light_func.name}") + + func_size = deci.get_func_size(first_addr) + print(f" get_func_size(0x{first_addr:x}): {func_size}") + + # Test decompilation if available + if deci.decompiler_available: + start_time = time.time() + decomp = deci.decompile(first_addr) + end_time = time.time() + if decomp: + lines = decomp.text.split('\n') + print(f" decompile(0x{first_addr:x}): {len(lines)} lines in {end_time - start_time:.3f}s") + print(f" First line: {lines[0][:80]}...") + else: + print(" No functions available for testing") + + except Exception as e: + print(f" Method call error: {e}") + + +def discover_decompiler(prefer_remote: bool = False, server_url: str = "rpyc://localhost:18861"): + """ + Smart discovery function that tries remote first if preferred, then falls back to local. + + This demonstrates how you can write code that seamlessly works with either + local or remote decompilers based on availability. + """ + if prefer_remote: + # Try remote first + try: + from declib.api.decompiler_client import DecompilerClient + return DecompilerClient.discover(server_url=server_url) + except Exception: + pass + + # Fall back to local + try: + from declib.api import DecompilerInterface + return DecompilerInterface.discover() + except Exception: + return None + else: + # Try local first + try: + from declib.api import DecompilerInterface + return DecompilerInterface.discover() + except Exception: + pass + + # Fall back to remote + try: + from declib.api.decompiler_client import DecompilerClient + return DecompilerClient.discover(server_url=server_url) + except Exception: + return None + + +def main(): + if len(sys.argv) > 1: + server_url = sys.argv[1] + else: + server_url = "rpyc://localhost:18861" + + print("DecLib DecompilerClient Example") + print("==============================") + + # Demo 1: Try local decompiler + example_with_local_decompiler() + + # Demo 2: Try remote decompiler + example_with_remote_decompiler(server_url) + + # Demo 3: Smart discovery + print(f"\n=== Smart Discovery (prefer remote) ===") + deci = discover_decompiler(prefer_remote=True, server_url=server_url) + if deci: + print(f"Discovered: {type(deci).__name__}") + demo_decompiler_operations(deci) + if hasattr(deci, 'shutdown'): + deci.shutdown() + else: + print("No decompiler available (local or remote)") + + print(f"\n=== Smart Discovery (prefer local) ===") + deci = discover_decompiler(prefer_remote=False, server_url=server_url) + if deci: + print(f"Discovered: {type(deci).__name__}") + demo_decompiler_operations(deci) + if hasattr(deci, 'shutdown'): + deci.shutdown() + else: + print("No decompiler available (local or remote)") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/retype_functions.py b/examples/retype_functions.py new file mode 100644 index 00000000..7ae2f0e7 --- /dev/null +++ b/examples/retype_functions.py @@ -0,0 +1,13 @@ +from declib.api import DecompilerInterface + +deci = DecompilerInterface.discover() +for addr, func in deci.functions.items(): + if func.size > 0x30: + # decompile the function + func = deci.functions[addr] + if func.header.type == "void": + deci.print(f"Updating {func}") + func.header.type = "int" + func.name = f"up_{addr}" + # reassign to affect the decompiler + deci.functions[addr] = func diff --git a/examples/struct_and_variable_use.py b/examples/struct_and_variable_use.py new file mode 100644 index 00000000..90c2decf --- /dev/null +++ b/examples/struct_and_variable_use.py @@ -0,0 +1,28 @@ +# This example works with the binary found in ../tests/binaries/fauxware +# To use this script, open that binary in a decompiler, than run this script. +from declib.api import DecompilerInterface +from declib.artifacts import Struct, StructMember + +deci = DecompilerInterface.discover() +# access a function and stack variable using their offsets, which get unified across decompilers +func = deci.functions[0x71D] +stack_var = func.stack_vars[-0x18] +print("Stack variable:", stack_var) + +# make a struct that is the same size as the stack variable (16) +members = { + 0: StructMember(name="field1", type_="int", size=4, offset=0), + 4: StructMember(name="field2", type_="int", size=4, offset=4), + 8: StructMember(name="field3", type_="int", size=4, offset=8), + 12: StructMember(name="field4", type_="int", size=4, offset=12), +} +struct = Struct(name="my_struct", size=16, members=members) +print("Struct:", struct) +deci.structs["my_struct"] = struct + +# modify the stack variable to use the struct +stack_var.type = "my_struct" +print("Updated stack variable:", stack_var) + +# reassign to affect the decompiler +deci.functions[0x71D] = func diff --git a/examples/template_plugin_entry.py b/examples/template_plugin_entry.py new file mode 100644 index 00000000..2d118c97 --- /dev/null +++ b/examples/template_plugin_entry.py @@ -0,0 +1,63 @@ +# REPLACE_ME: with the description of the plugin you want displayed in Ghidra, and update below items +# @author YourNameHere +# @category YourCategoryHere +# @menupath Tools.MyPlugin.Replace me with short desc shown in Tools>MyPlugin menu + +# REPLACE_ME: replace the command to run your plugin from Ghidra Python2 side +plugin_command = "my_library_name --run" + +def create_plugin(*args, **kwargs): + # REPLACE_ME this import with an import of your plugin's create_plugin function + from my_library_name import create_plugin as _create_plugin + return _create_plugin(*args, **kwargs) + +# ============================================================================= +# DecLib generic plugin loader (don't touch things below this) +# ============================================================================= + +import sys +# Python 2 has special requirements for Ghidra, which forces us to use a different entry point +# and scope for defining plugin entry points +if sys.version[0] == "2": + # Do Ghidra Py2 entry point + import subprocess + from declib_vendored.ghidra_bridge_server import GhidraBridgeServer + + GhidraBridgeServer.run_server(background=True) + process = subprocess.Popen(plugin_command.split(" ")) + if process.poll() is not None: + raise RuntimeError( + "Failed to run the Python3 backed. It's likely Python3 (and its scripts) is not in your Path inside Ghidra." + ) +else: + # Try plugin discovery for other decompilers + try: + import idaapi + has_ida = True + except ImportError: + has_ida = False + try: + import angrmanagement + has_angr = True + except ImportError: + has_angr = False + + if not has_ida and not has_angr: + create_plugin() + elif has_angr: + from angrmanagement.plugins import BasePlugin + class AngrDLPluginThunk(BasePlugin): + def __init__(self, workspace): + super().__init__(workspace) + globals()["workspace"] = workspace + self.plugin = create_plugin() + + def teardown(self): + pass + + +def PLUGIN_ENTRY(*args, **kwargs): + """ + This is the entry point for IDA to load the plugin. + """ + return create_plugin(*args, **kwargs) diff --git a/libbs/__init__.py b/libbs/__init__.py deleted file mode 100644 index 0b1a820d..00000000 --- a/libbs/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -The 'libbs' package has been renamed to 'declib'. - -This module is a thin backwards-compatibility shim that: -- emits a DeprecationWarning on import, -- forwards `libbs.X` imports to `declib.X` so existing code keeps working, -- will not receive further updates. - -Please install `declib` and update your imports: - - pip install declib - import declib # was: import libbs -""" -import importlib -import sys -import warnings -from importlib.abc import Loader, MetaPathFinder -from importlib.machinery import ModuleSpec - -__version__ = "3.8.1" - -warnings.warn( - "'libbs' has been renamed to 'declib'. Install 'declib' (pip install declib) " - "and replace 'libbs' with 'declib' in your imports. This shim forwards to " - "'declib' but will be removed in a future release.", - DeprecationWarning, - stacklevel=2, -) - - -class _DecLibAliasLoader(Loader): - def create_module(self, spec): - target = "declib" + spec.name[len("libbs"):] - module = importlib.import_module(target) - sys.modules[spec.name] = module - return module - - def exec_module(self, module): - return None - - -class _DecLibAliasFinder(MetaPathFinder): - """Resolve `libbs.X` imports to the matching `declib.X` module.""" - - _loader = _DecLibAliasLoader() - - def find_spec(self, fullname, path=None, target=None): - if fullname != "libbs" and not fullname.startswith("libbs."): - return None - return ModuleSpec(fullname, self._loader, is_package=True) - - -sys.meta_path.insert(0, _DecLibAliasFinder()) - -# Mirror declib's top-level attributes onto libbs so `libbs.foo` works without -# routing through the finder. -import declib as _declib # noqa: E402 - -for _attr in dir(_declib): - if not _attr.startswith("_"): - globals()[_attr] = getattr(_declib, _attr) - -del _declib, _attr diff --git a/pyproject.toml b/pyproject.toml index e9e9b140..1c39ac85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,20 +3,29 @@ requires = ["setuptools>=61.2"] build-backend = "setuptools.build_meta" [project] -name = "libbs" +name = "declib" classifiers = [ - "Development Status :: 7 - Inactive", "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.12", ] license = {text = "BSD 2 Clause"} -description = "Deprecated: 'libbs' has been renamed to 'declib'. Install 'declib' instead." +description = "Your Only Decompiler API Lib - A generic API to script in and out of decompilers" urls = {Homepage = "https://github.com/binsync/declib"} requires-python = ">= 3.10" dependencies = [ - "declib", + "toml", + "ply", + "pycparser~=3.0", + "setuptools", + "prompt_toolkit", + "tqdm", + "psutil", + "pyghidra", + "platformdirs", + "filelock", + "networkx" ] dynamic = ["version"] @@ -24,12 +33,30 @@ dynamic = ["version"] file = "README.md" content-type = "text/markdown" +[project.optional-dependencies] +test = [ + "pytest", + "angr", + "requests", + "ipdb" +] +ghidra = [ + "PySide6-Essentials>=6.4.2,!=6.7.0" +] + +[project.scripts] +declib = "declib.__main__:main" +decompiler = "declib.cli:main" + [tool.setuptools] include-package-data = true license-files = ["LICENSE"] +[tool.setuptools.package-data] +"declib.skills" = ["**/SKILL.md", "**/*.md"] + [tool.setuptools.packages] find = {namespaces = false} [tool.setuptools.dynamic] -version = {attr = "libbs.__version__"} +version = {attr = "declib.__version__"} diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py new file mode 100644 index 00000000..237227ce --- /dev/null +++ b/tests/test_artifacts.py @@ -0,0 +1,199 @@ +import sys +import json + +import unittest + +import toml +from declib.artifacts import ( + FunctionHeader, StackVariable, FunctionArgument, Function, ArtifactFormat, Struct, StructMember, + load_many_artifacts, Artifact +) + + +def generate_test_funcs(func_addr): + fh1 = FunctionHeader(name="main", addr=func_addr, type_="int *", args={ + 0: FunctionArgument(offset=0, name="a1", type_="int", size=4), + 1: FunctionArgument(offset=1, name="a2", type_="long", size=8) + }) + fh2 = FunctionHeader("binsync_main", func_addr, type_="long *", args={ + 0: FunctionArgument(offset=0, name="a1", type_="int", size=4), + 1: FunctionArgument(offset=1, name="a2", type_="int", size=4) + }) + + stack_vars1 = { + 0x0: StackVariable(stack_offset=0, name="v0", type_="int", size=4, addr=func_addr), + 0x4: StackVariable(stack_offset=4, name="v4", type_="int", size=4, addr=func_addr) + } + stack_vars2 = { + 0x0: StackVariable(stack_offset=0, name="v0", type_="int", size=4, addr=func_addr), + 0x4: StackVariable(stack_offset=4, name="v4", type_="long", size=8, addr=func_addr), + 0x8: StackVariable(stack_offset=8, name="v8", type_="long", size=8, addr=func_addr) + } + + func1 = Function(addr=func_addr, size=0x100, header=fh1, stack_vars=stack_vars1) + func2 = Function(addr=func_addr, size=0x150, header=fh2, stack_vars=stack_vars2) + return func1, func2 + + +class TestArtifacts(unittest.TestCase): + def test_func_diffing(self): + # setup top + func_addr = 0x400000 + func1, func2 = generate_test_funcs(func_addr) + + diff_dict = func1.diff(func2) + header_diff = diff_dict["header"] + vars_diff = diff_dict["stack_vars"] + + # size should not match + assert func1.size != func2.size + assert diff_dict["size"]["before"] == func1.size + assert diff_dict["size"]["after"] == func2.size + + # names should not match + assert header_diff["name"]["before"] == func1.name + assert header_diff["name"]["after"] == func2.name + + # arg1 should match + assert not header_diff["args"][0] + + # arg2 should not match + assert header_diff["args"][1]["type"]["before"] != header_diff["args"][1]["type"]["after"] + + # v4 and v8 should differ + offsets = [0, 4, 8] + for off in offsets: + var_diff = vars_diff[off] + if off == 0: + assert not var_diff + if off == 0x4: + assert var_diff["size"]["before"] != var_diff["size"]["after"] + elif off == 0x8: + assert var_diff["addr"]["before"] is None + assert var_diff["addr"]["after"] == func1.addr + + def test_func_nonconflict_merge(self): + # setup top + func_addr = 0x400000 + fh1 = FunctionHeader(name="user1_func", addr=func_addr, type_="int *", args={}) + fh2 = FunctionHeader(name="main", addr=func_addr, type_="long *", args={}) + + stack_vars1 = { + 0x0: StackVariable(stack_offset=0, name="v0", type_="int", size=4, addr=func_addr), + 0x4: StackVariable(stack_offset=4, name="my_var", type_="int", size=4, addr=func_addr) + } + stack_vars2 = { + 0x0: StackVariable(stack_offset=0, name="v0", type_="int", size=4, addr=func_addr), + 0x4: StackVariable(stack_offset=4, name="v4", type_="long", size=8, addr=func_addr), + 0x8: StackVariable(stack_offset=8, name="v8", type_="long", size=8, addr=func_addr) + } + + func1 = Function(addr=func_addr, size=0x100, header=fh1, stack_vars=stack_vars1) + func2 = Function(addr=func_addr, size=0x100, header=fh2, stack_vars=stack_vars2) + merge_func = func1.nonconflict_merge(func2) + + assert merge_func.name == "user1_func" + assert merge_func.header.type == "int *" + assert merge_func.stack_vars[0].name == "v0" + assert merge_func.stack_vars[4].name == "my_var" + assert merge_func.stack_vars[4].type == "int" + assert merge_func.stack_vars[8].name == "v8" + + def test_func_overwrite_merge(self): + func_addr = 0x400000 + func_size = 0x100 + fh1 = FunctionHeader(name="main", addr=func_addr, type_="int *", args={ + 0: FunctionArgument(offset=0, name="a1", type_="int", size=4) + }) + fh2 = FunctionHeader(name="binsync_main", addr=func_addr, type_="long *", args={ + 1: FunctionArgument(offset=1, name="bs_2", type_="int", size=4) + }) + + stack_vars1 = { + 0x0: StackVariable(stack_offset=0, name="v0", type_="int", size=4, addr=func_addr), + 0x4: StackVariable(stack_offset=4, name="v4", type_="long", size=8, addr=func_addr), + 0x8: StackVariable(stack_offset=8, name="v8", type_="long", size=8, addr=func_addr) + } + stack_vars2 = { + 0x0: StackVariable(stack_offset=0, name="v0", type_="long", size=4, addr=func_addr), + 0x4: StackVariable(stack_offset=4, name="my_var", type_="int", size=4, addr=func_addr) + } + + func1 = Function(addr=func_addr, size=func_size, header=fh1, stack_vars=stack_vars1) + func2 = Function(addr=func_addr, size=func_size, header=fh2, stack_vars=stack_vars2) + + merge_func = func1.overwrite_merge(func2) + + assert merge_func.size == func1.size == func2.size + assert merge_func.name == func2.name + assert merge_func.header.args[0].name == func1.header.args[0].name + assert merge_func.stack_vars[0].name == stack_vars1[0].name + assert merge_func.stack_vars[0].type == stack_vars2[0].type + assert merge_func.stack_vars[0x4] == stack_vars2[0x4] + assert merge_func.stack_vars[0x8] == stack_vars1[0x8] + + def test_serialization(self): + native_load_funcs = { + ArtifactFormat.JSON: json.loads, + ArtifactFormat.TOML: toml.loads + } + + func, _ = generate_test_funcs(0x400000) + struct = Struct(name="some_struct", size=8, members={ + 0: StructMember(offset=0, name="m0", type_="int", size=4), + 4: StructMember(offset=4, name="m4", type_="long", size=8) + }) + # TODO: add comments, enums, patches, and global vars to the test + for fmt, load_func in native_load_funcs.items(): + serialized_func = func.dumps(fmt=fmt) + loaded_func_dict = load_func(serialized_func) + + assert loaded_func_dict["addr"] == func.addr + assert loaded_func_dict["size"] == func.size + assert loaded_func_dict["name"] == func.name + assert loaded_func_dict["type"] == func.type + assert loaded_func_dict["header"]["name"] == func.header.name + assert loaded_func_dict["header"]["type"] == func.header.type + assert loaded_func_dict["header"]["args"]["0x0"]["name"] == func.header.args[0].name + # XXX: critical point: keys are strings, not integers + assert loaded_func_dict["stack_vars"]["0x0"]["name"] == func.stack_vars[0].name + + loaded_func = Function.loads(serialized_func, fmt=fmt) + assert loaded_func == func + + serialized_struct = struct.dumps(fmt=fmt) + loaded_struct_dict = load_func(serialized_struct) + assert loaded_struct_dict["name"] == struct.name + assert loaded_struct_dict["size"] == struct.size + assert loaded_struct_dict["members"]["0x0"]["name"] == struct.members[0].name + assert loaded_struct_dict["members"]["0x4"]["type"] == struct.members[4].type + + loaded_struct = Struct.loads(serialized_struct, fmt=fmt) + assert loaded_struct == struct + + def test_many_deserialization(self): + func, _ = generate_test_funcs(0x400000) + struct = Struct(name="some_struct", size=8, members={ + 0: StructMember(offset=0, name="m0", type_="int", size=4), + 4: StructMember(offset=4, name="m4", type_="long", size=8) + }) + + # test loading many in a list of strings + data_strs = [func.dumps(fmt=ArtifactFormat.JSON), struct.dumps(fmt=ArtifactFormat.JSON)] + for data_str in data_strs: + data_dict = json.loads(data_str) + # the ART_TYPE_STR should be in the data to tell you what type of artifact it is + assert Artifact.ART_TYPE_STR in data_dict + + loaded_arts = load_many_artifacts(data_strs, fmt=ArtifactFormat.JSON) + assert len(loaded_arts) == 2 + + loaded_func = loaded_arts[0] + assert loaded_func == func + + loaded_struct = loaded_arts[1] + assert loaded_struct == struct + + +if __name__ == "__main__": + unittest.main(argv=sys.argv) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..4ce1b0e0 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,89 @@ +import sys +import subprocess +import tempfile +from pathlib import Path + +import unittest + +from declib.plugin_installer import DecLibPluginInstaller + + +class TestCommandline(unittest.TestCase): + def test_change_watcher_plugin_cli(self): + # assumes you've pip installed ./examples/change_watcher_plugin + import dl_change_watcher + + # run the CLI version check + output = subprocess.run(["dl_change_watcher", "--version"], capture_output=True) + version = output.stdout.decode().strip() + assert version == dl_change_watcher.__version__ + + +class TestInstaller(unittest.TestCase): + """Tests for the plugin installer.""" + + def test_install_ida_to_custom_path(self): + """Test installing IDA plugin to a custom path.""" + with tempfile.TemporaryDirectory() as tmpdir: + installer = DecLibPluginInstaller(targets=["ida"]) + result = installer.install(interactive=False, paths_by_target={"ida": tmpdir}) + # Verify the installer ran without error and returned the path + assert "ida" in installer._successful_installs + assert installer._successful_installs["ida"] == Path(tmpdir) + + def test_install_binja_to_custom_path(self): + """Test installing Binary Ninja plugin to a custom path.""" + with tempfile.TemporaryDirectory() as tmpdir: + installer = DecLibPluginInstaller(targets=["binja"]) + result = installer.install(interactive=False, paths_by_target={"binja": tmpdir}) + # Verify the installer ran without error and returned the path + assert "binja" in installer._successful_installs + assert installer._successful_installs["binja"] == Path(tmpdir) + + def test_install_ghidra_to_custom_path(self): + """Test installing Ghidra plugin to a custom path.""" + with tempfile.TemporaryDirectory() as tmpdir: + installer = DecLibPluginInstaller(targets=["ghidra"]) + result = installer.install(interactive=False, paths_by_target={"ghidra": tmpdir}) + # Verify the installer ran without error and returned the path + assert "ghidra" in installer._successful_installs + assert installer._successful_installs["ghidra"] == Path(tmpdir) + + def test_install_angr_skipped_without_angrmanagement(self): + """Test that angr install is skipped when angr-management is not installed.""" + with tempfile.TemporaryDirectory() as tmpdir: + installer = DecLibPluginInstaller(targets=["angr"]) + # angr install requires angr-management to be installed, so it should be skipped + # in test environments where angr-management is not available + result = installer.install(interactive=False, paths_by_target={"angr": tmpdir}) + # The install may or may not succeed depending on whether angr-management is installed + # Just verify it doesn't raise an exception + + def test_install_all_decompilers_to_custom_paths(self): + """Test installing all decompilers to custom paths.""" + with tempfile.TemporaryDirectory() as tmpdir: + ida_path = Path(tmpdir) / "ida" + binja_path = Path(tmpdir) / "binja" + ghidra_path = Path(tmpdir) / "ghidra" + + ida_path.mkdir() + binja_path.mkdir() + ghidra_path.mkdir() + + installer = DecLibPluginInstaller(targets=["ida", "binja", "ghidra"]) + result = installer.install( + interactive=False, + paths_by_target={ + "ida": str(ida_path), + "binja": str(binja_path), + "ghidra": str(ghidra_path), + } + ) + # Verify all installers ran without error + assert "ida" in installer._successful_installs + assert "binja" in installer._successful_installs + assert "ghidra" in installer._successful_installs + + +if __name__ == "__main__": + unittest.main(argv=sys.argv) diff --git a/tests/test_client_server.py b/tests/test_client_server.py new file mode 100644 index 00000000..f6849d26 --- /dev/null +++ b/tests/test_client_server.py @@ -0,0 +1,703 @@ +import os +import contextlib +import socket +import tempfile +import threading +import time +import unittest +from pathlib import Path + +from declib.api.decompiler_server import DecompilerServer +from declib.api.decompiler_client import DecompilerClient +from declib.api.decompiler_interface import DecompilerInterface +from declib.decompilers import GHIDRA_DECOMPILER + +# Test binary path - use the same path as other tests +TEST_BINARIES_DIR = Path(os.getenv("TEST_BINARIES_DIR", Path(__file__).parent.parent.parent / "bs-artifacts" / "binaries")) +if not TEST_BINARIES_DIR.exists(): + # fallback to relative path + TEST_BINARIES_DIR = Path(__file__).parent.parent.parent / "bs-artifacts" / "binaries" + +FAUXWARE_PATH = TEST_BINARIES_DIR / "fauxware" + + +@contextlib.contextmanager +def simulate_no_af_unix(): + if hasattr(socket, "AF_UNIX"): + af_unix_val = socket.AF_UNIX + delattr(socket, "AF_UNIX") + try: + yield + finally: + setattr(socket, "AF_UNIX", af_unix_val) + else: + yield + +class TestClientServer(unittest.TestCase): + """Test the new AF_UNIX socket-based DecompilerClient and DecompilerServer""" + + def setUp(self): + """Set up test environment""" + self.server = None + self.client = None + self.temp_dir = None + + def tearDown(self): + """Clean up test environment""" + if self.client: + self.client.shutdown() + if self.server: + self.server.stop() + if self.temp_dir and os.path.exists(self.temp_dir): + try: + os.rmdir(self.temp_dir) + except: + pass + + def test_server_startup_and_client_connection(self): + """Test that server starts and client can connect""" + # Start server with Ghidra headless and fauxware binary + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware" + ) + self.server.start() + + # Give server time to start + time.sleep(1) + + # Connect client + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Verify connection works + self.assertTrue(self.client.is_connected()) + self.assertTrue(self.client.ping()) + + # Test basic properties + self.assertEqual(self.client.name, "ghidra") + self.assertIsNotNone(self.client.binary_path) + self.assertIsNotNone(self.client.binary_hash) + self.assertTrue(self.client.decompiler_available) + + def test_inet_fallback(self): + """Test the AF_INET fallback mechanism when AF_UNIX is missing""" + with simulate_no_af_unix(): + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_inet" + ) + self.server.start() + + # Give server time to start + time.sleep(2) + + # Verify that it binds to a port and writes to the socket path + self.assertTrue(os.path.exists(self.server.socket_path)) + with open(self.server.socket_path, 'r') as f: + port_str = f.read().strip() + self.assertTrue(port_str.isdigit()) + + # Connect client + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Verify connection works + self.assertTrue(self.client.is_connected()) + self.assertTrue(self.client.ping()) + + self.client.shutdown() + self.server.stop() + + def test_artifact_collections_match_local(self): + """Test that client artifact collections behave like local interface""" + with tempfile.TemporaryDirectory() as proj_dir: + # Create server + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_remote" + ) + self.server.start() + time.sleep(1) + + # Connect client + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Test that we get functions + remote_func_keys = list(self.client.functions.keys()) + self.assertGreater(len(remote_func_keys), 0, "Should have found functions") + + # Test that we can get light functions + remote_light_funcs = list(self.client.functions.items()) + self.assertGreater(len(remote_light_funcs), 0, "Should have light functions") + + # Verify functions are actual Function objects + if remote_light_funcs: + addr, func = remote_light_funcs[0] + self.assertIsNotNone(func, "Function should not be None") + self.assertEqual(func.addr, addr, "Function address should match key") + self.assertIsInstance(func.name, str, "Function should have a name") + + def test_client_server_method_calls(self): + """Test that client method calls work correctly""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_methods" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Test function size method + func_keys = list(self.client.functions.keys()) + self.assertGreater(len(func_keys), 0, "Should have functions") + + func_addr = func_keys[0] + func_size = self.client.get_func_size(func_addr) + self.assertGreater(func_size, 0, "Function size should be positive") + + # Test fast_get_function + fast_func = self.client.fast_get_function(func_addr) + self.assertIsNotNone(fast_func, "Fast function should not be None") + self.assertEqual(fast_func.addr, func_addr, "Fast function address should match") + + def test_client_discover_auto_detection(self): + """Test client auto-discovery functionality""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_autodiscovery" + ) + self.server.start() + time.sleep(1) + + # Test auto-discovery (should find the server we just started) + try: + self.client = DecompilerClient.discover() + self.assertTrue(self.client.is_connected()) + self.assertEqual(self.client.name, "ghidra") + except ConnectionError: + # Auto-discovery might fail if multiple temp directories exist + # This is acceptable, we can still test manual connection + self.client = DecompilerClient(socket_path=self.server.socket_path) + self.assertTrue(self.client.is_connected()) + + def test_error_handling(self): + """Test error handling in client-server communication""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_errors" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Test KeyError handling for non-existent function + with self.assertRaises(KeyError, msg="Should raise KeyError for non-existent function"): + self.client.functions[0xDEADBEEF] # Non-existent function + + def test_client_context_manager(self): + """Test client context manager functionality""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_context" + ) + self.server.start() + time.sleep(1) + + # Test context manager + with DecompilerClient(socket_path=self.server.socket_path) as client: + self.assertTrue(client.is_connected()) + self.assertEqual(client.name, "ghidra") + + # Client should be disconnected after context manager + # (Note: we can't test this easily since the client object is out of scope) + + def test_server_restart_discovery(self): + """Test that client can discover server after restart""" + with tempfile.TemporaryDirectory() as proj_dir: + # Start first server + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_restart" + ) + self.server.start() + time.sleep(1) + + # Get the binary hash from the server + self.client = DecompilerClient(socket_path=self.server.socket_path) + binary_hash = self.client.binary_hash + self.assertIsNotNone(binary_hash, "Binary hash should not be None") + socket_path_1 = self.server.socket_path + self.client.shutdown() + + # Stop the server + self.server.stop() + time.sleep(0.5) + + # Start a new server (will have different socket path) + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_restart2" + ) + self.server.start() + time.sleep(1) + socket_path_2 = self.server.socket_path + + # Socket paths should be different (different temp directories) + self.assertNotEqual(socket_path_1, socket_path_2, + "New server should have different socket path") + + # Client should discover the new server using binary_hash + self.client = DecompilerClient.discover(binary_hash=binary_hash) + self.assertTrue(self.client.is_connected()) + self.assertEqual(self.client.binary_hash, binary_hash) + self.assertEqual(self.client.socket_path, socket_path_2, + "Client should connect to new server, not old socket") + + def test_multiple_servers_binary_hash_matching(self): + """Test client can select correct server when multiple are running""" + # We'll use different binaries to get different hashes + # For this test, we'll create two servers with the same binary + # but simulate different binary_hash by using different project names + + with tempfile.TemporaryDirectory() as proj_dir1: + with tempfile.TemporaryDirectory() as proj_dir2: + # Start first server + server1 = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir1, + project_name="test_server1" + ) + server1.start() + time.sleep(1) + + # Get hash from first server + client1 = DecompilerClient(socket_path=server1.socket_path) + hash1 = client1.binary_hash + socket1 = server1.socket_path + client1.shutdown() + + # Start second server with same binary (will have same hash) + server2 = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir2, + project_name="test_server2" + ) + server2.start() + time.sleep(1) + + socket2 = server2.socket_path + self.assertNotEqual(socket1, socket2, "Servers should have different sockets") + + try: + # Discover with binary hash - should connect to one of the servers + # (since they have the same binary, they'll have the same hash) + discovered_client = DecompilerClient.discover(binary_hash=hash1) + self.assertTrue(discovered_client.is_connected()) + self.assertEqual(discovered_client.binary_hash, hash1) + + # Should connect to one of the two servers + self.assertIn(discovered_client.socket_path, [socket1, socket2], + "Should connect to one of the running servers") + discovered_client.shutdown() + + # Discover without binary hash - should connect to most recent + discovered_client2 = DecompilerClient.discover() + self.assertTrue(discovered_client2.is_connected()) + discovered_client2.shutdown() + + finally: + server1.stop() + server2.stop() + + def test_defunct_socket_handling(self): + """Test that client skips defunct socket files from stopped servers""" + with tempfile.TemporaryDirectory() as proj_dir: + # Start and stop a server to create a defunct socket + server1 = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_defunct" + ) + server1.start() + time.sleep(1) + defunct_socket = server1.socket_path + server1.stop() + time.sleep(0.5) + + # Manually recreate the socket file to simulate a stale socket + # (normally stop() removes it, but crashes might leave it) + import tempfile as tf + temp_dir = tf.mkdtemp(prefix="declib_server_") + defunct_socket = os.path.join(temp_dir, "decompiler.sock") + # Create an empty file to simulate stale socket + open(defunct_socket, 'w').close() + + # Start a new server + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_working" + ) + self.server.start() + time.sleep(1) + + # Discovery should skip the defunct socket and find the working server + self.client = DecompilerClient.discover() + self.assertTrue(self.client.is_connected()) + self.assertEqual(self.client.socket_path, self.server.socket_path, + "Should connect to working server, not defunct socket") + + # Clean up the fake defunct socket + try: + os.unlink(defunct_socket) + os.rmdir(temp_dir) + except: + pass + + def test_discover_with_binary_hash_no_match(self): + """Test that discovery fails when binary_hash doesn't match any server""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_no_match" + ) + self.server.start() + time.sleep(1) + + # Try to discover with a non-matching binary hash + fake_hash = "this_hash_does_not_exist_12345" + with self.assertRaises(ConnectionError) as context: + DecompilerClient.discover(binary_hash=fake_hash) + + # Error message should mention the hash + self.assertIn(fake_hash, str(context.exception)) + self.assertIn("none matched", str(context.exception).lower()) + + def test_server_info_includes_binary_hash(self): + """Test that server_info response includes binary_hash""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_server_info" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Server info is fetched during connection and stored + server_info = self.client._server_info + self.assertIsNotNone(server_info, "Server info should be available") + self.assertIn("binary_hash", server_info, "Server info should include binary_hash") + + # Verify binary_hash matches what we get from the property + self.assertEqual(server_info["binary_hash"], self.client.binary_hash, + "Server info binary_hash should match client property") + + def test_callback_events(self): + """Test that client receives callback events when artifacts change on server""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_callbacks" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Track callback invocations + callback_events = [] + + def test_callback(artifact, **kwargs): + callback_events.append({ + "artifact_type": type(artifact).__name__, + "artifact": artifact, + "kwargs": kwargs + }) + + # Register callback for Comment artifacts + from declib.artifacts import Comment + self.client.artifact_change_callbacks[Comment].append(test_callback) + + # Start artifact watchers (which starts event listener) + self.client.start_artifact_watchers() + time.sleep(0.5) # Give listener time to start + + # Verify event listener is running + self.assertTrue(self.client._event_listener_running, + "Event listener should be running") + self.assertTrue(self.client._subscribed_to_events, + "Client should be subscribed to events") + + # Trigger a callback on the server by creating a comment + # TODO: update this to just sent a comment so we can see the callback trigger naturally + test_comment = Comment(0x1234, "Test comment from callback test") + # Note: comment_changed will lift the artifact, which changes the address + lifted_comment = self.server.deci.comment_changed(test_comment) + + # Wait for event to be received and processed + time.sleep(0.5) + + # Verify callback was triggered + self.assertGreater(len(callback_events), 0, + "Callback should have been triggered") + + # Verify event contents + event = callback_events[0] + self.assertEqual(event["artifact_type"], "Comment", + "Event should be for Comment artifact") + # The address should match the lifted address, not the original + self.assertEqual(event["artifact"].addr, lifted_comment.addr, + "Comment address should match the lifted address") + self.assertIn("Test comment", event["artifact"].comment, + "Comment text should match") + + # Clean up + self.client.stop_artifact_watchers() + self.assertFalse(self.client._event_listener_running, + "Event listener should be stopped") + + def test_multiple_callbacks(self): + """Test that multiple callbacks can be registered and all are triggered""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_multiple_callbacks" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Track callbacks + callback1_called = [] + callback2_called = [] + + def callback1(artifact, **kwargs): + callback1_called.append(artifact) + + def callback2(artifact, **kwargs): + callback2_called.append(artifact) + + # Register multiple callbacks + from declib.artifacts import Struct + self.client.artifact_change_callbacks[Struct].append(callback1) + self.client.artifact_change_callbacks[Struct].append(callback2) + + # Start watchers + self.client.start_artifact_watchers() + time.sleep(0.5) + + # Trigger event + test_struct = Struct("TestStruct", 0x10, members={}) + self.server.deci.struct_changed(test_struct) + + # Wait for processing + time.sleep(0.5) + + # Both callbacks should have been called + self.assertEqual(len(callback1_called), 1, "Callback 1 should be called once") + self.assertEqual(len(callback2_called), 1, "Callback 2 should be called once") + self.assertEqual(callback1_called[0].name, "TestStruct") + self.assertEqual(callback2_called[0].name, "TestStruct") + + def test_callback_with_metadata(self): + """Test that callback metadata (like deleted flag) is passed correctly""" + with tempfile.TemporaryDirectory() as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_callback_metadata" + ) + self.server.start() + time.sleep(1) + + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Track metadata + received_metadata = [] + + def metadata_callback(artifact, **kwargs): + received_metadata.append(kwargs) + + # Register callback + from declib.artifacts import Enum + self.client.artifact_change_callbacks[Enum].append(metadata_callback) + + # Start watchers + self.client.start_artifact_watchers() + time.sleep(0.5) + + # Trigger event with metadata + test_enum = Enum("TestEnum", members={}) + self.server.deci.enum_changed(test_enum, deleted=True) + + # Wait for processing + time.sleep(0.5) + + # Verify metadata was passed + self.assertEqual(len(received_metadata), 1, "Callback should be called once") + self.assertIn("deleted", received_metadata[0], "Metadata should include deleted flag") + self.assertTrue(received_metadata[0]["deleted"], "deleted flag should be True") + + def test_artifact_watchers_integration(self): + """ + Test artifact callbacks with client-server architecture (adapted from test_remote_ghidra). + + Note: This test manually triggers callbacks on the server to test the event broadcast system, + since Ghidra's artifact watchers don't function in headless mode. + """ + from declib.artifacts import FunctionHeader, StackVariable, Struct, GlobalVariable, Enum, Comment + from collections import defaultdict + + with tempfile.TemporaryDirectory() as proj_dir: + # Start server + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_artifact_watchers" + ) + self.server.start() + time.sleep(1) + + # Connect client + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Track callback hits + hits = defaultdict(list) + def func_hit(artifact, **kwargs): + hits[artifact.__class__].append(artifact) + + # Register callbacks for different artifact types + for typ in (FunctionHeader, StackVariable, Enum, Struct, GlobalVariable, Comment): + self.client.artifact_change_callbacks[typ].append(func_hit) + + # Start event listener + self.client.start_artifact_watchers() + time.sleep(0.5) + + # Test FunctionHeader callback by manually triggering on server + # (Ghidra headless watchers don't work, so we manually trigger) + func_addr = self.client.art_lifter.lift_addr(0x400664) + main = self.client.functions[func_addr] + + # Trigger callback on server side directly + test_header = FunctionHeader("test_func", func_addr, type_="int") + self.server.deci.function_header_changed(test_header) + time.sleep(0.5) + + # Verify callback was received on client + self.assertGreaterEqual(len(hits[FunctionHeader]), 1, + "FunctionHeader callback should be triggered") + + # Test Comment callback + test_comment = Comment(func_addr, "Test comment for integration test") + self.server.deci.comment_changed(test_comment) + time.sleep(0.5) + + self.assertGreaterEqual(len(hits[Comment]), 1, + "Comment callback should be triggered") + + # Test Struct callback + test_struct = Struct("TestStruct", 0x10, members={}) + self.server.deci.struct_changed(test_struct) + time.sleep(0.5) + + self.assertGreaterEqual(len(hits[Struct]), 1, + "Struct callback should be triggered") + + # Test Enum callback + test_enum = Enum("TestEnum", members={"VALUE1": 1, "VALUE2": 2}) + self.server.deci.enum_changed(test_enum) + time.sleep(0.5) + + self.assertGreaterEqual(len(hits[Enum]), 1, + "Enum callback should be triggered") + + # Test GlobalVariable callback + g_addr = self.client.art_lifter.lift_addr(0x4008e0) + test_gvar = GlobalVariable(g_addr, "test_global", "int", 4) + self.server.deci.global_variable_changed(test_gvar) + time.sleep(0.5) + + self.assertGreaterEqual(len(hits[GlobalVariable]), 1, + "GlobalVariable callback should be triggered") + + # Test that client can also modify artifacts through the server + # and they persist correctly + main.name = "modified_main" + self.client.functions[func_addr] = main + time.sleep(0.5) + + # Retrieve and verify the change persisted + modified_main = self.client.functions[func_addr] + self.assertEqual(modified_main.name, "modified_main", + "Function name modification should persist") + + # Clean up + self.client.stop_artifact_watchers() + self.assertFalse(self.client._event_listener_running, + "Event listener should be stopped") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_decompiler_cli.py b/tests/test_decompiler_cli.py new file mode 100644 index 00000000..687f58f8 --- /dev/null +++ b/tests/test_decompiler_cli.py @@ -0,0 +1,1078 @@ +""" +Tests for the `decompiler` CLI and the new declib core features it exposes +(list_strings, get_callers, disassemble, xref_to_addr, xref_from). + +The CLI tests are backend-parametrized: each test method lives on a single +base class, and one subclass per supported decompiler re-runs them with a +different ``backend`` class attribute. Backends whose dependencies aren't +available are skipped. + +Subprocesses are used on purpose so the real entry point + server-registry +flow is exercised end-to-end. +""" +import json +import os +import shutil +import subprocess +import sys +import tempfile +import unittest +from pathlib import Path + +from declib.api import server_registry +from declib.api.decompiler_client import DecompilerClient +from declib.api.decompiler_interface import DecompilerInterface + + +TEST_BINARIES_DIR = Path( + os.getenv("TEST_BINARIES_DIR", Path(__file__).parent.parent.parent / "bs-artifacts" / "binaries") +) +FAUXWARE_PATH = TEST_BINARIES_DIR / "fauxware" +POSIX_SYSCALL_PATH = TEST_BINARIES_DIR / "posix_syscall" + + +# --------------------------------------------------------------------------- +# Backend availability detection: skip subclasses cleanly when a decompiler +# isn't installed. Keep these tight and cheap — don't actually load a binary. +# --------------------------------------------------------------------------- + +def _backend_available(backend: str) -> bool: + try: + if backend == "angr": + import angr # noqa: F401 + elif backend == "ghidra": + import pyghidra # noqa: F401 + if not os.environ.get("GHIDRA_INSTALL_DIR"): + return False + elif backend == "binja": + import binaryninja # noqa: F401 + elif backend == "ida": + import idapro # noqa: F401 + else: + return False + except Exception: + return False + return True + + +def _cli_env(): + env = os.environ.copy() + # Isolate registry per-test so concurrent test runs don't collide and stale + # servers from previous runs don't leak in. + env["DECLIB_SERVER_REGISTRY"] = _REGISTRY_DIR + return env + + +def _run_cli(*args, check=True, timeout=600, env_overrides=None) -> subprocess.CompletedProcess: + """Run the `decompiler` CLI and return the result.""" + cmd = [sys.executable, "-m", "declib.cli.decompiler_cli", *args] + env = _cli_env() + for key, value in (env_overrides or {}).items(): + if value is None: + env.pop(key, None) + else: + env[key] = value + return subprocess.run(cmd, capture_output=True, text=True, check=check, timeout=timeout, env=env) + + +def _format_hex(value: int) -> str: + """Tiny helper: render an int as ``0x...`` for CLI args.""" + return f"0x{value:x}" + + +# Shared registry directory for this module's tests +_REGISTRY_DIR = tempfile.mkdtemp(prefix="declib_cli_registry_") + + +def _stop_all_servers(): + """Best-effort teardown: kill every server present in the registry.""" + os.environ["DECLIB_SERVER_REGISTRY"] = _REGISTRY_DIR + try: + records = server_registry.list_servers(prune_stale=False) + except Exception: + records = [] + for record in records: + try: + client = DecompilerClient(socket_path=record["socket_path"]) + try: + client._send_request({"type": "shutdown_deci"}) + except Exception: + pass + client.shutdown() + except Exception: + pass + finally: + server_registry.unregister_server(record.get("id")) + # Also try to SIGKILL the PID as a fallback + pid = record.get("pid") + if pid: + try: + os.kill(int(pid), 9) + except Exception: + pass + + +class _CLIBackendTestBase(unittest.TestCase): + """Base class for backend-parametrized CLI tests. + + Subclasses set ``backend`` to one of ``angr``, ``ghidra``, ``binja``, + ``ida``. Tests that rely on angr-specific quirks are gated inside the + method body rather than being split into separate subclasses, so a + single test method describes "what the CLI should do against any + backend" and the backend-specific allowances live near the asserts. + """ + + backend: str = "angr" + + @classmethod + def setUpClass(cls): + # `_CLIBackendTestBase` itself is abstract; skip it so unittest doesn't + # try to run its inherited methods with the default angr backend. + if cls is _CLIBackendTestBase: + raise unittest.SkipTest("abstract base class") + if not FAUXWARE_PATH.exists(): + raise unittest.SkipTest(f"Missing test binary: {FAUXWARE_PATH}") + if not _backend_available(cls.backend): + raise unittest.SkipTest(f"{cls.backend} backend not available") + os.environ["DECLIB_SERVER_REGISTRY"] = _REGISTRY_DIR + _stop_all_servers() + + @classmethod + def tearDownClass(cls): + _stop_all_servers() + + def tearDown(self): + _stop_all_servers() + + # ------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------- + + def _load_fauxware(self, *extra_args, project_dir=None): + args = ["load", str(FAUXWARE_PATH), "--backend", self.backend, "--json", *extra_args] + if project_dir is not None: + args.extend(["--project-dir", str(project_dir)]) + result = _run_cli(*args) + payload = json.loads(result.stdout) + self.assertIn(payload["status"], ("started", "already_loaded")) + self.assertEqual(payload["backend"], self.backend) + return payload + + def _resolve_main_name(self): + """Return whatever the current backend calls the fauxware entry. + + angr promotes the entry to ``main``; Ghidra leaves ``main`` when the + symbol is present (fauxware is not stripped). We scan + ``list_functions`` so the tests don't depend on any particular + backend's naming convention. + """ + result = _run_cli("list_functions", "--json") + entries = json.loads(result.stdout) + preferred = {"main", "_main"} + for e in entries: + if e.get("name") in preferred: + return e["name"] + # Fauxware's `main` entry starts at offset 0x71d (lifted). + for e in entries: + if e.get("addr") == 0x71d: + return e["name"] or f"0x{e['addr']:x}" + self.fail("Couldn't locate main in list_functions output") + + # ------------------------------------------------------------------- + # Shared backend-agnostic tests + # ------------------------------------------------------------------- + + def test_load_and_list(self): + loaded = self._load_fauxware() + server_id = loaded["id"] + + list_result = _run_cli("list", "--json") + payload = json.loads(list_result.stdout) + self.assertIn("registry_dir", payload) + ids = {s["id"] for s in payload["servers"]} + self.assertIn(server_id, ids) + + def test_list_functions_and_decompile(self): + self._load_fauxware() + lf = _run_cli("list_functions", "--json").stdout + entries = json.loads(lf) + self.assertTrue(entries, "list_functions returned no entries") + for e in entries: + self.assertIn("addr", e) + self.assertIn("addr_hex", e) + self.assertIn("size", e) + self.assertIn("name", e) + + name = self._resolve_main_name() + dec_result = _run_cli("decompile", name, "--json") + payload = json.loads(dec_result.stdout) + self.assertIn("text", payload) + self.assertTrue(payload["text"], "empty decompilation") + self.assertIn("addr_hex", payload) + self.assertTrue(payload["addr_hex"].startswith("0x")) + + def test_disassemble(self): + self._load_fauxware() + name = self._resolve_main_name() + result = _run_cli("disassemble", name, "--json") + payload = json.loads(result.stdout) + self.assertIn("text", payload) + self.assertIn("addr_hex", payload) + # Any reasonable disassembler emits at least one of these opcodes for + # main. Compare case-insensitively so Ghidra's uppercase "PUSH" and + # angr/capstone's lowercase "push" both pass. + text = payload["text"].lower() + self.assertTrue(any(op in text for op in ("push", "mov", "call", "sub"))) + + def test_decompile_raw(self): + """--raw should print text directly, not JSON-wrapped.""" + self._load_fauxware() + name = self._resolve_main_name() + raw = _run_cli("decompile", name, "--raw") + self.assertNotIn('\\n', raw.stdout) + self.assertNotIn('{"addr"', raw.stdout) + + def test_list_strings(self): + self._load_fauxware() + # Every supported backend sees this string in fauxware. + result = _run_cli("list_strings", "--filter", "Welcome", "--json") + payload = json.loads(result.stdout) + self.assertTrue(any("Welcome" in s["string"] for s in payload), + f"{self.backend} list_strings missed 'Welcome': {payload!r}") + for entry in payload: + # Regression for negative-address / `0x-100000` formatting — the + # lifted hex rendering must always be a well-formed positive hex. + self.assertTrue(entry["addr_hex"].startswith("0x")) + self.assertNotIn("-", entry["addr_hex"][2:]) + + def test_xref_to_function(self): + self._load_fauxware() + # `authenticate` exists in fauxware and is called from main across + # all backends we support. + result = _run_cli("xref_to", "authenticate", "--json") + payload = json.loads(result.stdout) + self.assertEqual(payload.get("target_kind"), "function") + names = {x.get("name") for x in payload["xrefs"]} + self.assertIn("main", names, f"{self.backend}: 'main' not in xrefs_to(authenticate): {names!r}") + for x in payload["xrefs"]: + self.assertIn("addr_hex", x) + + def test_xref_to_string(self): + """Regression: xref_to should accept a string literal as target.""" + self._load_fauxware() + # SOSNEAKY is the magic password constant in fauxware; it's + # referenced from `authenticate`. + result = _run_cli("xref_to", "SOSNEAKY", "--json", check=False) + if result.returncode != 0: + self.skipTest(f"{self.backend} doesn't surface SOSNEAKY: {result.stdout}") + payload = json.loads(result.stdout) + self.assertEqual(payload.get("target_kind"), "string") + xref_names = {x.get("name") for x in payload["xrefs"]} + self.assertIn("authenticate", xref_names, + f"{self.backend}: expected 'authenticate' in xref_to(SOSNEAKY): {xref_names}") + + def test_xref_from(self): + """Regression: xref_from must return non-empty callees on each backend.""" + self._load_fauxware() + name = self._resolve_main_name() + result = _run_cli("xref_from", name, "--json") + payload = json.loads(result.stdout) + addrs = {x.get("addr") for x in payload["xrefs"]} + self.assertGreater(len(addrs), 0, f"{self.backend}: xref_from({name}) empty") + # Backends with debug symbols recognize at least one of these names. + names = {x.get("name") for x in payload["xrefs"] if x.get("name")} + self.assertTrue(names & {"authenticate", "puts", "read", "accepted", "rejected"}, + f"{self.backend}: expected a known callee in {names}") + + def test_get_callers(self): + self._load_fauxware() + result = _run_cli("get_callers", "authenticate", "--json") + payload = json.loads(result.stdout) + names = {c.get("name") for c in payload["callers"]} + self.assertIn("main", names) + for c in payload["callers"]: + self.assertIn("addr_hex", c) + + def test_read_memory(self): + """read_memory should return the bytes at a known location. + + Fauxware's ``Welcome to the admin console, trusted user!`` string + lives at lifted address ``0x8e0`` and the ELF header lives at the + binary's base. Both are stable across every backend we support, so + this is a clean cross-decompiler smoke test. + """ + import base64 + + self._load_fauxware() + + # 1. ELF magic at the binary's base. Lifted address 0x0. + result = _run_cli("read_memory", "0x0", "0x4", "--json") + payload = json.loads(result.stdout) + self.assertEqual(payload["size"], 4) + decoded = base64.b64decode(payload["bytes_b64"]) + self.assertEqual(decoded, b"\x7fELF", + f"{self.backend} read_memory(0x0, 4) returned {decoded!r}") + self.assertEqual(payload["hex"], "7f454c46") + + # 2. The "Welcome" string. Walk list_strings to find it so this + # isn't tied to a specific backend's address representation. + strings = json.loads(_run_cli("list_strings", "--filter", "Welcome", + "--json").stdout) + self.assertTrue(strings, f"{self.backend}: 'Welcome' string not surfaced") + welcome_addr = strings[0]["addr"] + + result = _run_cli("read_memory", _format_hex(welcome_addr), "7", "--json") + payload = json.loads(result.stdout) + self.assertEqual(base64.b64decode(payload["bytes_b64"]), b"Welcome", + f"{self.backend} read_memory at Welcome addr returned wrong bytes") + + def test_read_memory_hexdump_default(self): + """Default text output is a hexdump of the bytes.""" + self._load_fauxware() + result = _run_cli("read_memory", "0x0", "16") + # Hexdump of the ELF header starts with the magic + class + data. + self.assertIn("7f 45 4c 46", result.stdout) + # ASCII column should also be present. + self.assertIn("|.ELF", result.stdout) + + def test_read_memory_hex_format(self): + self._load_fauxware() + result = _run_cli("read_memory", "0x0", "4", "--format", "hex") + self.assertEqual(result.stdout.strip(), "7f454c46") + + def test_read_memory_invalid_address(self): + """An address far outside any segment should error cleanly.""" + self._load_fauxware() + result = _run_cli("read_memory", "0xdeadbeef00", "16", check=False) + self.assertNotEqual(result.returncode, 0) + # Either the backend rejects it, or it raises before responding. + # We just assert the CLI didn't print bytes. + combined = result.stdout + result.stderr + self.assertNotIn("|.ELF", combined) + + #: Subclasses set this to True if their backend actually persists files + #: (Ghidra project, IDA database, etc). For in-memory backends like angr + #: it stays False and we only assert "nothing wound up next to the binary". + _persists_project_files: bool = False + + def test_project_dir_keeps_binary_dir_clean(self): + """`--project-dir` should make the backend write its DB outside the binary's dir.""" + with tempfile.TemporaryDirectory() as project_dir, tempfile.TemporaryDirectory() as bin_dir: + # Copy fauxware into an isolated directory so we can verify + # nothing gets written beside it. + local_bin = Path(bin_dir) / "fauxware" + shutil.copyfile(FAUXWARE_PATH, local_bin) + local_bin.chmod(0o755) + before = set(os.listdir(bin_dir)) + + _run_cli("load", str(local_bin), "--backend", self.backend, + "--project-dir", project_dir, "--json") + # Give the backend a beat to finish writing. + _run_cli("list_functions", "--json") + + after = set(os.listdir(bin_dir)) + new_files = after - before + self.assertFalse(new_files, + f"{self.backend} wrote unexpected files beside the binary: {new_files}") + # Backends that actually persist project state (Ghidra, IDA) should + # have written *something* to the override dir; in-memory backends + # (angr) correctly produce no files and that's the whole point — + # there's nothing to place anywhere. + if self._persists_project_files: + project_contents = list(Path(project_dir).rglob("*")) + self.assertTrue(project_contents, + f"{self.backend} wrote nothing to the project_dir") + + # ------------------------------------------------------------------- + # create-type / retype (run against every backend) + # ------------------------------------------------------------------- + + def _direct_client(self): + """Connect a DecompilerClient straight to this binary's server.""" + record = server_registry.find_servers(binary_path=str(FAUXWARE_PATH))[0] + return DecompilerClient(socket_path=record["socket_path"]) + + def _load_fauxware_isolated(self): + """Load fauxware into a fresh, non-hidden project dir. + + Ghidra rejects project *locations* containing a dot-prefixed path + element (e.g. the default ``~/.cache/declib/...``), so hand it a temp + dir. This also keeps the test hermetic — no shared-cache state leaks + in from prior (possibly interrupted) runs. + """ + proj = tempfile.mkdtemp(prefix="declib_cli_proj_") + self.addCleanup(shutil.rmtree, proj, ignore_errors=True) + return self._load_fauxware(project_dir=proj) + + def test_create_type(self): + self._load_fauxware_isolated() + result = _run_cli("create-type", "struct Point { int x; int y; }", "--json") + payload = json.loads(result.stdout) + self.assertEqual(payload["kind"], "Struct") + self.assertEqual(payload["name"], "Point") + self.assertTrue(payload["success"], + f"{self.backend}: create-type failed: {payload}") + + # Verify the struct actually landed, with both named members. + client = self._direct_client() + try: + struct = client.structs["Point"] + finally: + client.shutdown() + self.assertIsNotNone(struct, f"{self.backend}: Point not found after create") + member_names = {m.name for m in struct.members.values()} + self.assertEqual(member_names, {"x", "y"}, + f"{self.backend}: unexpected members {member_names}") + + def test_retype(self): + self._load_fauxware_isolated() + # Pick a 4-byte scalar stack var (an int) and retype it to `float`. + # Same size + scalar->scalar keeps this clean across backends: no + # overlap with the adjacent slot and no array->scalar reshaping (which + # Ghidra handles poorly). + client = self._direct_client() + try: + addrs = [a for a, f in client.functions.items() if f.name == "main"] + main_addr = addrs[0] + main_func = client.functions[main_addr] + svars = list(main_func.stack_vars.values()) + scalars = [v for v in svars + if (v.size or 0) == 4 and "[" not in str(v.type or "")] + if not scalars: + self.skipTest(f"{self.backend}: no 4-byte scalar var in main to retype") + target = scalars[0].name + had_float_before = any("float" in str(v.type or "").lower() for v in svars) + finally: + client.shutdown() + self.assertFalse(had_float_before, + f"{self.backend}: main already has a float var; bad fixture") + + result = _run_cli("retype", "main", target, "float", "--json", check=False) + if result.returncode != 0: + self.skipTest( + f"{self.backend}: retype of {target!r} unsupported: " + f"{result.stdout + result.stderr}" + ) + self.assertTrue(json.loads(result.stdout)["success"]) + + # Verify a float-typed variable now exists. Match on the type appearing + # in the set rather than by name/offset: backends rename a variable by + # its type when retyped (Ghidra local_2c -> fStack_2c). + client = self._direct_client() + try: + refreshed = client.functions[main_addr] + after_types = [str(v.type or "").lower() for v in refreshed.stack_vars.values()] + finally: + client.shutdown() + self.assertTrue(any("float" in t for t in after_types), + f"{self.backend}: no float-typed var after retype; types={after_types}") + + def test_retype_missing_var_exits_1(self): + self._load_fauxware_isolated() + result = _run_cli("retype", "main", "no_such_var_xyz", "int", check=False) + self.assertEqual(result.returncode, 1) + self.assertIn("not found", (result.stdout + result.stderr).lower()) + + +class TestDecompilerCLIAngr(_CLIBackendTestBase): + """angr backend: always available (pure-Python dependency).""" + backend = "angr" + + # angr-specific sanity checks that don't map cleanly to the other + # backends live here. + def test_load_idempotent(self): + first = self._load_fauxware() + second = self._load_fauxware() + self.assertEqual(first["id"], second["id"]) + self.assertEqual(second["status"], "already_loaded") + + def test_multi_instance_same_binary_with_force(self): + first = self._load_fauxware() + forced = _run_cli("load", str(FAUXWARE_PATH), "--backend", "angr", + "--force", "--json") + second = json.loads(forced.stdout) + self.assertNotEqual(first["id"], second["id"]) + + # Ambiguous selection should fail helpfully. + result = _run_cli("decompile", "main", check=False) + self.assertNotEqual(result.returncode, 0) + self.assertIn("Specify --id", result.stdout + result.stderr) + + # Selecting a specific id disambiguates. + ok = _run_cli("decompile", "main", "--id", first["id"]) + self.assertIn("main", ok.stdout) + + def test_load_replace_stops_old_server(self): + first = self._load_fauxware() + replaced_result = _run_cli("load", str(FAUXWARE_PATH), "--backend", "angr", + "--replace", "--json") + replaced = json.loads(replaced_result.stdout) + self.assertEqual(replaced["status"], "started") + self.assertNotEqual(replaced["id"], first["id"]) + listing = _run_cli("list", "--json") + servers = json.loads(listing.stdout)["servers"] + fauxware_servers = [s for s in servers if s["binary_path"] == str(FAUXWARE_PATH)] + self.assertEqual(len(fauxware_servers), 1) + self.assertEqual(fauxware_servers[0]["id"], replaced["id"]) + + def test_client_disconnect_does_not_tear_down_server(self): + """Regression: a client context-exiting must not close the server's project. + + Each `decompiler ` spawns a fresh client, uses it via `with`, and + exits. If the client's `shutdown()` sends `shutdown_deci` to the server, + the next invocation hits a closed program (ClosedException on ghidra). + """ + self._load_fauxware() + for _ in range(3): + result = _run_cli("decompile", "main", "--json") + payload = json.loads(result.stdout) + self.assertIn("text", payload) + + def test_decompile_not_a_function_start(self): + self._load_fauxware() + result = _run_cli("decompile", "0x71e", check=False) + self.assertEqual(result.returncode, 1) + self.assertIn("No function starts at", result.stdout + result.stderr) + + def test_rename_func(self): + self._load_fauxware() + result = _run_cli("rename", "func", "authenticate", "my_auth", "--json") + payload = json.loads(result.stdout) + self.assertTrue(payload["success"]) + + def test_rename_func_missing_exits_1(self): + self._load_fauxware() + result = _run_cli("rename", "func", "nonexistent_fn_xyz", "whatever", + check=False) + self.assertEqual(result.returncode, 1) + + def test_rename_var_missing_exits_1(self): + self._load_fauxware() + result = _run_cli("rename", "var", "no_such_var_xyz", "whatever", + "--function", "main", check=False) + self.assertEqual(result.returncode, 1) + + def test_rename_var(self): + self._load_fauxware() + record = server_registry.find_servers(binary_path=str(FAUXWARE_PATH))[0] + client = DecompilerClient(socket_path=record["socket_path"]) + try: + addrs = [a for a, f in client.functions.items() if f.name == "main"] + main_addr = addrs[0] + main_func = client.functions[main_addr] + names = client.local_variable_names(main_func) + target = next((n for n in names if n not in ("a0", "a1")), names[0]) + finally: + client.shutdown() + + result = _run_cli("rename", "var", target, "renamed_var", + "--function", "main", "--json") + payload = json.loads(result.stdout) + self.assertTrue(payload["success"]) + + def test_list_strings_min_length(self): + self._load_fauxware() + result = _run_cli("list_strings", "--min-length", "20", "--json") + entries = json.loads(result.stdout) + for e in entries: + self.assertGreaterEqual(len(e["string"]), 20) + + def test_stop(self): + loaded = self._load_fauxware() + stop = _run_cli("stop", "--id", loaded["id"], "--json") + payload = json.loads(stop.stdout) + self.assertTrue(payload["stopped"][0]["stopped"]) + listing = _run_cli("list", "--json") + ids = {s["id"] for s in json.loads(listing.stdout)["servers"]} + self.assertNotIn(loaded["id"], ids) + + @unittest.skipUnless(POSIX_SYSCALL_PATH.exists(), f"Missing: {POSIX_SYSCALL_PATH}") + def test_two_binaries_concurrent(self): + first = self._load_fauxware() + second_result = _run_cli("load", str(POSIX_SYSCALL_PATH), "--backend", "angr", "--json") + second = json.loads(second_result.stdout) + self.assertNotEqual(first["id"], second["id"]) + fauxware_strings = _run_cli("list_strings", "--id", first["id"], "--json") + self.assertTrue(any("Welcome" in s["string"] + for s in json.loads(fauxware_strings.stdout))) + + +@unittest.skipUnless(_backend_available("ghidra"), + "ghidra backend not available (no GHIDRA_INSTALL_DIR or pyghidra missing)") +class TestDecompilerCLIGhidra(_CLIBackendTestBase): + """Ghidra backend: same suite as angr, running against real Ghidra.""" + backend = "ghidra" + _persists_project_files = True # Ghidra writes its project under --project-dir + + def test_list_strings_picks_up_uchar_array(self): + """Regression: Ghidra auto-types the base64 alphabet as `uchar[64]` + rather than a string, so ``getDefinedData`` misses it. The + supplemental StringSearcher pass should surface it anyway. + + Skips when the challenge binary isn't checked in (it only ships in + the repo for local reproduction). Using ``pathlib`` rather than + copying the binary into TEST_BINARIES_DIR keeps the repo tidy. + """ + challenge = Path(__file__).parent.parent / "challenge" / "rpc.out" + if not challenge.exists(): + self.skipTest(f"challenge binary missing: {challenge}") + _run_cli("load", str(challenge), "--backend", "ghidra", "--json") + result = _run_cli("list_strings", "--filter", "ABCDEFGHIJKLMN", "--json") + payload = json.loads(result.stdout) + self.assertTrue( + any("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + in s["string"] for s in payload), + f"Ghidra list_strings missed the base64 alphabet: {payload!r}" + ) + + +@unittest.skipUnless(_backend_available("ida"), + "ida backend not available (idapro module missing)") +class TestDecompilerCLIIDA(_CLIBackendTestBase): + """IDA (via idalib) backend: same suite as angr, running against real IDA. + + Mostly a regression test for main-thread dispatch: idalib rejects every + cross-thread API call with ``Function can be called from the main thread + only``, so every CLI round-trip here exercises the dispatcher path — + the client's ``server_info`` handshake included. + """ + backend = "ida" + _persists_project_files = True # .id0/.id1/.id2/.nam/.til + + +# --------------------------------------------------------------------------- +# Cross-decompiler sync: push edits made in IDA into a running Ghidra instance. +# Standalone (not backend-parametrized) because it needs two specific backends. +# --------------------------------------------------------------------------- + +@unittest.skipUnless( + _backend_available("ida") and _backend_available("ghidra"), + "sync IDA->Ghidra tests need both ida (idapro) and ghidra (GHIDRA_INSTALL_DIR)", +) +class TestDecompilerSyncIDAtoGhidra(unittest.TestCase): + """`decompiler sync` copies a function's work from a source server (IDA) + into a destination server (Ghidra) for the same binary.""" + + @classmethod + def setUpClass(cls): + if not FAUXWARE_PATH.exists(): + raise unittest.SkipTest(f"Missing test binary: {FAUXWARE_PATH}") + os.environ["DECLIB_SERVER_REGISTRY"] = _REGISTRY_DIR + _stop_all_servers() + + @classmethod + def tearDownClass(cls): + _stop_all_servers() + + def setUp(self): + # Fresh, isolated project dir per test so a stale/locked backend + # database from a previous (possibly interrupted) run can't make a + # `load` hang or fail. Each backend writes into its own subdir. + self._proj_dir = tempfile.mkdtemp(prefix="declib_sync_proj_") + + def tearDown(self): + _stop_all_servers() + shutil.rmtree(self._proj_dir, ignore_errors=True) + + # -- helpers ----------------------------------------------------------- + + def _load(self, backend): + # `load` blocks until the server is ready (Ghidra analysis included). + out = _run_cli("load", str(FAUXWARE_PATH), "--backend", backend, + "--force", "--project-dir", self._proj_dir, "--json").stdout + payload = json.loads(out) + self.assertIn(payload["status"], ("started", "already_loaded")) + return payload["id"] + + def _client_for(self, server_id): + rec = server_registry.find_server(server_id=server_id) + self.assertIsNotNone(rec, f"no server record for id={server_id}") + return DecompilerClient(socket_path=rec["socket_path"]) + + def _main_addr(self, client): + addrs = [a for a, f in client.functions.items() + if f.name in ("main", "_main")] + if not addrs: + addrs = [a for a in client.functions.keys() if a == 0x71d] + self.assertTrue(addrs, "could not find main on server") + return addrs[0] + + # -- tests ------------------------------------------------------------- + + def test_sync_names_ida_to_ghidra(self): + ida_id = self._load("ida") + ghidra_id = self._load("ghidra") + + # Pick a stack var on the IDA side to rename. + ida = self._client_for(ida_id) + try: + main_addr = self._main_addr(ida) + main_func = ida.functions[main_addr] + self.assertTrue(main_func.stack_vars, "IDA main has no stack vars") + target_off = sorted(main_func.stack_vars.keys())[0] + old_var_name = main_func.stack_vars[target_off].name + finally: + ida.shutdown() + + # Edit in IDA via the CLI: rename the function and the stack var. + # Reference the function by address (stable) rather than by its new + # name, since the light function list can lag a header rename. + main_hex = _format_hex(main_addr) + r1 = _run_cli("rename", "func", main_hex, "synced_main", "--id", ida_id, "--json") + self.assertTrue(json.loads(r1.stdout)["success"]) + r2 = _run_cli("rename", "var", old_var_name, "synced_var", + "--function", main_hex, "--id", ida_id, "--json") + self.assertTrue(json.loads(r2.stdout)["success"]) + + # Sync IDA -> Ghidra (sync takes a function address). + rs = _run_cli("sync", main_hex, "--from-id", ida_id, + "--id", ghidra_id, "--json") + sync_payload = json.loads(rs.stdout) + self.assertTrue(sync_payload["success"], f"sync failed: {sync_payload}") + + # Verify on Ghidra. The function is keyed by addr (Ghidra still calls + # it "main"); the renamed var is matched by canonical stack offset. + ghidra = self._client_for(ghidra_id) + try: + gfunc = ghidra.functions[sync_payload["addr"]] + self.assertEqual(gfunc.name, "synced_main", + f"function name not synced: {gfunc.name}") + var_names = {sv.name for sv in gfunc.stack_vars.values()} + self.assertIn("synced_var", var_names, + f"variable name not synced; ghidra vars: {var_names}") + finally: + ghidra.shutdown() + + def test_sync_types_ida_to_ghidra(self): + ida_id = self._load("ida") + ghidra_id = self._load("ghidra") + + # Pick a stack var on IDA to retype. Use the largest so there's room + # for an 8-byte `Point *` without overlapping the adjacent slot. + ida = self._client_for(ida_id) + try: + main_addr = self._main_addr(ida) + main_func = ida.functions[main_addr] + self.assertTrue(main_func.stack_vars) + biggest = max(main_func.stack_vars.values(), key=lambda v: (v.size or 0)) + target_off = biggest.offset + target_var_name = biggest.name + finally: + ida.shutdown() + + # Feature 1 in IDA: create a struct, then retype a var to a Point pointer. + rc = _run_cli("create-type", "struct Point { int x; int y; }", + "--id", ida_id, "--json") + self.assertEqual(rc.returncode, 0, rc.stderr) + self.assertTrue(json.loads(rc.stdout)["success"]) + main_hex = _format_hex(main_addr) + rt = _run_cli("retype", main_hex, target_var_name, "Point *", + "--id", ida_id, "--json") + self.assertEqual(rt.returncode, 0, rt.stderr) + self.assertTrue(json.loads(rt.stdout)["success"]) + + # Sync IDA -> Ghidra (sync takes a function address). + rs = _run_cli("sync", main_hex, "--from-id", ida_id, "--id", ghidra_id, "--json") + sync_payload = json.loads(rs.stdout) + self.assertTrue(sync_payload["success"], f"sync failed: {sync_payload}") + + # Verify on Ghidra: the struct exists and the var references it. + ghidra = self._client_for(ghidra_id) + try: + self.assertIn("Point", ghidra.structs, + f"Point not in ghidra structs: {list(ghidra.structs.keys())}") + gfunc = ghidra.functions[sync_payload["addr"]] + point_typed = [sv for sv in gfunc.stack_vars.values() + if "Point" in str(sv.type or "")] + self.assertTrue(point_typed, + "no ghidra var references Point: " + f"{[(sv.name, sv.type) for sv in gfunc.stack_vars.values()]}") + finally: + ghidra.shutdown() + + +# --------------------------------------------------------------------------- +# Type-definition parser unit tests: backend-free, cheap to iterate on. +# --------------------------------------------------------------------------- + +class TestTypeDefinitionParser(unittest.TestCase): + def test_struct_offsets_and_size(self): + from declib.api.type_definition_parser import parse_type_definition + s = parse_type_definition("struct Point { int x; int y; }") + self.assertEqual(s.name, "Point") + self.assertEqual(s.members[0].name, "x") + self.assertEqual(s.members[0].size, 4) + self.assertEqual(s.members[4].name, "y") + self.assertEqual(s.size, 8) + + def test_struct_pointer_and_array_members(self): + from declib.api.type_definition_parser import parse_type_definition + s = parse_type_definition("struct S { char *name; int arr[4]; struct Foo *fp; }") + types = {m.name: m.type for m in s.members.values()} + self.assertEqual(types["name"], "char *") + self.assertEqual(types["arr"], "int [4]") + self.assertEqual(types["fp"], "struct Foo *") + + def test_enum(self): + from declib.api.type_definition_parser import parse_type_definition + e = parse_type_definition("enum Color { RED, GREEN=5, BLUE }") + self.assertEqual(dict(e.members), {"RED": 0, "GREEN": 5, "BLUE": 6}) + + def test_typedef(self): + from declib.api.type_definition_parser import parse_type_definition + t = parse_type_definition("typedef char *str_t") + self.assertEqual(t.name, "str_t") + self.assertEqual(t.type, "char *") + + def test_bad_input_raises(self): + from declib.api.type_definition_parser import ( + parse_type_definition, TypeDefinitionParseError, + ) + for bad in ["struct {", "not c @#", "", "struct Empty {}", + "struct A { int a; }; struct B { int b; };"]: + with self.assertRaises(TypeDefinitionParseError): + parse_type_definition(bad) + + +# --------------------------------------------------------------------------- +# Artifact-serialization unit tests: keep these separate from the CLI +# subprocess tests so they run in isolation and are cheap to iterate on. +# --------------------------------------------------------------------------- + +class TestArtifactWireSerialization(unittest.TestCase): + """The client↔server wire format must survive tricky decompilation text. + + Regression for the Ghidra `Reserved escape sequence used` failure: the + `toml` encoder mangles literal `\\x01` escapes that show up in C char + literals. The server now emits JSON on the wire; JSON is stricter about + backslash escaping, so this test locks that behavior in. + """ + + def test_decompilation_with_backslash_x_roundtrip_json(self): + from declib.artifacts import Decompilation + from declib.artifacts.formatting import ArtifactFormat + + # Exactly the kind of text Ghidra emits when decompiling code that + # compares a byte to a control character: `if (c == '\x01')`. + text = "if (c == '\\x01') { return 42; }" + dec = Decompilation(addr=0x1000, text=text, decompiler="ghidra") + + encoded = dec.dumps(fmt=ArtifactFormat.JSON) + decoded = Decompilation.loads(encoded, fmt=ArtifactFormat.JSON) + self.assertEqual(decoded.text, text) + self.assertEqual(decoded.addr, 0x1000) + + def test_decompilation_toml_still_fails_on_backslash_x(self): + """Document WHY we moved off TOML — if this ever starts working we + can reconsider, but in the meantime it's load-bearing for the fix.""" + from declib.artifacts import Decompilation + from declib.artifacts.formatting import ArtifactFormat + import toml + + text = "if (c == '\\x01') { return 42; }" + dec = Decompilation(addr=0x1000, text=text, decompiler="ghidra") + encoded = dec.dumps(fmt=ArtifactFormat.TOML) + with self.assertRaises(toml.decoder.TomlDecodeError): + Decompilation.loads(encoded, fmt=ArtifactFormat.TOML) + + +class TestCLIFormatters(unittest.TestCase): + """Sanity tests for the small pure-Python helpers in the CLI.""" + + def test_format_addr_hex_handles_negative(self): + """Regression for Ghidra surfacing negative-signed-long section addrs.""" + from declib.cli.decompiler_cli import _format_addr_hex + + # Positive values render as-is. + self.assertEqual(_format_addr_hex(0x400), "0x400") + # Negative values wrap to unsigned 64-bit, never emit '0x-...'. + rendered = _format_addr_hex(-0x100000) + self.assertTrue(rendered.startswith("0x")) + self.assertNotIn("-", rendered) + self.assertEqual(rendered, f"0x{((-0x100000) & ((1 << 64) - 1)):x}") + + def test_annotate_addrs_uses_safe_hex(self): + from declib.cli.decompiler_cli import _annotate_addrs + + payload = {"addr": -0x100000, "target_addr": 0x1000} + annotated = _annotate_addrs(payload) + self.assertNotIn("-", annotated["addr_hex"]) + self.assertEqual(annotated["target_addr_hex"], "0x1000") + + +# --------------------------------------------------------------------------- +# Skill installer tests +# --------------------------------------------------------------------------- + +class TestSkillInstaller(unittest.TestCase): + """The bundled `decompiler` skill should ship with the package and install cleanly.""" + + def test_bundled_skill_present(self): + from declib import skills + + names = skills.available_skills() + self.assertIn("decompiler", names) + skill = skills.skill_path("decompiler") / "SKILL.md" + content = skill.read_text() + self.assertIn("name: decompiler", content) + self.assertIn("decompiler load", content) + + def test_install_skill_via_cli(self): + with tempfile.TemporaryDirectory() as dest: + result = _run_cli("install-skill", "--dest", dest, "--json") + payload = json.loads(result.stdout) + self.assertEqual(len(payload["installed"]), 1) + installed_path = Path(payload["installed"][0]["path"]) + self.assertEqual(payload["installed"][0]["agent"], "custom") + self.assertTrue((installed_path / "SKILL.md").is_file()) + + # Re-install without --force should fail helpfully. + again = _run_cli("install-skill", "--dest", dest, "--json", check=False) + self.assertNotEqual(again.returncode, 0) + + # --force overwrites. + forced = _run_cli("install-skill", "--dest", dest, "--json", "--force") + self.assertEqual(len(json.loads(forced.stdout)["installed"]), 1) + + def test_install_skill_text_output_is_parsable(self): + with tempfile.TemporaryDirectory() as dest: + result = _run_cli("install-skill", "--dest", dest) + self.assertNotIn("[{'name'", result.stdout) + self.assertIn("decompiler", result.stdout) + + def test_install_skill_agent_destinations(self): + with tempfile.TemporaryDirectory() as home, tempfile.TemporaryDirectory() as codex_home: + result = _run_cli( + "install-skill", + "--agent", "all", + "--json", + env_overrides={"HOME": home, "CODEX_HOME": codex_home}, + ) + payload = json.loads(result.stdout) + installed = {entry["agent"]: Path(entry["path"]) for entry in payload["installed"]} + self.assertEqual(set(installed), {"claude", "codex"}) + self.assertEqual(installed["claude"], + (Path(home) / ".claude" / "skills" / "decompiler").resolve()) + self.assertEqual(installed["codex"], + (Path(codex_home) / "skills" / "decompiler").resolve()) + + def test_install_skill_default_prefers_codex_under_codex(self): + with tempfile.TemporaryDirectory() as home, tempfile.TemporaryDirectory() as codex_home: + result = _run_cli( + "install-skill", + "--json", + env_overrides={"HOME": home, "CODEX_HOME": codex_home, "CODEX_CI": "1"}, + ) + installed = json.loads(result.stdout)["installed"] + self.assertEqual(len(installed), 1) + self.assertEqual(installed[0]["agent"], "codex") + + def test_install_skill_default_falls_back_to_claude(self): + codex_vars = { + "CODEX_CI": None, "CODEX_HOME": None, "CODEX_MANAGED_BY_NPM": None, + "CODEX_SANDBOX": None, "CODEX_SANDBOX_NETWORK_DISABLED": None, + "CODEX_THREAD_ID": None, + } + with tempfile.TemporaryDirectory() as home: + result = _run_cli( + "install-skill", + "--json", + env_overrides={"HOME": home, **codex_vars}, + ) + installed = json.loads(result.stdout)["installed"] + self.assertEqual(len(installed), 1) + self.assertEqual(installed[0]["agent"], "claude") + + def test_install_skill_dest_and_agent_are_mutually_exclusive(self): + with tempfile.TemporaryDirectory() as dest: + result = _run_cli("install-skill", "--dest", dest, "--agent", "codex", + check=False) + self.assertNotEqual(result.returncode, 0) + self.assertIn("--dest cannot be combined with --agent", + result.stdout + result.stderr) + + +# --------------------------------------------------------------------------- +# Direct library-level tests (don't need the CLI + subprocess machinery) +# --------------------------------------------------------------------------- + +@unittest.skipUnless(FAUXWARE_PATH.exists(), f"Missing test binary: {FAUXWARE_PATH}") +class TestNewDecLibFeatures(unittest.TestCase): + """Direct tests for list_strings, get_callers, disassemble, xref_from, xref_to_addr.""" + + @classmethod + def setUpClass(cls): + cls.deci = DecompilerInterface.discover( + force_decompiler="angr", + headless=True, + binary_path=str(FAUXWARE_PATH), + ) + + def test_list_strings(self): + strings = self.deci.list_strings() + self.assertGreater(len(strings), 0) + + welcome = self.deci.list_strings(filter=r"Welcome") + self.assertEqual(len(welcome), 1) + self.assertIn("Welcome", welcome[0][1]) + self.assertEqual(self.deci.list_strings(filter=r"zzz_no_match"), []) + + def test_disassemble(self): + addrs = [a for a, f in self.deci.functions.items() if f.name == "main"] + text = self.deci.disassemble(addrs[0]) + self.assertTrue(any(mnem in text for mnem in ("push", "mov", "call"))) + + def test_get_callers_by_addr_name_and_function(self): + addrs_by_name = {f.name: a for a, f in self.deci.functions.items()} + auth_addr = addrs_by_name["authenticate"] + + by_addr = self.deci.get_callers(auth_addr) + by_name = self.deci.get_callers("authenticate") + self.assertEqual({f.addr for f in by_addr}, {f.addr for f in by_name}) + with self.assertRaises(ValueError): + self.deci.get_callers("no_such_function_xyz") + + def test_xrefs_from_returns_callees(self): + """xrefs_from(main) should include authenticate, puts, read, etc.""" + addrs_by_name = {f.name: a for a, f in self.deci.functions.items()} + main_addr = addrs_by_name["main"] + callees = self.deci.xrefs_from(main_addr) + callee_names = {c.name for c in callees if c.name} + self.assertTrue( + callee_names & {"authenticate", "puts", "read", "accepted", "rejected"}, + f"expected a known callee in {callee_names}" + ) + + def test_xrefs_to_addr_on_string(self): + """xrefs_to_addr on the SOSNEAKY constant should point at authenticate.""" + strings = self.deci.list_strings(filter=r"SOSNEAKY") + self.assertTrue(strings, "SOSNEAKY not found in angr strings") + str_addr = strings[0][0] + refs = self.deci.xrefs_to_addr(str_addr) + ref_names = {getattr(r, "name", None) for r in refs} + self.assertIn("authenticate", ref_names, + f"expected 'authenticate' in xrefs_to_addr(SOSNEAKY): {ref_names}") + + def test_read_memory(self): + """read_memory should return the ELF magic at the binary's base.""" + # ELF magic at lifted addr 0 + elf = self.deci.read_memory(0, 4) + self.assertEqual(elf, b"\x7fELF") + + # Welcome string — find via list_strings, then read its bytes. + strings = self.deci.list_strings(filter=r"Welcome") + self.assertTrue(strings, "Welcome string not found") + welcome_addr = strings[0][0] + bytes_ = self.deci.read_memory(welcome_addr, 7) + self.assertEqual(bytes_, b"Welcome") + + # Out-of-range read should return None. + self.assertIsNone(self.deci.read_memory(0xdeadbeef00, 16)) + + # Zero/negative size short-circuit. + self.assertEqual(self.deci.read_memory(0, 0), b"") + self.assertEqual(self.deci.read_memory(0, -5), b"") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_decompilers.py b/tests/test_decompilers.py new file mode 100644 index 00000000..1f5f560f --- /dev/null +++ b/tests/test_decompilers.py @@ -0,0 +1,922 @@ +import json +import logging +import subprocess +import tempfile +import time +import unittest +from pathlib import Path +from collections import defaultdict +import os + +from declib.api import DecompilerInterface +from declib.artifacts import FunctionHeader, StackVariable, Struct, GlobalVariable, Enum, Comment, ArtifactFormat, \ + Decompilation, Function, StructMember, Typedef, Segment +from declib.decompilers import IDA_DECOMPILER, ANGR_DECOMPILER, BINJA_DECOMPILER, GHIDRA_DECOMPILER + +GHIDRA_HEADLESS_PATH = Path(os.environ.get('GHIDRA_INSTALL_DIR', "")) / "support" / "analyzeHeadless" +IDA_HEADLESS_PATH = Path(os.environ.get('IDA_HEADLESS_PATH', "")) + +if os.getenv("TEST_BINARIES_DIR"): + TEST_BINARIES_DIR = Path(os.getenv("TEST_BINARIES_DIR")) +else: + # default assumes its a git repo that is above this one + TEST_BINARIES_DIR = Path(__file__).parent.parent.parent / "bs-artifacts" / "binaries" + +assert TEST_BINARIES_DIR.exists(), f"Test binaries dir {TEST_BINARIES_DIR} does not exist" + + +_l = logging.getLogger(__name__) + +def custom_load_ida(binary_path: Path, extra_args: list[str] | None = None, delete_old_idb=True) -> None: + try: + import idapro + except ImportError: + import ida as idapro + idat_path = Path(idapro.__file__).parent / "bin/idat64" + assert idat_path.exists(), "IDA executable not found, this cannot run" + + # first, assure no idb currently reside there + idb_path = binary_path.with_name(binary_path.name + ".i64") + if delete_old_idb: + if idb_path.exists(): + idb_path.unlink() + + # Command: idat64 -A -B /path/to/binary [extra args] + # construct the command + cmd = [str(idat_path), "-A"] + if extra_args: + cmd.extend(extra_args) + cmd.extend(["-B", str(binary_path)]) + + subprocess.check_call(cmd) + + # verify the idb was created + assert idb_path.exists(), "IDA database was not created" + + +class TestHeadlessInterfaces(unittest.TestCase): + FAUXWARE_PATH = TEST_BINARIES_DIR / "fauxware" + RENAMED_NAME = "binsync_main" + + def setUp(self): + self.deci = None + + def tearDown(self): + if self.deci is not None: + self.deci.shutdown() + + def test_readme_example(self): + # TODO: add angr + for dec_name in [IDA_DECOMPILER, GHIDRA_DECOMPILER, BINJA_DECOMPILER]: + deci = DecompilerInterface.discover( + force_decompiler=dec_name, + headless=True, + binary_path=TEST_BINARIES_DIR / "posix_syscall", + ) + self.deci = deci + changed_addrs = set() + # set it + for addr in deci.functions: + function = deci.functions[addr] + if function.header.type == "void": + function.header.type = "int" + deci.functions[function.addr] = function + changed_addrs.add(function.addr) + + # now check that it really was set for AT LEAST one + # note: this is not a guarantee that it was set for all, type setting can fail + success = 0 + no_voids = not bool(changed_addrs) + for addr in deci.functions: + if addr not in changed_addrs: + continue + + function = deci.functions[addr] + if function.type == "int": + success += 1 + + assert no_voids | success > 0, "Failed to set function type for any functions" + deci.shutdown() + + def test_getting_artifacts(self): + # TODO: add angr + for dec_name in [IDA_DECOMPILER, GHIDRA_DECOMPILER, BINJA_DECOMPILER]: + deci = DecompilerInterface.discover( + force_decompiler=dec_name, + headless=True, + binary_path=TEST_BINARIES_DIR / "posix_syscall", + ) + self.deci = deci + + # list all the different artifacts + json_strings = [] + for func in deci.functions.values(): + json_strings.append(func.dumps(fmt=ArtifactFormat.JSON)) + # verify decompilation works + dec_func: Function = deci.functions[func.addr] + assert dec_func is not None + dec_json: dict = json.loads(dec_func.dumps(fmt=ArtifactFormat.JSON)) + assert dec_json.get("header", {}).get("type", None) is not None + + for struct in deci.structs.values(): + json_strings.append(struct.dumps(fmt=ArtifactFormat.JSON)) + for enum in deci.enums.values(): + json_strings.append(enum.dumps(fmt=ArtifactFormat.JSON)) + for gvar in deci.global_vars.values(): + json_strings.append(gvar.dumps(fmt=ArtifactFormat.JSON)) + for comment in deci.comments.values(): + json_strings.append(comment.dumps(fmt=ArtifactFormat.JSON)) + for typedef in deci.typedefs.values(): + json_strings.append(typedef.dumps(fmt=ArtifactFormat.JSON)) + + # validate each one is not corrupted + for json_str in json_strings: + json.loads(json_str) + + deci.shutdown() + + def test_ghidra_types(self): + with tempfile.TemporaryDirectory() as temp_dir: + proj_name = "fdupes_ghidra" + + deci = DecompilerInterface.discover( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=TEST_BINARIES_DIR / 'fdupes', + project_location=Path(temp_dir), + project_name=proj_name, + ) + self.deci = deci + + # get decompiled function 'getcrcsignatureuntil' + func = deci.functions[0x1d66] + + # verify that the second argument is just a normal type name, and not a 'typedef ...' + type_name, _ = deci.art_lifter.parse_scoped_type(func.header.args[1].type) + assert type_name == "off_t" + assert "typedef" not in func.header.args[1].type + + # grab the size of a type that is used as an arg in function '_init' + func = deci.functions[0x1120] + arg0 = func.header.args[0] + assert arg0.size == 8, "Unexpected arg size for _init arg0, it is a pointer!" + + def test_ghidra_artifact_dependency_resolving(self): + with tempfile.TemporaryDirectory() as temp_dir: + proj_name = "fdupes_ghidra" + + deci = DecompilerInterface.discover( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=TEST_BINARIES_DIR / 'fdupes', + project_location=Path(temp_dir), + project_name=proj_name, + ) + self.deci = deci + light_funcs = {addr: func for addr, func in deci.functions.items()} + md5_process_func = deci.art_lifter.lift_addr(0x1036f4) + + # dont decompile the function to test it is decompiled on demand, however + # a normal use case would be to decompile it first + auth_func = light_funcs[md5_process_func] + initial_deps = deci.get_dependencies(auth_func) + for art in initial_deps: + assert art is not None + assert art.dumps(fmt=ArtifactFormat.JSON) is not None + + assert len(initial_deps) == 4 + # check the deps + struct_cnt = 0 + typedef_cnt = 0 + for dep in initial_deps: + if isinstance(dep, Struct): + struct_cnt += 1 + assert dep.name == "md5_state_s", "Unexpected struct" + assert len(dep.members) == 3, "Unexpected number of members" + elif isinstance(dep, Typedef): + typedef_cnt += 1 + assert dep.name in {"md5_word_t", "md5_state_t", "md5_byte_t"}, "Unexpected typedef" + assert struct_cnt == 1 + assert typedef_cnt == 3 + + # test a case of dependency resolving where we have a func arg with a multi-defined type + # the type in this case is '__off64_t' which is defined in types.h and DWARF + # the correct one to be used is the one from DWARF + func = deci.functions[0x1d66] + deps = deci.get_dependencies(func) + off64t_types = [d for d in deps if isinstance(d, Typedef) and d.name.endswith("__off64_t")] + assert len(off64t_types) == 1 + off64t_type = off64t_types[0] + assert off64t_type.scope == "DWARF" + + + # TODO: right now in headless Ghidra you cant ever set structs to variable types. + # This is a limitation of the headless decompiler, not the API. + # now create two structs that reference each other + # + # struct A { + # struct B *b; + # }; + # + # struct B { + # struct A *a; + # int size; + # }; + # + + #struct_a = Struct( + # name="A", + # members={ + # 0: StructMember(name="b", type_="B*", offset=0, size=8) + # }, + # size=8 + #) + #struct_b = Struct( + # name="B", + # members={ + # 0: StructMember(name="a", type_="A*", offset=0, size=8), + # 1: StructMember(name="size", type_="int", offset=8, size=4) + # }, + # size=12 + #) + + ## first add the structs to the decompiler, empty, so both names can exist + #deci.structs[struct_a.name] = Struct(name=struct_a.name, size=struct_a.size) + #deci.structs[struct_b.name] = Struct(name=struct_b.name, size=struct_b.size) + + ## now add the members to the structs + #deci.structs[struct_a.name] = struct_a + #deci.structs[struct_b.name] = struct_b + + ## now change a stack variable to be of type A + #auth_func = deci.functions[auth_func_addr] + #auth_func.stack_vars[-24].type = "A*" + #deci.functions[auth_func_addr] = auth_func + ## refresh the decompilation + #auth_func = deci.functions[auth_func_addr] + + ## now get the dependencies again + #new_deps = deci.get_dependencies(auth_func) + #assert len(new_deps) == 3 + deci.shutdown() + + # Test another case of dependency resolving where we have a function that looks like this: + # 1. A custom-typed function argument (typedef) + # 2. The typedef points to a struct + # 3. The pointed to struct is empty + with tempfile.TemporaryDirectory() as temp_dir: + deci = DecompilerInterface.discover( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=TEST_BINARIES_DIR / "posix_syscall", + project_location=Path(temp_dir), + project_name="posix_syscall_ghidra", + ) + self.deci = deci + + start_func = deci.functions[deci.art_lifter.lift_addr(0x100740)] + deps = deci.get_dependencies(start_func) + assert len(deps) == 3 + typdefs = [d for d in deps if isinstance(d, Typedef)] + assert len(typdefs) == 1 + typdef = typdefs[0] + assert typdef.name == "EVP_PKEY_CTX" + type_name, type_scope = self.deci.art_lifter.parse_scoped_type(typdef.type) + assert type_name == "evp_pkey_ctx_st" + structs = [d for d in deps if isinstance(d, Struct)] + assert len(structs) == 1 + struct = structs[0] + assert struct.name == "evp_pkey_ctx_st" + + deci.shutdown() + + def test_fauxware(self): + # TODO: add support for everyone else, but more specifically, IDA! + # there is a problem right now with how function args are set in IDA + for dec_name in [GHIDRA_DECOMPILER]: + deci = DecompilerInterface.discover( + force_decompiler=dec_name, + headless=True, + binary_path=self.FAUXWARE_PATH, + ) + self.deci = deci + + func_addr = deci.art_lifter.lift_addr(0x400664) + main = deci.functions[func_addr] + main.name = self.RENAMED_NAME + deci.functions[func_addr] = main + assert deci.functions[func_addr].name == self.RENAMED_NAME + + # + # Structs + # + + func_args = main.header.args + func_args[0].name = "new_name_1" + func_args[0].type = "int" + func_args[0].size = 4 # set manually to avoid resetting the size in the caller + func_args[1].name = "new_name_2" + func_args[1].type = "double" + func_args[1].size = 8 + deci.functions[func_addr] = main + assert deci.functions[func_addr].header.args == func_args + + eh_hdr_struct = deci.structs['eh_frame_hdr'] + eh_hdr_struct.name = "my_struct_name" + eh_hdr_struct.members[0].type = 'char' + eh_hdr_struct.members[1].type = 'char' + deci.structs['eh_frame_hdr'] = eh_hdr_struct + updated = deci.structs[eh_hdr_struct.name] + assert updated.name == eh_hdr_struct.name + assert updated.members[0].type == 'char' + assert updated.members[1].type == 'char' + + # + # Enums + # + + elf_dyn_tag_enum: Enum = deci.enums['ELF::Elf64_DynTag'] + elf_dyn_tag_enum.members['DT_YEET'] = elf_dyn_tag_enum.members['DT_FILTER'] + 1 + deci.enums[elf_dyn_tag_enum.name] = elf_dyn_tag_enum + assert deci.enums[elf_dyn_tag_enum.scoped_name] == elf_dyn_tag_enum + + enum = Enum("my_enum", {"member1": 0, "member2": 1}) + deci.enums[enum.name] = enum + assert deci.enums[enum.name] == enum + + nested_enum = Enum("nested_enum", {"field": 0, "another_field": 2, "third_field": 3}, scope="SomeEnums") + deci.enums[nested_enum.scoped_name] = nested_enum + assert deci.enums[nested_enum.scoped_name] == nested_enum + + # + # Typedefs + # + + # simple typedef + typedef = Typedef("my_typedef", "int") + deci.typedefs[typedef.name] = typedef + assert deci.typedefs[typedef.name] == typedef + + # typedef to a struct + typedef = Typedef("my_eh_frame_hdr", eh_hdr_struct.scoped_name) + deci.typedefs[typedef.name] = typedef + assert deci.typedefs[typedef.name] == typedef + + # typedef to an enum + typedef = Typedef("my_elf_dyn_tag", elf_dyn_tag_enum.scoped_name) + deci.typedefs[typedef.name] = typedef + updated_typedef = deci.typedefs[typedef.name] + assert updated_typedef.name == typedef.name + + # gvar_addr = deci.art_lifter.lift_addr(0x4008e0) + # g1 = deci.global_vars[gvar_addr] + # g1.name = "gvar1" + # deci.global_vars[gvar_addr] = g1 + # assert deci.global_vars[gvar_addr] == g1 + + stack_var = main.stack_vars[-24] + stack_var.name = "named_char_array" + stack_var.type = 'double' + deci.functions[func_addr] = main + assert deci.functions[func_addr].stack_vars[-24] == stack_var + + # + # Test Random APIs + # + + func_size = deci.get_func_size(func_addr) + assert func_size != -1 + + # + # Test Artifact Deletion + # + + eh_hdr_struct = deci.structs['my_struct_name'] + del deci.structs['my_struct_name'] + struct_items = deci.structs.items() + struct_keys = [k for k, v in struct_items] + struct_values = [v for k, v in struct_items] + assert eh_hdr_struct.name not in struct_keys and eh_hdr_struct not in struct_values + + deci.shutdown() + + def test_ghidra_project_loading(self): + with tempfile.TemporaryDirectory() as tmpdir: + proj_name = "posix_syscall_ghidra" + binary_path = TEST_BINARIES_DIR / "posix_syscall" + + start_load = time.time() + deci = DecompilerInterface.discover( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=binary_path, + project_location=tmpdir, + project_name=proj_name, + ) + slow_load_time = time.time() - start_load + first_funcs = list(deci.functions.values()) + deci.shutdown() + + start_load = time.time() + # load it by just reading the project + deci = DecompilerInterface.discover( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + project_location=tmpdir, + project_name=proj_name, + analyze=False, + ) + fast_load_time = time.time() - start_load + self.deci = deci + second_funcs = list(deci.functions.values()) + + assert first_funcs == second_funcs + assert slow_load_time > fast_load_time + + def test_angr(self): + deci = DecompilerInterface.discover( + force_decompiler=ANGR_DECOMPILER, + headless=True, + binary_path=self.FAUXWARE_PATH + ) + self.deci = deci + func_addr = deci.art_lifter.lift_addr(0x400664) + main = deci.functions[func_addr] + main.name = self.RENAMED_NAME + deci.functions[func_addr] = main + assert deci.functions[func_addr].name == self.RENAMED_NAME + assert self.RENAMED_NAME in deci.main_instance.project.kb.functions + + # + # Struct support + # + + # test struct creation + new_struct = Struct() + new_struct.name = "my_angr_struct" + new_struct.add_struct_member('char_member', 0, 'char', 1) + new_struct.add_struct_member('int_member', 1, 'int', 4) + deci.structs[new_struct.name] = new_struct + + # verify struct was created + updated = deci.structs[new_struct.name] + assert updated is not None, "Struct was not created" + assert updated.name == new_struct.name + + # verify members are present + assert 0 in updated.members, "First member not found" + assert 1 in updated.members, "Second member not found" + + # test struct listing + struct_items = list(deci.structs.items()) + struct_names = [k for k, v in struct_items] + assert new_struct.name in struct_names, "Struct not found in listing" + + # + # Stack variable type setting + # + + # Get the main function which has stack variables + main_func_addr = deci.art_lifter.lift_addr(0x40071d) + main_func = deci.functions[main_func_addr] + + # Check that we have stack variables + assert len(main_func.stack_vars) > 0, "No stack variables found in main function" + + # Get the first stack variable and change its type to a primitive + first_offset = list(main_func.stack_vars.keys())[0] + original_svar = main_func.stack_vars[first_offset] + + # Set a new type (change to int) + original_svar.type = "int" + deci.functions[main_func_addr] = main_func + + # Verify the type was set by re-fetching the function + updated_func = deci.functions[main_func_addr] + updated_svar = updated_func.stack_vars.get(first_offset) + assert updated_svar is not None, "Stack variable not found after update" + # The type should contain "int" (may be formatted differently by angr) + assert "int" in updated_svar.type.lower() if updated_svar.type else False, \ + f"Stack variable type was not updated to int, got: {updated_svar.type}" + + # + # Stack variable type setting with struct type + # + + # Re-fetch the function to get fresh stack variables + main_func = deci.functions[main_func_addr] + + # Get a stack variable (use the same one or another if available) + svar_offsets = list(main_func.stack_vars.keys()) + struct_test_offset = svar_offsets[0] if len(svar_offsets) == 1 else svar_offsets[1] + struct_test_svar = main_func.stack_vars[struct_test_offset] + + # Set the type to a pointer to our struct + struct_ptr_type = f"struct {new_struct.name} *" + struct_test_svar.type = struct_ptr_type + deci.functions[main_func_addr] = main_func + + # Verify the struct type was set + updated_func = deci.functions[main_func_addr] + updated_svar = updated_func.stack_vars.get(struct_test_offset) + assert updated_svar is not None, "Stack variable not found after struct type update" + assert updated_svar.type is not None, "Stack variable type is None after struct type update" + # The type should contain the struct name + assert new_struct.name in updated_svar.type, \ + f"Stack variable type was not updated to struct pointer, got: {updated_svar.type}" + + # Now test struct deletion (after we're done using it for stack var types) + del deci.structs[new_struct.name] + struct_items = list(deci.structs.items()) + struct_keys = [k for k, v in struct_items] + assert new_struct.name not in struct_keys, "Struct was not deleted" + + deci.shutdown() + + def test_binja(self): + deci = DecompilerInterface.discover( + force_decompiler=BINJA_DECOMPILER, + headless=True, + binary_path=self.FAUXWARE_PATH + ) + func_addr = deci.art_lifter.lift_addr(0x400664) + func_authenticate = deci.functions[func_addr] + func_authenticate.name = self.RENAMED_NAME + + # test renaming a function + deci.functions[func_addr] = func_authenticate + assert deci.functions[func_addr].name == self.RENAMED_NAME + + # test strucr creation + new_struct = Struct() + new_struct.name = "my_new_struct" + new_struct.add_struct_member('char_member', 0, 'char', 1) + new_struct.add_struct_member('int_member', 1, 'int', 4) + deci.structs[new_struct.name] = new_struct + + updated = deci.structs[new_struct.name] + assert updated.name == new_struct.name + assert updated.members[0].type == 'char' + assert updated.members[1].type == 'int' + + # test some typedef stuff + new_typedef = Typedef(name="my_int", type_="int") + deci.typedefs[new_typedef.name] = new_typedef + assert deci.typedefs[new_typedef.name] == new_typedef + + new_typedef = Typedef(name="my_int_t", type_="my_int") + deci.typedefs[new_typedef.name] = new_typedef + assert deci.typedefs[new_typedef.name] == new_typedef + + # test function arg change + func_main = deci.functions[deci.art_lifter.lift_addr(0x40071d)] + func_main.header.args[0].name = "my_arg" + # this arg is normally char** argv, so we can retype to another pointer + new_struct_type = new_struct.name + "*" + func_main.header.args[1].type = new_struct_type + + deci.functions[func_main.addr] = func_main + assert deci.functions[func_main.addr].header.args[0].name == "my_arg" + current_struct_type = deci.functions[func_main.addr].header.args[1].type + current_struct_type = current_struct_type.replace("struct ", "").replace(" ", "") + assert current_struct_type == new_struct_type + + # test struct deletion + del deci.structs[new_struct.name] + struct_items = deci.structs.items() + struct_keys = [k for k, v in struct_items] + struct_values = [v for k, v in struct_items] + assert new_struct.name not in struct_keys and new_struct not in struct_values + + def test_decompile_api(self): + # TODO: put angr back when it is greater than 9.2.165 + for dec_name in [IDA_DECOMPILER, BINJA_DECOMPILER, GHIDRA_DECOMPILER]: + deci = DecompilerInterface.discover( + force_decompiler=dec_name, + headless=True, + binary_path=TEST_BINARIES_DIR / "fauxware", + ) + self.deci = deci + main_func_addr = deci.art_lifter.lift_addr(0x40071d) + decompilation = deci.decompile(main_func_addr, map_lines=True) + + assert decompilation is not None, f"Decompilation failed for {dec_name}" + assert decompilation.decompiler == deci.name + assert decompilation.addr == main_func_addr + assert decompilation.text is not None + print_username_line = 'puts("Username: ");' + assert print_username_line in decompilation.text + + line_no = [line.strip() for line in decompilation.text.splitlines()].index(print_username_line) + assert bool(decompilation.line_map) is True + + correct_addr = deci.art_lifter.lift_addr(0x400739) + # TODO: fix the mapping for everyone except IDA... everything is off-by-one in some way + if dec_name == BINJA_DECOMPILER: + line_no -= 1 + if dec_name in [GHIDRA_DECOMPILER, ANGR_DECOMPILER]: + line_no += 1 + + assert line_no in decompilation.line_map + assert correct_addr in decompilation.line_map[line_no] + + self.deci.shutdown() + + def test_fast_function_api(self): + for dec_name in [GHIDRA_DECOMPILER, BINJA_DECOMPILER, ANGR_DECOMPILER, IDA_DECOMPILER]: + deci = DecompilerInterface.discover( + force_decompiler=dec_name, + headless=True, + binary_path=TEST_BINARIES_DIR / "fauxware", + ) + self.deci = deci + main_func_addr = deci.art_lifter.lift_addr(0x40071d) + main_func = deci.fast_get_function(main_func_addr) + assert main_func is not None + assert main_func.name is not None + + self.deci.shutdown() + + def test_ghidra_type_scoping(self): + """ + Scopes help distinguish between types with the same name but different namespaces. + In most decompilers, there is no such thing as a type scope, i.e., a type is always scoped to the global + and the name must be unique. + + In Ghidra, however, types can be scoped to a specific category, which is a way to group types together. + Types looks like this: `/CategoryLayer1/CategorLayery2/my_type`. + We need to save types from Ghidra in such a way that: + 1. The scope is preserved when the type is saved, so that other Ghidra instances can load it and use it. + 2. The type can be used without the scope, i.e., the type can be used as if it was a global type. + """ + # first use ghidra to load types from a debug sym binary + ghidra_deci = DecompilerInterface.discover( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=TEST_BINARIES_DIR / "fauxware", + ) + self.deci = ghidra_deci + + # + # Typedefs + # + + custom_type = Typedef(name="my_int", type_="int", scope="MyCategory") + ghidra_deci.typedefs[custom_type.scoped_name] = custom_type + assert custom_type.scope == "MyCategory" + # use special scoped name to access the scoped type + assert ghidra_deci.typedefs[custom_type.scoped_name] == custom_type + + # define another custom type that has no scope + custom_type_no_scope = Typedef(name="my_int", type_="int") + ghidra_deci.typedefs[custom_type_no_scope.name] = custom_type_no_scope + assert custom_type_no_scope.scope is None + # use the normal name to access the type, since it has no scope (scoped works too) + assert ghidra_deci.typedefs[custom_type_no_scope.name] == custom_type_no_scope + assert ghidra_deci.typedefs[custom_type_no_scope.scoped_name] == custom_type_no_scope + + # make sure both are in the lister + all_typedefs = list(ghidra_deci.typedefs.items()) + assert (custom_type.scoped_name, custom_type) in all_typedefs + assert (custom_type_no_scope.scoped_name, custom_type_no_scope) in all_typedefs + + # + # Structs + # + + custom_type = Struct(name="my_struct", size=0, scope="MyCategory") + ghidra_deci.structs[custom_type.scoped_name] = custom_type + assert custom_type.scope == "MyCategory" + assert ghidra_deci.structs[custom_type.scoped_name] == custom_type + + custom_type_no_scope = Struct(name="my_struct", size=0) + ghidra_deci.structs[custom_type_no_scope.name] = custom_type_no_scope + assert custom_type_no_scope.scope is None + assert ghidra_deci.structs[custom_type_no_scope.name] == custom_type_no_scope + + all_structs = list(ghidra_deci.structs.items()) + assert (custom_type.scoped_name, custom_type) in all_structs + assert (custom_type_no_scope.scoped_name, custom_type_no_scope) in all_structs + + # + # Enums + # + + custom_type = Enum(name="my_enum", members={}, scope="MyCategory") + ghidra_deci.enums[custom_type.scoped_name] = custom_type + assert custom_type.scope == "MyCategory" + assert ghidra_deci.enums[custom_type.scoped_name] == custom_type + custom_type_no_scope = Enum(name="my_enum", members={}) + ghidra_deci.enums[custom_type_no_scope.name] = custom_type_no_scope + assert custom_type_no_scope.scope is None + assert ghidra_deci.enums[custom_type_no_scope.name] == custom_type_no_scope + all_enums = list(ghidra_deci.enums.values()) + assert custom_type in all_enums + assert custom_type_no_scope in all_enums + + # + # Get dependencies check for overlapping name use + # + + custom_type = Typedef(name="my_int", type_="int", scope="MyCategory") + custom_type_no_scope = Typedef(name="my_int", type_="int") + main_func = ghidra_deci.functions[0x71d] + main_func.type = custom_type.scoped_name + main_func.stack_vars[-0x2c].type = custom_type_no_scope.scoped_name + + # refresh the function to be sure it set + ghidra_deci.functions[main_func.addr] = main_func + main_func = ghidra_deci.functions[main_func.addr] + + deps = ghidra_deci.get_dependencies(main_func) + assert len(deps) == 2 + + ghidra_deci.shutdown() + + def test_ghidra_to_ida_transfer(self): + # first use ghidra to load types from a debug sym binary + ghidra_deci = DecompilerInterface.discover( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=TEST_BINARIES_DIR / "debug_symbol", + ) + debug_func = ghidra_deci.functions[0x1249] + debug_types = ghidra_deci.get_dependencies(debug_func) + for debug_type in debug_types: + if isinstance(debug_type, Typedef) and debug_type.name.endswith("_IO_lock_t"): + break + else: + raise RuntimeError("Failed to find the expected typedef") + ghidra_deci.shutdown() + + ida_deci = DecompilerInterface.discover( + force_decompiler=IDA_DECOMPILER, + headless=True, + binary_path=TEST_BINARIES_DIR / "debug_symbol_mod_stripped", + ) + # since this type is already native to IDA, even without symbols, we need to change the name + debug_type.name += "_new" + assert debug_type.name not in ida_deci.typedefs + + # now add the type to IDA + ida_deci.typedefs[debug_type.name] = debug_type + + # verify it was added + assert debug_type.name in ida_deci.typedefs + ida_deci.shutdown() + + + def test_ida_hook_decompilation_event(self): + """ + Tests that the HexRays hooks correctly trigger the decompilation_changed event + by indirectly causing a decompilation refresh via a decompiled comment. + """ + ida_deci = DecompilerInterface.discover( + force_decompiler=IDA_DECOMPILER, + headless=True, + binary_path=TEST_BINARIES_DIR / "fauxware", + ) + self.deci = ida_deci + + # initialize hooks + ida_deci.start_artifact_watchers() + ida_deci._thread_artifact_callbacks = False + + # register a callback to observe decompilation changes + event_triggered = False + + def on_decompilation_change(decompilation, **kwargs): + nonlocal event_triggered + event_triggered = True + assert decompilation.addr is not None + assert decompilation.text is not None + assert decompilation.decompiler == "ida" + + ida_deci.artifact_change_callbacks[Decompilation].append(on_decompilation_change) + + # TODO: uncomment the below when IDA 9.2 is put in CI so comment setting works headlessly + # trigger a decompilation update indirectly through a decompiled comment + #ida_deci.comments[1821] = Comment(addr=1821, comment="test comment!", func_addr=1821, decompiled=True) + #ida_deci.shutdown() + #assert event_triggered, "Decompilation change event was not triggered" + + def test_ida_segment(self): + """ + Test segment CRUD operations specifically for IDA Pro. + This tests the new segment syncing functionality. + """ + ida_deci = DecompilerInterface.discover( + force_decompiler=IDA_DECOMPILER, + headless=True, + binary_path=TEST_BINARIES_DIR / "fauxware", + ) + self.deci = ida_deci + + # Get initial segments to avoid conflicts + initial_segments = list(ida_deci.segments.keys()) + + # Create a test segment + test_segment_name = "BSSEG" + test_segment = Segment( + name=test_segment_name, + # addresses + start_addr=0x6010c0, + end_addr=0x6010c0 + 0x40, + ) + + # Ensure test segment doesn't exist initially + assert test_segment_name not in ida_deci.segments.keys() + + # Set the segment (this should create it in IDA) + ida_deci.segments[test_segment_name] = test_segment + + # Test 2: Read the segment back + retrieved_segment = ida_deci.segments[test_segment_name] + assert retrieved_segment is not None + assert retrieved_segment.name == test_segment.name + assert retrieved_segment.start_addr == test_segment.start_addr + assert retrieved_segment.end_addr == test_segment.end_addr + # Note: permissions might not be preserved exactly in IDA, so we don't assert on them + + # Test 3: List all segments (should include our new one) + all_segments = ida_deci.segments.keys() + assert test_segment_name in all_segments + assert len(all_segments) == len(initial_segments) + 1 + + # Test 4: Modify the segment + modified_segment = retrieved_segment.copy() + modified_segment.end_addr = modified_segment.end_addr + 0x40 + ida_deci.segments[test_segment_name] = modified_segment + + # Verify modification + updated_segment = ida_deci.segments[test_segment_name] + assert updated_segment.end_addr == modified_segment.end_addr + assert updated_segment.size == test_segment.size + 0x40 + + # Test 5: Test serialization of IDA segments + segment_toml = updated_segment.dumps() + loaded_segment = Segment.loads(segment_toml) + assert loaded_segment == updated_segment + + # Test 6: Delete the segment + del ida_deci.segments[test_segment_name] + + # Verify deletion + final_segments = ida_deci.segments._artifact_lister() + assert test_segment_name not in final_segments + assert len(final_segments) == len(initial_segments) + + # Test 7: Try to access deleted segment (should raise KeyError) + try: + _ = ida_deci.segments[test_segment_name] + assert False, "Expected KeyError when accessing deleted segment" + except KeyError: + pass # Expected behavior + + ida_deci.shutdown() + + def test_firmware_base_addrs(self): + binary_path = TEST_BINARIES_DIR / "i2c_master_read-arduino_mzero.hex" + + # Load an armel binary in the hex format. Because IDA 9 (the version we are stuck on) does not support + # loading armel binaries in headless mode, we first need to load it with idat64 and save it. + custom_load_ida(binary_path, extra_args=["-pARM"], delete_old_idb=True) + + # function for setting at already known lifted addr + my_func = Function(addr=0x214, name="my_func") + + # In https://github.com/binsync/binsync/issues/425 we found that IDA and Binja would report different + # base addresses (0 and non-0) for the same firmware, along with static-looking addresses for this firmware + # Now, this testcase should verify that both IDA and Binja report the same base address (0x4000). + for dec_name in [IDA_DECOMPILER, BINJA_DECOMPILER, GHIDRA_DECOMPILER]: + deci = DecompilerInterface.discover( + force_decompiler=dec_name, + headless=True, + binary_path=binary_path, + language="ARM:LE:32:v7" + ) + self.deci = deci + assert deci.binary_base_addr == 0x4000, f"Unexpected base addr {hex(deci.binary_base_addr)} for {dec_name}" + + # check a few function addresses to be sure, Ghidra does no identify the same functions, so we choose alternates that overla[ + expected_func_addrs = {0x40DC, 0x41d8, 0x4214} + for lower_addr in expected_func_addrs: + if dec_name == GHIDRA_DECOMPILER and lower_addr == 0x40DC: + # skip this one, Ghidra does not identify it for some reason + continue + + lifted_addr = deci.art_lifter.lift_addr(lower_addr) + try: + exists = self.deci.functions[lifted_addr] # verify it exists + except KeyError: + exists = None + assert exists is not None, f"Failed to find function at {hex(lower_addr)} in {dec_name}" + + # attempt to set a function at a known lifted addr + deci.functions[my_func.addr] = my_func + assert deci.functions[my_func.addr].name == my_func.name + + deci.shutdown() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file