From b72c54ae9f125a0ba9b661487a44b621692518aa Mon Sep 17 00:00:00 2001 From: Tyson Smith Date: Wed, 22 Apr 2026 10:08:56 -0700 Subject: [PATCH 1/3] ci: add type hints to FTB --- .github/workflows/ci.yml | 3 + .pre-commit-config.yaml | 9 + FTB/AssertionHelper.py | 63 ++--- FTB/ConfigurationFiles.py | 12 +- FTB/CoverageHelper.py | 27 ++- FTB/ProgramConfiguration.py | 19 +- FTB/Running/AutoRunner.py | 85 ++++--- FTB/Running/GDB.py | 31 ++- FTB/Running/PersistentApplication.py | 120 +++++---- FTB/Running/StreamCollector.py | 17 +- FTB/Running/WaitpidMonitor.py | 8 +- FTB/Signatures/CrashInfo.py | 348 +++++++++++++++++---------- FTB/Signatures/CrashSignature.py | 98 ++++---- FTB/Signatures/JSONHelper.py | 39 ++- FTB/Signatures/Matchers.py | 74 +++--- FTB/Signatures/RegisterHelper.py | 16 +- FTB/Signatures/Symptom.py | 118 +++++---- pyproject.toml | 13 +- tox.ini | 7 + 19 files changed, 674 insertions(+), 433 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 698ee451..5dfbb643 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,9 @@ jobs: - name: Install pre-commit run: pipx install pre-commit + - name: Install tox + run: pipx install tox + - name: Run linters run: pre-commit run -a diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 52669f9a..1f7d0506 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,3 +23,12 @@ repos: language: system files: ^server/frontend/.*\.(js|mjs|cjs|vue)$ stages: [ pre-commit ] + - id: mypy + name: mypy + entry: tox -e mypy -- + language: system + require_serial: true + files: ^FTB/ + exclude: (^|/)tests/ + types: [python] + pass_filenames: false diff --git a/FTB/AssertionHelper.py b/FTB/AssertionHelper.py index 75b7d57f..df5e260e 100644 --- a/FTB/AssertionHelper.py +++ b/FTB/AssertionHelper.py @@ -25,7 +25,7 @@ RE_V8_END = re.compile(r"^") -def getAssertion(output): +def getAssertion(output: list[str]) -> str | list[str] | None: """ This helper method provides a way to extract and process the different types of assertions from a given buffer. @@ -35,8 +35,8 @@ def getAssertion(output): @type output: list @param output: List of strings to be searched """ - lastLine = None - endRegex = None + lastLine: str | list[str] | None = None + endRegex: re.Pattern[str] | None = None # Use this to ignore the ASan head line in case of an assertion haveFatalAssertion = False @@ -53,6 +53,7 @@ def getAssertion(output): line = re.sub(RE_PID, "", line, count=1) if endRegex is not None: + assert isinstance(lastLine, list) lastLine.append(line) if endRegex.search(line) is not None: endRegex = None @@ -129,7 +130,7 @@ def getAssertion(output): return lastLine -def getAuxiliaryAbortMessage(output): +def getAuxiliaryAbortMessage(output: list[str]) -> str | list[str] | None: """ This helper method provides a way to extract and process additional abort messages or other useful messages produced by helper tools like @@ -139,7 +140,7 @@ def getAuxiliaryAbortMessage(output): @type output: list @param output: List of strings to be searched """ - lastLine = None + lastLine: str | list[str] | None = None needASanRW = False needTSanRW = False @@ -161,6 +162,7 @@ def getAuxiliaryAbortMessage(output): lastLine = line.strip() needASanRW = True elif needASanRW and ("READ of size" in line or "WRITE of size" in line): + assert isinstance(lastLine, str) lastLine = [lastLine] lastLine.append(line) needASanRW = False @@ -177,6 +179,7 @@ def getAuxiliaryAbortMessage(output): elif needTSanRW and re.match( r"\s*(?:Previous )?(?:[Aa]tomic )?(?:[Rr]ead|[Ww]rite) of size", line ): + assert isinstance(lastLine, list) lastLine.append(line.strip()) elif "glibc detected" in line: # Aborts caused by glibc runtime error detection @@ -188,7 +191,7 @@ def getAuxiliaryAbortMessage(output): return lastLine -def getSanitizedAssertionPattern(msgs): +def getSanitizedAssertionPattern(msgs: str | list[str]) -> str | list[str]: """ This method provides a way to strip out unwanted dynamic information from assertions and replace it with pattern matching elements, e.g. @@ -211,7 +214,7 @@ def getSanitizedAssertionPattern(msgs): for msg in msgs: # remember the position of all backslashes in the input - bsPositions = [] + bsPositions: list[int] = [] for chunk in msg.split("\\"): if not bsPositions: bsPositions.append(len(chunk)) @@ -231,31 +234,28 @@ def getSanitizedAssertionPattern(msgs): idx += len(chunk) + 1 bsPositions = [bs + 1 if bs > idx else bs for bs in bsPositions] - # Each entry is (match_regex, replacement_text). For most patterns the - # two are identical, but path patterns match with a boundary-restricted - # class while writing a clean `.+/` into the sanitized output. - replacementPatterns = [] + # Each entry is (match_regex, replacement_text). A replacement_text of + # None means "use the match_regex as the replacement" + replacementPatterns: list[tuple[str, str | None]] = [] # Specific TSan patterns - replacementPatterns.append("(Previous )?[Aa]tomic [Rr]ead of size") - replacementPatterns.append("(Previous )?[Aa]tomic [Ww]rite of size") - replacementPatterns.append("(Previous )?[Rr]ead of size") - replacementPatterns.append("(Previous )?[Ww]rite of size") + replacementPatterns.append(("(Previous )?[Aa]tomic [Rr]ead of size", None)) + replacementPatterns.append(("(Previous )?[Aa]tomic [Ww]rite of size", None)) + replacementPatterns.append(("(Previous )?[Rr]ead of size", None)) + replacementPatterns.append(("(Previous )?[Ww]rite of size", None)) # We avoid the use of parentheses here because they would be double-escaped - replacementPatterns.append("thread T[0-9]+( .+mutexes: .+)?:") - replacementPatterns.append("by main thread( .+mutexes: .+)?:") + replacementPatterns.append(("thread T[0-9]+( .+mutexes: .+)?:", None)) + replacementPatterns.append(("by main thread( .+mutexes: .+)?:", None)) # Replace everything that looks like a memory address - replacementPatterns.append("0x[0-9a-fA-F]+") + replacementPatterns.append(("0x[0-9a-fA-F]+", None)) # Strip line numbers as they can easily change across versions - replacementPatterns.append("(:[0-9]+)+") - replacementPatterns.append(", line [0-9]+") + replacementPatterns.append(("(:[0-9]+)+", None)) + replacementPatterns.append((", line [0-9]+", None)) # Replace rust thread #s - replacementPatterns.append("Thread#[0-9]+' panicked") - - replacementPatterns = [(p, p) for p in replacementPatterns] + replacementPatterns.append(("Thread#[0-9]+' panicked", None)) # Strip full paths. Match using a boundary-restricted class so we don't # greedily consume text preceding the path, but emit `.+/` as the @@ -268,16 +268,21 @@ def getSanitizedAssertionPattern(msgs): # spaces, quotes and comma are the only things used in the assertions # we support so far. However, we don't want to group these characters # into a regex so avoid cluttering the signature too much. - for prefix in (" ", "'", '"', ","): - replacementPatterns.append((prefix + pathMatch, prefix + pathReplace)) + replacementPatterns.extend( + (prefix + pathMatch, prefix + pathReplace) + for prefix in (" ", "'", '"', ",") + ) # Replace larger numbers, assuming that 1-digit numbers are likely # some constant that doesn't need sanitizing. - replacementPatterns.append(("[0-9]{2,}", "[0-9]{2,}")) + replacementPatterns.append(("[0-9]{2,}", None)) - for matchPattern, replacementPattern in replacementPatterns: + for matchPattern, replacementText in replacementPatterns: + replacementPattern = ( + matchPattern if replacementText is None else replacementText + ) - def _handleMatch(match): + def _handleMatch(match: re.Match[str]) -> str: start = match.start(0) end = match.end(0) lengthDiff = len(replacementPattern) - len(match.group(0)) @@ -324,7 +329,7 @@ def _handleMatch(match): return sanitizedMsgs -def escapePattern(msg): +def escapePattern(msg: str) -> str: """ This method escapes regular expression characters in the string. And no, this is not re.escape, which would escape many more characters. diff --git a/FTB/ConfigurationFiles.py b/FTB/ConfigurationFiles.py index e1f7ca18..c6f7120b 100755 --- a/FTB/ConfigurationFiles.py +++ b/FTB/ConfigurationFiles.py @@ -18,15 +18,15 @@ class ConfigurationFiles: - def __init__(self, configFiles): - self.mainConfig = {} - self.metadataConfig = {} + def __init__(self, configFiles: list[str] | None) -> None: + self.mainConfig: dict[str, str] = {} + self.metadataConfig: dict[str, str] = {} if configFiles: self.parser = configparser.ConfigParser() # Make sure keys are kept case-sensitive - self.parser.optionxform = str + self.parser.optionxform = str # type: ignore[method-assign,assignment] self.parser.read(configFiles) self.mainConfig = self.getSectionMap("Main") @@ -46,8 +46,8 @@ def __init__(self, configFiles): file=sys.stderr, ) - def getSectionMap(self, section): - ret = {} + def getSectionMap(self, section: str) -> dict[str, str]: + ret: dict[str, str] = {} try: options = self.parser.options(section) except configparser.NoSectionError: diff --git a/FTB/CoverageHelper.py b/FTB/CoverageHelper.py index 2bc54778..435b3547 100644 --- a/FTB/CoverageHelper.py +++ b/FTB/CoverageHelper.py @@ -13,9 +13,10 @@ """ import re +from typing import Any -def merge_coverage_data(r, s): +def merge_coverage_data(r: dict[str, Any], s: dict[str, Any]) -> dict[str, int]: # These variables are mainly for debugging purposes. We count the number # of warnings we encounter during merging, which are mostly due to # bugs in GCOV. These statistics can be included in the report description @@ -26,7 +27,7 @@ def merge_coverage_data(r, s): "coverable_mismatch_count": 0, } - def merge_recursive(r, s): + def merge_recursive(r: dict[str, Any], s: dict[str, Any]) -> None: assert r["name"] == s["name"] if "children" in s: @@ -110,7 +111,7 @@ def merge_recursive(r, s): return stats -def calculate_summary_fields(node, name=None): +def calculate_summary_fields(node: dict[str, Any], name: str | None = None) -> None: node["name"] = name node["linesTotal"] = 0 node["linesCovered"] = 0 @@ -145,7 +146,9 @@ def calculate_summary_fields(node, name=None): node["coveragePercent"] = 0.0 -def apply_include_exclude_directives(node, directives): +def apply_include_exclude_directives( + node: dict[str, Any], directives: list[str] +) -> None: """ Applies the given include and exclude directives to the given nodeself. Directives either start with a + or a - for include or exclude, followed @@ -174,7 +177,7 @@ def apply_include_exclude_directives(node, directives): # # ** are left as a string # patterns are converted to regex and compile - directives_new = [ + directives_new: list[tuple[str, list[str | re.Pattern[str]]]] = [ ("+", ["**"]) ] # start with an implicit +:** so we don't have to handle the empty case for directive in directives: @@ -188,7 +191,7 @@ def apply_include_exclude_directives(node, directives): what, pattern = directive.split(":", 1) if what not in "+-": raise RuntimeError("Unexpected directive prefix: " + what) - parts = [] + parts: list[str | re.Pattern[str]] = [] for part in pattern.split("/"): if part == "**": parts.append(part) @@ -206,10 +209,12 @@ def apply_include_exclude_directives(node, directives): parts.append(re.compile(part)) directives_new.append((what, parts)) - def _is_dir(node): + def _is_dir(node: dict[str, Any]) -> bool: return "children" in node - def __apply_include_exclude_directives(node, directives): + def __apply_include_exclude_directives( + node: dict[str, Any], directives: list[tuple[str, Any]] + ) -> None: if not _is_dir(node): return @@ -332,7 +337,7 @@ def __apply_include_exclude_directives(node, directives): __apply_include_exclude_directives(node, directives_new) -def get_flattened_names(node, prefix=""): +def get_flattened_names(node: dict[str, str | None], prefix: str = "") -> set[str]: """ Returns a list of flattened paths (files and directories) of the given node. @@ -349,7 +354,9 @@ def get_flattened_names(node, prefix=""): @rtype: list(str) """ - def __get_flattened_names(node, prefix, result): + def __get_flattened_names( + node: dict[str, Any], prefix: str, result: set[str] + ) -> set[str]: current_name = node["name"] if current_name is None: new_prefix = "" diff --git a/FTB/ProgramConfiguration.py b/FTB/ProgramConfiguration.py index 42dda91d..a8c824a6 100644 --- a/FTB/ProgramConfiguration.py +++ b/FTB/ProgramConfiguration.py @@ -23,8 +23,15 @@ class ProgramConfiguration: def __init__( - self, product, platform, os, version=None, env=None, args=None, metadata=None - ): + self, + product: str, + platform: str, + os: str, + version: str | None = None, + env: dict[str, str] | None = None, + args: list[str] | None = None, + metadata: dict[str, str] | None = None, + ) -> None: """ @type product: string @param product: The name of the product/program/branch tested @@ -56,7 +63,7 @@ def __init__( self.metadata = metadata @staticmethod - def fromBinary(binaryPath): + def fromBinary(binaryPath: str) -> "ProgramConfiguration | None": binaryConfig = f"{binaryPath}.fuzzmanagerconf" if not os.path.exists(binaryConfig): print( @@ -87,7 +94,7 @@ def fromBinary(binaryPath): metadata=config.metadataConfig, ) - def addEnvironmentVariables(self, env): + def addEnvironmentVariables(self, env: dict[str, str]) -> None: """ Add (additional) environment variable definitions. Existing definitions will be overwritten if they are redefined in the given environment. @@ -98,7 +105,7 @@ def addEnvironmentVariables(self, env): assert isinstance(env, dict) self.env.update(env) - def addProgramArguments(self, args): + def addProgramArguments(self, args: list[str]) -> None: """ Add (additional) program arguments. @@ -108,7 +115,7 @@ def addProgramArguments(self, args): assert isinstance(args, list) self.args.extend(args) - def addMetadata(self, metadata): + def addMetadata(self, metadata: dict[str, str]) -> None: """ Add (additional) metadata definitions. Existing definitions will be overwritten if they are redefined in the given metadata. diff --git a/FTB/Running/AutoRunner.py b/FTB/Running/AutoRunner.py index b2660677..7209d412 100644 --- a/FTB/Running/AutoRunner.py +++ b/FTB/Running/AutoRunner.py @@ -32,13 +32,17 @@ class AutoRunner(metaclass=ABCMeta): for running the given program and obtaining crash information. """ - def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): + def __init__( + self, + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + stdin: str | list[str] | None = None, + ) -> None: self.binary = binary self.cwd = cwd - self.stdin = stdin - - if self.stdin and isinstance(self.stdin, list): - self.stdin = "\n".join(self.stdin) + self.stdin = "\n".join(stdin) if isinstance(stdin, list) else stdin # Certain debuggers like GDB can run into problems when certain # environment variables are missing. Hence we copy the system environment @@ -51,28 +55,32 @@ def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): if "LD_LIBRARY_PATH" not in self.env: self.env["LD_LIBRARY_PATH"] = os.path.dirname(binary) - self.args = args - if self.args is None: - self.args = [] + self.args = args or [] assert isinstance(self.env, dict) assert isinstance(self.args, list) # The command that we will run for obtaining crash information - self.cmdArgs = [] + self.cmdArgs: list[str] = [] # These will hold our results from running - self.stdout = None - self.stderr = None - self.auxCrashData = None + self.stdout: str | None = None + self.stderr: str | None = None + self.auxCrashData: str | None = None - def getCrashInfo(self, configuration): + def getCrashInfo(self, configuration: ProgramConfiguration) -> CrashInfo: return CrashInfo.fromRawCrashData( self.stdout, self.stderr, configuration, self.auxCrashData ) @staticmethod - def fromBinaryArgs(binary, args=None, env=None, cwd=None, stdin=None): + def fromBinaryArgs( + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + stdin: str | list[str] | None = None, + ) -> "AutoRunner": process = subprocess.Popen( ["nm", "-g", binary], stdin=subprocess.PIPE, @@ -82,8 +90,8 @@ def fromBinaryArgs(binary, args=None, env=None, cwd=None, stdin=None): env=env, ) - (stdout, _) = process.communicate() - stdout = stdout.decode("utf-8", errors="ignore") + stdout_bytes, _ = process.communicate() + stdout = stdout_bytes.decode("utf-8", errors="ignore") force_gdb = bool(os.environ.get("FTB_FORCE_GDB", False)) @@ -91,13 +99,21 @@ def fromBinaryArgs(binary, args=None, env=None, cwd=None, stdin=None): stdout.find(" __asan_init") >= 0 or stdout.find("__ubsan_default_options") >= 0 ): - return ASanRunner(binary, args, env, cwd, stdin) + return ASanRunner(binary, args=args, env=env, cwd=cwd, stdin=stdin) - return GDBRunner(binary, args, env, cwd, stdin) + return GDBRunner(binary, args=args, env=env, cwd=cwd, stdin=stdin) class GDBRunner(AutoRunner): - def __init__(self, binary, args=None, env=None, cwd=None, core=None, stdin=None): + def __init__( + self, + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + core: str | None = None, + stdin: str | list[str] | None = None, + ) -> None: AutoRunner.__init__(self, binary, args, env, cwd, stdin) # This can be used to force GDBRunner to first generate a core and then @@ -146,7 +162,7 @@ def __init__(self, binary, args=None, env=None, cwd=None, core=None, stdin=None) else: self.cmdArgs.extend(self.args) - def run(self): + def run(self) -> bool: if self.force_core: plainCmdArgs = [self.binary] plainCmdArgs.extend(self.args) @@ -162,7 +178,9 @@ def run(self): core = f"core.{process.pid}" - (plainStdout, plainStderr) = process.communicate(input=self.stdin) + plainStdout, plainStderr = process.communicate( + input=self.stdin.encode() if self.stdin else None + ) if os.path.isfile(core): self.cmdArgs.append(core) @@ -185,7 +203,9 @@ def run(self): env=self.env, ) - (stdout, stderr) = process.communicate(input=self.stdin) + stdout, stderr = process.communicate( + input=self.stdin.encode() if self.stdin else None + ) self.stdout = stdout.decode("utf-8", errors="ignore") self.stderr = stderr.decode("utf-8", errors="ignore") @@ -217,7 +237,14 @@ def run(self): class ASanRunner(AutoRunner): - def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): + def __init__( + self, + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + stdin: str | list[str] | None = None, + ) -> None: AutoRunner.__init__(self, binary, args, env, cwd, stdin) self.cmdArgs.append(self.binary) @@ -231,9 +258,10 @@ def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): os.path.dirname(binary), "llvm-symbolizer" ) if not os.path.isfile(self.env["ASAN_SYMBOLIZER_PATH"]): - self.env["ASAN_SYMBOLIZER_PATH"] = shutil.which("llvm-symbolizer") - if not self.env["ASAN_SYMBOLIZER_PATH"]: + llvm_symbolizer = shutil.which("llvm-symbolizer") + if not llvm_symbolizer: raise RuntimeError("Unable to locate llvm-symbolizer") + self.env["ASAN_SYMBOLIZER_PATH"] = llvm_symbolizer if not os.path.isfile(self.env["ASAN_SYMBOLIZER_PATH"]): raise RuntimeError( @@ -263,7 +291,7 @@ def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): # for bucketing. This is helpful when assertions are hit in debug builds self.env["ASAN_OPTIONS"] = "allocator_may_return_null=1:handle_abort=1" - def run(self): + def run(self) -> bool: tmpd = Path(mkdtemp(prefix="fm-autorun-")) try: env = self.env.copy() @@ -275,9 +303,10 @@ def run(self): # create a ProgramConfiguration just to create the temporary CrashInfo pc = ProgramConfiguration.fromBinary(self.binary) + assert pc is not None process = subprocess.run( self.cmdArgs, - stdin=self.stdin, + input=self.stdin, capture_output=True, text=True, cwd=self.cwd, @@ -286,8 +315,8 @@ def run(self): self.stdout = process.stdout self.stderr = process.stderr + self.auxCrashData = None first = True - self.auxCrashData = [] for crash in tmpd.iterdir(): self.auxCrashData = crash.read_text() if not first: diff --git a/FTB/Running/GDB.py b/FTB/Running/GDB.py index 02214775..814d1e94 100644 --- a/FTB/Running/GDB.py +++ b/FTB/Running/GDB.py @@ -12,34 +12,39 @@ @contact: choller@mozilla.com """ +from typing import TYPE_CHECKING -def is64bit(): - return not str(gdb.parse_and_eval("$rax")) == "void" # noqa @UndefinedVariable +if TYPE_CHECKING: + import gdb # noqa: TC004 -def isARM(): - return not str(gdb.parse_and_eval("$r0")) == "void" # noqa @UndefinedVariable +def is64bit() -> bool: + return str(gdb.parse_and_eval("$rax")) != "void" -def isARM64(): - return not str(gdb.parse_and_eval("$x0")) == "void" # noqa @UndefinedVariable +def isARM() -> bool: + return str(gdb.parse_and_eval("$r0")) != "void" -def regAsHexStr(reg): +def isARM64() -> bool: + return str(gdb.parse_and_eval("$x0")) != "void" + + +def regAsHexStr(reg: str) -> str: mask = 0xFFFFFFFFFFFFFFFF if is64bit() else 0xFFFFFFFF - val = int(str(gdb.parse_and_eval("$" + reg)), 0) & mask # noqa @UndefinedVariable + val = int(str(gdb.parse_and_eval("$" + reg)), 0) & mask return f"0x{val:x}" -def regAsIntStr(reg): - return str(int(str(gdb.parse_and_eval("$" + reg)), 0)) # noqa @UndefinedVariable +def regAsIntStr(reg: str) -> str: + return str(int(str(gdb.parse_and_eval("$" + reg)), 0)) -def regAsRaw(reg): - return str(gdb.parse_and_eval("$" + reg)) # noqa @UndefinedVariable +def regAsRaw(reg: str) -> str: + return str(gdb.parse_and_eval("$" + reg)) -def printImportantRegisters(): +def printImportantRegisters() -> None: if is64bit(): regs = [ "rax", diff --git a/FTB/Running/PersistentApplication.py b/FTB/Running/PersistentApplication.py index 6a2f2f68..430c507a 100644 --- a/FTB/Running/PersistentApplication.py +++ b/FTB/Running/PersistentApplication.py @@ -24,16 +24,20 @@ import subprocess import time from abc import ABCMeta +from enum import Enum, IntEnum, auto from FTB.Running.StreamCollector import StreamCollector from FTB.Running.WaitpidMonitor import WaitpidMonitor -class ApplicationStatus: - OK, ERROR, TIMEDOUT, CRASHED = range(1, 5) +class ApplicationStatus(IntEnum): + OK = 1 + ERROR = 2 + TIMEDOUT = 3 + CRASHED = 4 -class PersistentMode: +class PersistentMode(Enum): """ Persistent fuzzing mode - determines how the program synchronizes the execution of multiple testcases in one process. @@ -59,7 +63,9 @@ class PersistentMode: if no synchronization via stdin is possible """ - NONE, SPFP, SIGSTOP = range(1, 4) + NONE = auto() + SPFP = auto() + SIGSTOP = auto() class PersistentApplication(metaclass=ABCMeta): @@ -69,14 +75,14 @@ class PersistentApplication(metaclass=ABCMeta): def __init__( self, - binary, - args=None, - env=None, - cwd=None, - persistentMode=PersistentMode.NONE, - processingTimeout=10, - inputFile=None, - ): + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + persistentMode: PersistentMode = PersistentMode.NONE, + processingTimeout: int = 10, + inputFile: str | None = None, + ) -> None: self.binary = binary self.cwd = cwd @@ -86,9 +92,7 @@ def __init__( for envkey in env: self.env[envkey] = env[envkey] - self.args = args - if self.args is None: - self.args = [] + self.args = args or [] assert isinstance(self.env, dict) assert isinstance(self.args, list) @@ -103,10 +107,10 @@ def __init__( self.inputFile = inputFile # Various variables holding information about the program - self.process = None - self.stdout = None - self.stderr = None - self.testLog = None + self.process: subprocess.Popen[str] | None = None + self.stdout: list[str] | None = None + self.stderr: list[str] | None = None + self.testLog: list[str] | None = None # This string will be used to prefix spfp inputs and can be set # to e.g. a comment string prefix for the target input ('//') @@ -115,19 +119,20 @@ def __init__( self.spfpPrefix = "" self.spfpSuffix = "" # To support - def start(self, test=None): + def start(self, test: str | None = None) -> int | None: pass - def stop(self): + def stop(self) -> None: pass - def runTest(self, test): + def runTest(self, test: str) -> int | None: pass - def status(self): + def status(self) -> int | None: pass - def _crashed(self): + def _crashed(self) -> bool: + assert self.process is not None if self.process.returncode < 0: crashSignals = [ # POSIX.1-1990 signals @@ -151,32 +156,35 @@ def _crashed(self): class SimplePersistentApplication(PersistentApplication): def __init__( self, - binary, - args=None, - env=None, - cwd=None, - persistentMode=PersistentMode.NONE, - processingTimeout=10, - inputFile=None, - ): + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + persistentMode: PersistentMode = PersistentMode.NONE, + processingTimeout: int = 10, + inputFile: str | None = None, + ) -> None: PersistentApplication.__init__( self, binary, args, env, cwd, persistentMode, processingTimeout, inputFile ) # Used to store the second return value if waitpid, which has the real exit code - self.childExit = None + self.childExit: int | None = None # These will hold our StreamCollectors for stdout/err - self.outCollector = None - self.errCollector = None + self.outCollector: StreamCollector | None = None + self.errCollector: StreamCollector | None = None - def _write_log_test(self, test): + def _write_log_test(self, test: str) -> None: + assert self.testLog is not None self.testLog.append(test) if self.inputFile: with open(self.inputFile, "w") as inputFileFd: inputFileFd.write(test) elif self.persistentMode == PersistentMode.SPFP: + assert self.process is not None + assert self.process.stdin is not None # This won't work with pure binary data, but SPFP mode isn't suitable for # that in general print(test, file=self.process.stdin) @@ -185,16 +193,22 @@ def _write_log_test(self, test): file=self.process.stdin, ) elif self.persistentMode == PersistentMode.SIGSTOP: + assert self.process is not None + assert self.process.stdin is not None # Shameless copycat, oh hai lcamtuf ;) - os.ftruncate(self.process.stdin, len(test)) - os.lseek(self.process.stdin, 0, os.SEEK_SET) + + os.ftruncate(self.process.stdin.fileno(), len(test)) + os.lseek(self.process.stdin.fileno(), 0, os.SEEK_SET) self.process.stdin.write(test) self.process.stdin.flush() else: + assert self.process is not None + assert self.process.stdin is not None self.process.stdin.write(test) self.process.stdin.close() - def _wait_child_stopped(self): + def _wait_child_stopped(self) -> bool: + assert self.process is not None monitor = WaitpidMonitor(self.process.pid, os.WUNTRACED) monitor.start() monitor.join(self.processingTimeout) @@ -209,7 +223,7 @@ def _wait_child_stopped(self): return True - def start(self, test=None): + def start(self, test: str | None = None) -> int | None: assert self.process is None or self.process.poll() is not None # Reset the test log @@ -239,11 +253,13 @@ def start(self, test=None): # This queue is used to queue up responses that should be directly processed # by this class rather than being logged. - self.responseQueue = queue.Queue() + self.responseQueue: queue.Queue[str] = queue.Queue() + assert self.process.stdout is not None self.outCollector = StreamCollector( self.process.stdout, self.responseQueue, logResponses=False, maxBacklog=256 ) + assert self.process.stderr is not None self.errCollector = StreamCollector( self.process.stderr, self.responseQueue, logResponses=False, maxBacklog=256 ) @@ -292,10 +308,12 @@ def start(self, test=None): ) else: if not self.inputFile: + assert test is not None self._write_log_test(test) # Assume PersistentMode.NONE and expect the process to exit now - (maxSleepTime, pollInterval) = (self.processingTimeout, 0.2) + maxSleepTime = float(self.processingTimeout) + pollInterval = 0.2 while self.process.poll() is None and maxSleepTime > 0: maxSleepTime -= pollInterval time.sleep(pollInterval) @@ -314,25 +332,27 @@ def start(self, test=None): # Also terminates the process in case of a timeout. self.stop() - return ret + return int(ret) return None - def stop(self): + def stop(self) -> None: self._terminateProcess() # Ensure we leave no dangling threads when stopping if self.outCollector is not None: # errCollector is expected to be set when outCollector is self.outCollector.join() + assert self.errCollector is not None self.errCollector.join() # Make the output available self.stdout = self.outCollector.output self.stderr = self.errCollector.output - def runTest(self, test): + def runTest(self, test: str) -> int | None: if self.process is None or self.process.poll() is not None: self.start() + assert self.process is not None # Write test data and also log it self._write_log_test(test) @@ -374,6 +394,8 @@ def runTest(self, test): ) # Update stdout/err available for the last run + assert self.outCollector is not None + assert self.errCollector is not None self.stdout = self.outCollector.output self.stderr = self.errCollector.output @@ -396,10 +418,13 @@ def runTest(self, test): return ApplicationStatus.TIMEDOUT # Update stdout/err available for the last run + assert self.outCollector is not None + assert self.errCollector is not None self.stdout = self.outCollector.output self.stderr = self.errCollector.output if self.process.poll() is not None: + assert self.childExit is not None exitCode = self.childExit >> 8 signalNum = self.childExit & 0xFF @@ -415,14 +440,15 @@ def runTest(self, test): return ApplicationStatus.OK return None - def _terminateProcess(self): + def _terminateProcess(self) -> None: if self.process and self.process.poll() is None: # Try to terminate the process gracefully first self.process.terminate() # Emulate a wait() with timeout. Because wait() having # a timeout would be way too easy, wouldn't it? -.- - (maxSleepTime, pollInterval) = (3, 0.2) + maxSleepTime = 3.0 + pollInterval = 0.2 while self.process.poll() is None and maxSleepTime > 0: maxSleepTime -= pollInterval time.sleep(pollInterval) diff --git a/FTB/Running/StreamCollector.py b/FTB/Running/StreamCollector.py index bbd04265..0fec0b3f 100644 --- a/FTB/Running/StreamCollector.py +++ b/FTB/Running/StreamCollector.py @@ -14,10 +14,17 @@ import queue import threading +from typing import IO class StreamCollector(threading.Thread): - def __init__(self, fd, responseQueue, logResponses=False, maxBacklog=None): + def __init__( + self, + fd: IO[str], + responseQueue: queue.Queue[str], + logResponses: bool = False, + maxBacklog: int | None = None, + ) -> None: assert callable(fd.readline) assert isinstance(responseQueue, queue.Queue) @@ -25,12 +32,12 @@ def __init__(self, fd, responseQueue, logResponses=False, maxBacklog=None): self.fd = fd self.queue = responseQueue - self.output = [] - self.responsePrefixes = [] + self.output: list[str] = [] + self.responsePrefixes: list[str] = [] self.logResponses = logResponses self.maxBacklog = maxBacklog - def run(self): + def run(self) -> None: while True: line = self.fd.readline(4096) @@ -54,5 +61,5 @@ def run(self): self.fd.close() - def addResponsePrefix(self, prefix): + def addResponsePrefix(self, prefix: str) -> None: self.responsePrefixes.append(prefix) diff --git a/FTB/Running/WaitpidMonitor.py b/FTB/Running/WaitpidMonitor.py index b92ac4bc..8ca7f4c6 100644 --- a/FTB/Running/WaitpidMonitor.py +++ b/FTB/Running/WaitpidMonitor.py @@ -18,15 +18,15 @@ class WaitpidMonitor(threading.Thread): - def __init__(self, pid, options): + def __init__(self, pid: int, options: int) -> None: threading.Thread.__init__(self) self.pid = pid self.options = options - self.childPid = None - self.childExit = None + self.childPid: int | None = None + self.childExit: int | None = None - def run(self): + def run(self) -> None: while not self.childPid: (self.childPid, self.childExit) = os.waitpid(self.pid, self.options) diff --git a/FTB/Signatures/CrashInfo.py b/FTB/Signatures/CrashInfo.py index 456245a7..291819c7 100644 --- a/FTB/Signatures/CrashInfo.py +++ b/FTB/Signatures/CrashInfo.py @@ -21,8 +21,10 @@ import sys import unicodedata from abc import ABCMeta +from collections.abc import Callable, Mapping from contextlib import suppress from functools import wraps +from typing import Any from FTB import AssertionHelper from FTB.ProgramConfiguration import ProgramConfiguration @@ -30,7 +32,7 @@ from FTB.Signatures.CrashSignature import CrashSignature -def unicode_escape_result(func): +def unicode_escape_result(func: Callable[..., str]) -> Callable[..., str]: r"""Decorator to escape control and special block unicode characters in a function returning untrusted str values. @@ -38,25 +40,24 @@ def unicode_escape_result(func): """ class unicode_cc_map: - def __getitem__(self, char): + def __getitem__(self, char: int) -> str: if unicodedata.category(chr(char)) in {"Cc", "So"}: return f"\\u{{{char:x}}}" raise LookupError() @wraps(func) - def wrapped(*args, **kwds): + def wrapped(*args: Any, **kwds: Any) -> str: result = func(*args, **kwds) return result.translate(unicode_cc_map()) return wrapped -def _is_unfinished(symbol, operators): - start, end = operators - return bool(symbol.count(start) > symbol.count(end)) +def _is_unfinished(symbol: str, operators: str) -> bool: + return symbol.count(operators[0]) > symbol.count(operators[1]) -def uint32(val): +def uint32(val: int) -> int: """Force `val` into unsigned 32-bit range. Note that the input is returned as an int, therefore @@ -76,7 +77,7 @@ def uint32(val): return val & 0xFFFFFFFF -def int32(val): +def int32(val: int) -> int: """Force `val` into signed 32-bit range. Note that the input is returned as an int, therefore @@ -99,7 +100,7 @@ def int32(val): return val -def uint64(val): +def uint64(val: int) -> int: """Force `val` into unsigned 64-bit range. Note that the input is returned as an int, therefore @@ -119,7 +120,7 @@ def uint64(val): return val & 0xFFFFFFFFFFFFFFFF -def int64(val): +def int64(val: int) -> int: """Force `val` into signed 64-bit range. Note that the input is returned as an int, therefore @@ -145,7 +146,7 @@ def int64(val): class TraceParsingError(RuntimeError): __slots__ = ("line_no",) - def __init__(self, *args, **kwds): + def __init__(self, *args: Any, **kwds: Any) -> None: self.line_no = kwds.pop("line_no") super().__init__(*args, **kwds) @@ -156,31 +157,31 @@ class CrashInfo(metaclass=ABCMeta): It also supports generating a CrashSignature based on the stored information. """ - def __init__(self): + def __init__(self) -> None: # Store the raw data - self.rawStdout = [] - self.rawStderr = [] - self.rawCrashData = [] + self.rawStdout: list[str] = [] + self.rawStderr: list[str] = [] + self.rawCrashData: list[str] = [] # Store processed data - self.backtrace = [] - self.registers = {} - self.crashAddress = None - self.crashInstruction = None + self.backtrace: list[str] = [] + self.registers: dict[str, int] = {} + self.crashAddress: int | None = None + self.crashInstruction: str | None = None # Store configuration data (platform, product, os, etc.) - self.configuration = None + self.configuration: ProgramConfiguration | None = None # This is an optional testcase that is not stored with the crashInfo but # can be "attached" before matching signatures that might require the # testcase. - self.testcase = None + self.testcase: str | None = None # This can be used to record failures during signature creation - self.failureReason = None + self.failureReason: str | None = None - def __str__(self): - buf = [] + def __str__(self) -> str: + buf: list[str] = [] buf.append("Crash trace:") buf.append("") for idx, frame in enumerate(self.backtrace): @@ -201,7 +202,7 @@ def __str__(self): return "\n".join(buf) - def toCacheObject(self): + def toCacheObject(self) -> dict[str, Any]: """ Create a cache object for restoring the class instance later on without parsing the crash data again. This object includes all class fields except for the @@ -210,7 +211,7 @@ def toCacheObject(self): @rtype: dict @return: Dictionary containing expensive class fields """ - cacheObject = {} + cacheObject: dict[str, Any] = {} cacheObject["backtrace"] = self.backtrace cacheObject["registers"] = self.registers @@ -226,8 +227,12 @@ def toCacheObject(self): @staticmethod def fromRawCrashData( - stdout, stderr, configuration, auxCrashData=None, cacheObject=None - ): + stdout: str | list[str] | None, + stderr: str | list[str] | None, + configuration: ProgramConfiguration, + auxCrashData: str | list[str] | None = None, + cacheObject: Mapping[str, Any] | None = None, + ) -> "CrashInfo": """ Create appropriate CrashInfo instance from raw crash data @@ -250,19 +255,19 @@ def fromRawCrashData( @return: Crash information object """ - assert stdout is None or isinstance(stdout, (list, str, bytes)) - assert stderr is None or isinstance(stderr, (list, str, bytes)) - assert auxCrashData is None or isinstance(auxCrashData, (list, str, bytes)) - + # TODO: These checks should raise ValueError instead of being asserts + assert stdout is None or isinstance(stdout, (list, str)) + assert stderr is None or isinstance(stderr, (list, str)) + assert auxCrashData is None or isinstance(auxCrashData, (list, str)) assert isinstance(configuration, ProgramConfiguration) - if isinstance(stdout, (str, bytes)): + if isinstance(stdout, str): stdout = stdout.splitlines() - if isinstance(stderr, (str, bytes)): + if isinstance(stderr, str): stderr = stderr.splitlines() - if isinstance(auxCrashData, (str, bytes)): + if isinstance(auxCrashData, str): auxCrashData = auxCrashData.splitlines() if cacheObject is not None: @@ -314,13 +319,13 @@ def fromRawCrashData( minidumpFirstDetected = False # Search both crashData and stderr, but prefer crashData - lines = [] + lines: list[str] = [] if auxCrashData is not None: lines.extend(auxCrashData) if stderr is not None: lines.extend(stderr) - result = None + result: CrashInfo | None = None for line in lines: if ubsanString in line and re.match(ubsanRegex, line) is not None: result = UBSanCrashInfo(stdout, stderr, configuration, auxCrashData) @@ -385,7 +390,7 @@ def fromRawCrashData( return result @unicode_escape_result - def createShortSignature(self): + def createShortSignature(self) -> str: """ @rtype: String @return: A string representing this crash (short signature) @@ -409,11 +414,11 @@ def createShortSignature(self): def createCrashSignature( self, - forceCrashAddress=False, - forceCrashInstruction=False, - maxFrames=8, - minimumSupportedVersion=13, - ): + forceCrashAddress: bool = False, + forceCrashInstruction: bool = False, + maxFrames: int = 8, + minimumSupportedVersion: int = 13, + ) -> CrashSignature | None: """ @param forceCrashAddress: If True, the crash address will be included in any case @@ -438,7 +443,7 @@ def createCrashSignature( else: numFrames = len(self.backtrace) - symptomArr = [] + symptomArr: list[dict[str, Any]] = [] # Memorize where we find our abort messages abortMsgInCrashdata = False @@ -471,7 +476,7 @@ def createCrashSignature( abortMsgs = [abortMsgs] for abortMsg in abortMsgs: - abortMsg = AssertionHelper.getSanitizedAssertionPattern(abortMsg) + sanAbortMsg = AssertionHelper.getSanitizedAssertionPattern(abortMsg) abortMsgSrc = "stderr" if abortMsgInCrashdata: abortMsgSrc = "crashdata" @@ -481,19 +486,22 @@ def createCrashSignature( # for anything newer, use the short form with forward slashes # to increase the readability of the signatures. if minimumSupportedVersion < 12: - stringObj = {"value": abortMsg, "matchType": "pcre"} - symptomObj = { - "type": "output", - "src": abortMsgSrc, - "value": stringObj, - } + stringObj = {"value": sanAbortMsg, "matchType": "pcre"} + symptomArr.append( + { + "type": "output", + "src": abortMsgSrc, + "value": stringObj, + } + ) else: - symptomObj = { - "type": "output", - "src": abortMsgSrc, - "value": f"/{abortMsg}/", - } - symptomArr.append(symptomObj) + symptomArr.append( + { + "type": "output", + "src": abortMsgSrc, + "value": f"/{sanAbortMsg}/", + } + ) # Consider the first four frames as top stack topStackLimit = 4 @@ -516,17 +524,18 @@ def createCrashSignature( for idx in range(0, numFrames): functionName = self.backtrace[idx] if functionName != "??": - symptomObj = { - "type": "stackFrame", - "frameNumber": idx, - "functionName": functionName, - } - symptomArr.append(symptomObj) + symptomArr.append( + { + "type": "stackFrame", + "frameNumber": idx, + "functionName": functionName, + } + ) elif idx < 4: # If we're in the top 4, we count this as a miss topStackMissCount += 1 else: - framesArray = [] + framesArray: list[str] = [] for idx in range(0, numFrames): functionName = self.backtrace[idx] @@ -540,7 +549,7 @@ def createCrashSignature( lastSymbolizedFrame = None for frameIdx in range(0, len(framesArray)): - if str(framesArray[frameIdx]) != "?": + if framesArray[frameIdx] != "?": lastSymbolizedFrame = frameIdx if lastSymbolizedFrame is not None: @@ -604,7 +613,7 @@ def createCrashSignature( return CrashSignature(json.dumps(sigObj, indent=2, sort_keys=True)) @staticmethod - def sanitizeStackFrame(frame): + def sanitizeStackFrame(frame: str) -> str: """ This function removes function arguments and other non-generic parts of the function frame, returning a (hopefully) generic function name. @@ -648,7 +657,13 @@ def sanitizeStackFrame(frame): class NoCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -668,7 +683,13 @@ def __init__(self, stdout, stderr, configuration, crashData=None): class ASanCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -688,6 +709,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # If crashData is given, use that to find the ASan trace, otherwise use stderr asanOutput = crashData if crashData else stderr + assert asanOutput is not None asanCrashAddressPattern = r"""(?x) [A-Za-z]+Sanitizer.*\s @@ -740,6 +762,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): index, parts = self.split_frame(traceLine) if index is None: continue + assert parts is not None # We may see multiple traces in ASAN if index == 0: @@ -781,7 +804,9 @@ def __init__(self, stdout, stderr, configuration, crashData=None): expectedIndex += 1 @staticmethod - def split_frame(line): + def split_frame( + line: str, + ) -> tuple[int | None, list[str] | None]: parts = line.strip().split() # We only want stack frames @@ -823,7 +848,7 @@ def split_frame(line): return frame_no, parts @unicode_escape_result - def createShortSignature(self): + def createShortSignature(self) -> str: """ @rtype: String @return: A string representing this crash (short signature) @@ -897,7 +922,13 @@ def createShortSignature(self): class LSanCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -917,6 +948,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # If crashData is given, use that to find the LSan trace, otherwise use stderr lsanOutput = crashData if crashData else stderr + assert lsanOutput is not None lsanErrorPattern = "ERROR: LeakSanitizer:" lsanPatternSeen = False @@ -930,6 +962,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): index, parts = ASanCrashInfo.split_frame(traceLine) if index is None: continue + assert parts is not None if expectedIndex != index: raise TraceParsingError( @@ -962,7 +995,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): self.backtrace.append("??") @unicode_escape_result - def createShortSignature(self): + def createShortSignature(self) -> str: """ @rtype: String @return: A string representing this crash (short signature) @@ -986,7 +1019,13 @@ def createShortSignature(self): class UBSanCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1006,6 +1045,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # If crashData is given, use that to find the UBSan trace, otherwise use stderr ubsanOutput = crashData if crashData else stderr + assert ubsanOutput is not None ubsanErrorPattern = r":\d+:\d+:\s+runtime\s+error:\s+" ubsanPatternSeen = False @@ -1019,6 +1059,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): index, parts = ASanCrashInfo.split_frame(traceLine) if index is None: continue + assert parts is not None if expectedIndex != index: raise TraceParsingError( @@ -1045,7 +1086,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): expectedIndex += 1 @unicode_escape_result - def createShortSignature(self): + def createShortSignature(self) -> str: """ @rtype: String @return: A string representing this crash (short signature) @@ -1069,7 +1110,13 @@ def createShortSignature(self): class GDBCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1089,6 +1136,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # If crashData is given, use that to find the GDB trace, otherwise use stderr gdbOutput = crashData or stderr + assert gdbOutput is not None gdbFramePatterns = [ "\\s*#(\\d+)\\s+(0x[0-9a-f]+) in (.+?) \\(.*?\\)( at .+)?", @@ -1149,7 +1197,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): lastLineBuf += traceLine - functionName = None + functionName: str | None = None frameIndex = None gdbDebugInfoMismatch = False @@ -1189,6 +1237,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): line_no=line_no, ) + assert functionName is not None # This is a workaround for GDB throwing an error while resolving # function arguments in the trace and aborting. We try to remove the # error message to at least recover the function name properly. @@ -1227,7 +1276,9 @@ def __init__(self, stdout, stderr, configuration, crashData=None): self.crashAddress = uint64(self.crashAddress) @staticmethod - def calculateCrashAddress(crashInstruction, registerMap): + def calculateCrashAddress( + crashInstruction: str, registerMap: Mapping[str, int] + ) -> int | str | None: """ Calculate the crash address given the crash instruction and register contents @@ -1248,7 +1299,7 @@ def calculateCrashAddress(crashInstruction, registerMap): # that this caused our crash. return RegisterHelper.getInstructionPointer(registerMap) - parts = crashInstruction.split(None, 1) + parts = crashInstruction.split(maxsplit=1) if len(parts) == 1: # Single instruction without any operands? @@ -1297,10 +1348,10 @@ def calculateCrashAddress(crashInstruction, registerMap): # When we fail, try storing a reason here failureReason = "Unknown failure." - def isDerefOp(op): + def isDerefOp(op: str) -> bool: return "(" in op and ")" in op - def calculateDerefOpAddress(derefOp): + def calculateDerefOpAddress(derefOp: str) -> tuple[int | None, str | None]: match = re.match("\\*?((?:\\-?0x[0-9a-f]+)?)\\(%([a-z0-9]+)\\)", derefOp) if match is not None: offset = 0 @@ -1330,11 +1381,11 @@ def calculateDerefOpAddress(derefOp): # TODO: Fix this properly by including readability information in # GDB output if isDerefOp(parts[0]): - (val, failed) = calculateDerefOpAddress(parts[0]) - if failed: - failureReason = failed - else: + val, failed = calculateDerefOpAddress(parts[0]) + if val is not None: return val + assert failed is not None + failureReason = failed else: # No deref, so the stack access must be failing return RegisterHelper.getStackPointer(registerMap) @@ -1345,11 +1396,11 @@ def calculateDerefOpAddress(derefOp): # we don't mix them with instructions that also # interacts with the stack pointer. if instruction.startswith("set"): - (val, failed) = calculateDerefOpAddress(parts[0]) - if failed: - failureReason = failed - else: + val, failed = calculateDerefOpAddress(parts[0]) + if val is not None: return val + assert failed is not None + failureReason = failed else: failureReason = "Unsupported single-operand instruction." elif len(parts) == 2: @@ -1391,11 +1442,11 @@ def calculateDerefOpAddress(derefOp): derefOp = parts[1] if derefOp is not None: - (val, failed) = calculateDerefOpAddress(derefOp) - if failed: - failureReason = failed - else: + val, failed = calculateDerefOpAddress(derefOp) + if val is not None: return val + assert failed is not None + failureReason = failed else: failureReason = ( "Failed to decode two-operand instruction: No dereference " @@ -1418,14 +1469,13 @@ def calculateDerefOpAddress(derefOp): if "(" in parts[0] and ")" in parts[2]: complexDerefOp = parts[0] + "," + parts[1] + "," + parts[2] - (result, reason) = GDBCrashInfo.calculateComplexDerefOpAddress( + val, failed = GDBCrashInfo.calculateComplexDerefOpAddress( complexDerefOp, registerMap ) - - if result is None: - failureReason = reason - else: - return result + if val is not None: + return val + assert failed is not None + failureReason = failed else: raise RuntimeError( f"Unexpected instruction pattern: {crashInstruction}" @@ -1436,14 +1486,13 @@ def calculateDerefOpAddress(derefOp): elif "(" not in parts[0] and ")" not in parts[0]: complexDerefOp = parts[1] + "," + parts[2] + "," + parts[3] - (result, reason) = GDBCrashInfo.calculateComplexDerefOpAddress( + val, failed = GDBCrashInfo.calculateComplexDerefOpAddress( complexDerefOp, registerMap ) - - if result is None: - failureReason = reason - else: - return result + if val is not None: + return val + assert failed is not None + failureReason = failed else: raise RuntimeError( "Unexpected length after splitting operands of this instruction: " @@ -1454,13 +1503,15 @@ def calculateDerefOpAddress(derefOp): # Anything that is not explicitly handled now is considered unsupported failureReason = "Unsupported instruction in incomplete ARM/ARM64 support." - def getARMImmConst(val): + def getARMImmConst(val: str) -> int: val = val.replace("#", "").strip() if val.startswith("0x"): return int(val, 16) return int(val) - def calculateARMDerefOpAddress(derefOp): + def calculateARMDerefOpAddress( + derefOp: str, + ) -> tuple[int | None, str | None]: derefOps = derefOp.split(",") if len(derefOps) > 2: @@ -1510,11 +1561,11 @@ def calculateARMDerefOpAddress(derefOp): # Load/Store instruction match = re.match("^\\s*\\[(.*)\\]$", parts[1]) if match is not None: - (result, reason) = calculateARMDerefOpAddress(match.group(1)) - if result is None: - failureReason += f" ({reason})" - else: - return result + val, failed = calculateARMDerefOpAddress(match.group(1)) + if val is not None: + return val + assert failed is not None + failureReason += f" ({failed})" else: failureReason = "Architecture is not supported." @@ -1526,7 +1577,9 @@ def calculateARMDerefOpAddress(derefOp): return failureReason @staticmethod - def calculateComplexDerefOpAddress(complexDerefOp, registerMap): + def calculateComplexDerefOpAddress( + complexDerefOp: str, registerMap: Mapping[str, int] + ) -> tuple[int | None, str | None]: match = re.match( "((?:\\-?0x[0-9a-f]+)?)\\(%([a-z0-9]+),%([a-z0-9]+),([0-9]+)\\)", complexDerefOp, @@ -1558,7 +1611,13 @@ def calculateComplexDerefOpAddress(complexDerefOp, registerMap): class MinidumpCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1579,6 +1638,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # If crashData is given, use that to find the Minidump trace, otherwise use # stderr minidumpOuput = crashData or stderr + assert minidumpOuput is not None crashThread = None for traceLine in minidumpOuput: @@ -1608,7 +1668,13 @@ def __init__(self, stdout, stderr, configuration, crashData=None): class AppleCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1627,6 +1693,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): self.configuration = configuration apple_crash_data = crashData or stderr + assert apple_crash_data is not None inCrashingThread = False for line in apple_crash_data: @@ -1651,7 +1718,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): if inCrashingThread: # Example: # "0 js-dbg-64-dm-darwin-a523d4c7efe2 0x00000001004b04c4 js::jit::MacroAssembler::Pop(js::jit::Register) + 180 (MacroAssembler-inl.h:50)" # noqa: E501 - components = line.split(None, 3) + components = line.split(maxsplit=3) stackEntry = components[3] if stackEntry.startswith("0"): self.backtrace.append("??") @@ -1662,14 +1729,14 @@ def __init__(self, stdout, stderr, configuration, crashData=None): self.backtrace.append(stackEntry) @staticmethod - def removeFilename(stackEntry): + def removeFilename(stackEntry: str) -> str: match = re.match(r"(.*) \(\S+\)", stackEntry) if match: return match.group(1) return stackEntry @staticmethod - def removeOffset(stackEntry): + def removeOffset(stackEntry: str) -> str: match = re.match(r"(.*) \+ \d+", stackEntry) if match: return match.group(1) @@ -1677,7 +1744,13 @@ def removeOffset(stackEntry): class CDBCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1701,10 +1774,11 @@ def __init__(self, stdout, stderr, configuration, crashData=None): inCrashingThread = False inCrashInstruction = False inEcxrData = False - ecxrData = [] + ecxrData: list[str] = [] cInstruction = "" cdb_crash_data = crashData or stderr + assert cdb_crash_data is not None for line in cdb_crash_data: # Start of .ecxr data @@ -1780,7 +1854,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # 00404c59 8b39 mov edi,dword ptr [ecx] # 64-bit example: # 00007ff7`4d469ff3 4c8b01 mov r8,qword ptr [rcx] - cInstruction = line.split(None, 2)[-1] + cInstruction = line.split(maxsplit=2)[-1] # There may be multiple spaces inside the faulting instruction cInstruction = " ".join(cInstruction.split()) self.crashInstruction = cInstruction @@ -1813,7 +1887,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): self.backtrace.append(stackEntry) @staticmethod - def removeFilenameAndOffset(stackEntry): + def removeFilenameAndOffset(stackEntry: str) -> str: # Extract only the function name if "0x" in stackEntry: result = ( @@ -1836,7 +1910,13 @@ class RustCrashInfo(CrashInfo): r"(::h[0-9a-f]{16})?|\s+at ([A-Za-z]:)?(/[A-Za-z0-9_ .]+)+:\d+)$" ) - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1857,6 +1937,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # If crashData is given, use that to find the rust backtrace, otherwise use # stderr rustOutput = crashData or stderr + assert rustOutput is not None self.crashAddress = ( 0 # this is always an assertion, set to 0 to make matching more efficient @@ -1881,7 +1962,13 @@ def __init__(self, stdout, stderr, configuration, crashData=None): class TSanCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1901,6 +1988,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # If crashData is given, use that to find the ASan trace, otherwise use stderr tsanOutput = crashData if crashData else stderr + assert tsanOutput is not None tsanWarningPattern = r"""WARNING: ThreadSanitizer:.*\s.+?\s+\(pid=\d+\)""" @@ -1930,6 +2018,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): index, parts = ASanCrashInfo.split_frame(traceLine) if index is None: continue + assert parts is not None # We may see multiple traces in TSAN if index == 0: @@ -1987,7 +2076,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): self.backtrace.extend(backtrace) @unicode_escape_result - def createShortSignature(self): + def createShortSignature(self) -> str: """ @rtype: String @return: A string representing this crash (short signature) @@ -2041,7 +2130,13 @@ class ValgrindCrashInfo(CrashInfo): re.VERBOSE, ) - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ) -> None: """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -2062,6 +2157,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # If crashData is given, use that to find the Valgrind trace, otherwise use # stderr vgdOutput = crashData if crashData else stderr + assert vgdOutput is not None stackPattern = re.compile( r""" ^==\d+==\s+(at|by)\s+ # beginning of entry @@ -2105,7 +2201,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): break @unicode_escape_result - def createShortSignature(self): + def createShortSignature(self) -> str: """ @rtype: String @return: A string representing this crash (short signature) diff --git a/FTB/Signatures/CrashSignature.py b/FTB/Signatures/CrashSignature.py index f8090043..a10bea19 100644 --- a/FTB/Signatures/CrashSignature.py +++ b/FTB/Signatures/CrashSignature.py @@ -17,6 +17,7 @@ import difflib import json +from typing import TYPE_CHECKING, Any from FTB.Signatures import JSONHelper from FTB.Signatures.Symptom import ( @@ -26,9 +27,12 @@ TestcaseSymptom, ) +if TYPE_CHECKING: + from FTB.Signatures.CrashInfo import CrashInfo + class CrashSignature: - def __init__(self, rawSignature): + def __init__(self, rawSignature: str) -> None: """ Constructor @@ -42,7 +46,7 @@ def __init__(self, rawSignature): # serializes this object back to JSON as it is. # self.rawSignature = rawSignature - self.symptoms = [] + self.symptoms: list[Symptom] = [] try: obj = json.loads(rawSignature) @@ -66,14 +70,14 @@ def __init__(self, rawSignature): self.products = JSONHelper.getArrayChecked(obj, "products") @staticmethod - def fromFile(signatureFile): + def fromFile(signatureFile: str) -> "CrashSignature": with open(signatureFile) as sigFd: return CrashSignature(sigFd.read()) - def __str__(self): - return str(self.rawSignature) + def __str__(self) -> str: + return self.rawSignature - def matches(self, crashInfo): + def matches(self, crashInfo: "CrashInfo") -> bool: """ Match this signature against the given crash information @@ -83,23 +87,18 @@ def matches(self, crashInfo): @rtype: bool @return: True if the signature matches, False otherwise """ - if ( - self.platforms is not None - and crashInfo.configuration.platform not in self.platforms - ): - return False - - if ( - self.operatingSystems is not None - and crashInfo.configuration.os not in self.operatingSystems - ): - return False - - if ( - self.products is not None - and crashInfo.configuration.product not in self.products - ): - return False + if self.platforms is not None: + assert crashInfo.configuration is not None + if crashInfo.configuration.platform not in self.platforms: + return False + if self.operatingSystems is not None: + assert crashInfo.configuration is not None + if crashInfo.configuration.os not in self.operatingSystems: + return False + if self.products is not None: + assert crashInfo.configuration is not None + if crashInfo.configuration.product not in self.products: + return False deferredSymptoms = [] @@ -115,7 +114,7 @@ def matches(self, crashInfo): return all(symptom.matches(crashInfo) for symptom in deferredSymptoms) - def matchRequiresTest(self): + def matchRequiresTest(self) -> bool: """ Check if the signature requires a testcase to match. @@ -127,7 +126,7 @@ def matchRequiresTest(self): """ return any(isinstance(symptom, TestcaseSymptom) for symptom in self.symptoms) - def getRequiredOutputSources(self): + def getRequiredOutputSources(self) -> list[str]: """ Return a list of output sources required by this signature for matching. @@ -138,7 +137,7 @@ def getRequiredOutputSources(self): @return: A list of output identifiers (e.g. stdout, stderr or crashdata) required by this signature. """ - ret = [] + ret: list[str] = [] for symptom in self.symptoms: if isinstance(symptom, OutputSymptom): @@ -150,7 +149,7 @@ def getRequiredOutputSources(self): return ret - def getDistance(self, crashInfo): + def getDistance(self, crashInfo: "CrashInfo") -> int: distance = 0 for symptom in self.symptoms: @@ -165,29 +164,24 @@ def getDistance(self, crashInfo): if not symptom.matches(crashInfo): distance += 1 - if ( - self.platforms is not None - and crashInfo.configuration.platform not in self.platforms - ): - distance += 1 - - if ( - self.operatingSystems is not None - and crashInfo.configuration.os not in self.operatingSystems - ): - distance += 1 - - if ( - self.products is not None - and crashInfo.configuration.product not in self.products - ): - distance += 1 + if self.platforms is not None: + assert crashInfo.configuration is not None + if crashInfo.configuration.platform not in self.platforms: + distance += 1 + if self.operatingSystems is not None: + assert crashInfo.configuration is not None + if crashInfo.configuration.os not in self.operatingSystems: + distance += 1 + if self.products is not None: + assert crashInfo.configuration is not None + if crashInfo.configuration.product not in self.products: + distance += 1 return distance - def fit(self, crashInfo): - sigObj = {} - sigSymptoms = [] + def fit(self, crashInfo: "CrashInfo") -> "CrashSignature | None": + sigObj: dict[str, Any] = {} + sigSymptoms: list[Any] = [] sigObj["symptoms"] = sigSymptoms @@ -214,8 +208,8 @@ def fit(self, crashInfo): return CrashSignature(json.dumps(sigObj, indent=2, sort_keys=True)) - def getSymptomsDiff(self, crashInfo): - symptomsDiff = [] + def getSymptomsDiff(self, crashInfo: "CrashInfo") -> list[dict[str, Any]]: + symptomsDiff: list[dict[str, Any]] = [] for symptom in self.symptoms: if symptom.matches(crashInfo): symptomsDiff.append({"offending": False, "symptom": symptom}) @@ -239,8 +233,10 @@ def getSymptomsDiff(self, crashInfo): symptomsDiff.append({"offending": True, "symptom": symptom}) return symptomsDiff - def getSignatureUnifiedDiffTuples(self, crashInfo): - diffTuples = [] + def getSignatureUnifiedDiffTuples( + self, crashInfo: "CrashInfo" + ) -> list[tuple[str, str]]: + diffTuples: list[tuple[str, str]] = [] # go through dumps(loads()) to standardize the format. # the dumps args here must match what is returned by `fit()` diff --git a/FTB/Signatures/JSONHelper.py b/FTB/Signatures/JSONHelper.py index 95a5842d..17df4803 100644 --- a/FTB/Signatures/JSONHelper.py +++ b/FTB/Signatures/JSONHelper.py @@ -15,9 +15,13 @@ """ import numbers +from collections.abc import Mapping +from typing import Any -def getArrayChecked(obj, key, mandatory=False): +def getArrayChecked( + obj: Mapping[str, Any], key: str, mandatory: bool = False +) -> list[Any] | None: """ Retrieve a list from the given object using the given key @@ -33,10 +37,12 @@ def getArrayChecked(obj, key, mandatory=False): @rtype: list @return: List retrieved from object """ - return __getTypeChecked(obj, key, [list], mandatory) + return __getTypeChecked(obj, key, [list], mandatory) # type: ignore[no-any-return] -def getStringChecked(obj, key, mandatory=False): +def getStringChecked( + obj: Mapping[str, Any], key: str, mandatory: bool = False +) -> str | None: """ Retrieve a string from the given object using the given key @@ -52,10 +58,12 @@ def getStringChecked(obj, key, mandatory=False): @rtype: string @return: String retrieved from object """ - return __getTypeChecked(obj, key, [str, bytes], mandatory) + return __getTypeChecked(obj, key, [str], mandatory) # type: ignore[no-any-return] -def getNumberChecked(obj, key, mandatory=False): +def getNumberChecked( + obj: Mapping[str, Any], key: str, mandatory: bool = False +) -> int | None: """ Retrieve an integer from the given object using the given key @@ -71,10 +79,12 @@ def getNumberChecked(obj, key, mandatory=False): @rtype: int @return: Number retrieved from object """ - return __getTypeChecked(obj, key, [numbers.Integral], mandatory) + return __getTypeChecked(obj, key, [numbers.Integral], mandatory) # type: ignore[no-any-return] -def getObjectOrStringChecked(obj, key, mandatory=False): +def getObjectOrStringChecked( + obj: Mapping[str, Any], key: str, mandatory: bool = False +) -> str | dict[str, Any] | None: """ Retrieve an object or string from the given object using the given key @@ -90,10 +100,12 @@ def getObjectOrStringChecked(obj, key, mandatory=False): @rtype: string or dict @return: String/Object object retrieved from object """ - return __getTypeChecked(obj, key, [str, bytes, dict], mandatory) + return __getTypeChecked(obj, key, [str, dict], mandatory) # type: ignore[no-any-return] -def getNumberOrStringChecked(obj, key, mandatory=False): +def getNumberOrStringChecked( + obj: Mapping[str, Any], key: str, mandatory: bool = False +) -> str | int | None: """ Retrieve a number or string from the given object using the given key @@ -109,10 +121,15 @@ def getNumberOrStringChecked(obj, key, mandatory=False): @rtype: string or number @return: String/Number object retrieved from object """ - return __getTypeChecked(obj, key, [str, bytes, numbers.Integral], mandatory) + return __getTypeChecked(obj, key, [str, numbers.Integral], mandatory) # type: ignore[no-any-return] -def __getTypeChecked(obj, key, valTypes, mandatory=False): +def __getTypeChecked( + obj: Mapping[str, Any], + key: str, + valTypes: list[type], + mandatory: bool = False, +) -> Any: if key not in obj: if mandatory: raise RuntimeError(f'Expected key "{key}" in object') diff --git a/FTB/Signatures/Matchers.py b/FTB/Signatures/Matchers.py index 58d49810..d740031f 100644 --- a/FTB/Signatures/Matchers.py +++ b/FTB/Signatures/Matchers.py @@ -14,23 +14,24 @@ @contact: choller@mozilla.com """ -import numbers import re from abc import ABCMeta, abstractmethod +from enum import StrEnum +from typing import Any from FTB.Signatures import JSONHelper class Match(metaclass=ABCMeta): @abstractmethod - def matches(self, value): + def matches(self, value: Any) -> bool: pass class StringMatch(Match): - def __init__(self, obj): + def __init__(self, obj: str | bytes | dict[str, Any]) -> None: self.isPCRE = False - self.compiledValue = None + self.compiledValue: re.Pattern[str] | None = None self.patternContainsSlash = False if isinstance(obj, bytes): @@ -49,7 +50,9 @@ def __init__(self, obj): except re.error as e: raise RuntimeError(f"Error in regular expression: {e}") else: - self.value = JSONHelper.getStringChecked(obj, "value", True) + value = JSONHelper.getStringChecked(obj, "value", True) + assert value is not None + self.value = value matchType = JSONHelper.getStringChecked(obj, "matchType", False) if matchType is not None: @@ -64,7 +67,7 @@ def __init__(self, obj): else: raise RuntimeError(f"Unknown match operator specified: {matchType}") - def matches(self, value, windowsSlashWorkaround=False): + def matches(self, value: str | bytes, windowsSlashWorkaround: bool = False) -> bool: if isinstance(value, bytes): # If the input is not already unicode, try to interpret it as UTF-8 # If there are errors, replace them with U+FFFD so we neither raise nor @@ -72,6 +75,7 @@ def matches(self, value, windowsSlashWorkaround=False): value = value.decode("utf-8", errors="replace") if self.isPCRE: + assert self.compiledValue is not None if self.compiledValue.search(value) is not None: return True if windowsSlashWorkaround and self.patternContainsSlash: @@ -81,47 +85,43 @@ def matches(self, value, windowsSlashWorkaround=False): return False return self.value in value - def __str__(self): + def __str__(self) -> str: return self.value - def __repr__(self): + def __repr__(self) -> str: if self.isPCRE: return f"/{self.value}/" return self.value -class NumberMatchType: - GE, GT, LE, LT = range(4) +class NumberMatchType(StrEnum): + EQ = "==" + GE = ">=" + GT = ">" + LE = "<=" + LT = "<" class NumberMatch(Match): - def __init__(self, obj): - self.matchType = None + def __init__(self, obj: str | bytes | int) -> None: + self.matchType: NumberMatchType | None = None + self.value: int | None = None if isinstance(obj, bytes): obj = obj.decode("utf-8") if isinstance(obj, str): if len(obj) > 0: - numberMatchComponents = obj.split(None, 1) + numberMatchComponents = obj.split(maxsplit=1) numIdx = 0 if len(numberMatchComponents) > 1: numIdx = 1 matchType = numberMatchComponents[0] - - if matchType == "==": - pass - elif matchType == "<": - self.matchType = NumberMatchType.LT - elif matchType == "<=": - self.matchType = NumberMatchType.LE - elif matchType == ">": - self.matchType = NumberMatchType.GT - elif matchType == ">=": - self.matchType = NumberMatchType.GE - else: + try: + self.matchType = NumberMatchType(matchType) + except ValueError: raise RuntimeError( f"Unknown match operator specified: {matchType}" ) @@ -139,21 +139,25 @@ def __init__(self, obj): # address self.value = None - elif isinstance(obj, numbers.Integral): + elif isinstance(obj, int): self.value = obj else: raise RuntimeError(f"Invalid type {type(obj)} in NumberMatch.") - def matches(self, value): + def matches(self, value: int | None) -> bool: if value is None: return self.value is None - if self.matchType == NumberMatchType.GE: - return value >= self.value - if self.matchType == NumberMatchType.GT: - return value > self.value - if self.matchType == NumberMatchType.LE: - return value <= self.value - if self.matchType == NumberMatchType.LT: - return value < self.value + # _matchType is only set when __init__ parses a non-empty string that also + # sets self.value, so these comparisons are safe. + if self.matchType is not None: + assert self.value is not None + if self.matchType == NumberMatchType.GE: + return value >= self.value + if self.matchType == NumberMatchType.GT: + return value > self.value + if self.matchType == NumberMatchType.LE: + return value <= self.value + if self.matchType == NumberMatchType.LT: + return value < self.value return value == self.value diff --git a/FTB/Signatures/RegisterHelper.py b/FTB/Signatures/RegisterHelper.py index 774364a2..f9f895c2 100644 --- a/FTB/Signatures/RegisterHelper.py +++ b/FTB/Signatures/RegisterHelper.py @@ -12,6 +12,8 @@ @contact: choller@mozilla.com """ +from collections.abc import Mapping + x86Registers = ["eax", "ebx", "ecx", "edx", "esi", "edi", "ebp", "esp", "eip"] x64Registers = [ @@ -77,7 +79,7 @@ } -def getRegisterPattern(): +def getRegisterPattern() -> str: """ Return a pattern including all register names that are considered valid """ @@ -89,7 +91,7 @@ def getRegisterPattern(): ) -def getStackPointer(registerMap): +def getStackPointer(registerMap: Mapping[str, int]) -> int: """ Return the stack pointer value from the given register map @@ -107,7 +109,7 @@ def getStackPointer(registerMap): raise RuntimeError("Register map does not contain a usable stack pointer!") -def getInstructionPointer(registerMap): +def getInstructionPointer(registerMap: Mapping[str, int]) -> int: """ Return the instruction pointer value from the given register map @@ -125,7 +127,7 @@ def getInstructionPointer(registerMap): raise RuntimeError("Register map does not contain a usable instruction pointer!") -def getRegisterValue(register, registerMap): +def getRegisterValue(register: str, registerMap: Mapping[str, int]) -> int | None: """ Return the value of the specified register using the provided register map. This method also works for getting lower register parts out of higher ones. @@ -196,7 +198,7 @@ def getRegisterValue(register, registerMap): return None -def getBitWidth(registerMap): +def getBitWidth(registerMap: Mapping[str, int]) -> int: """ Return the bit width (32 or 64 bit) given the registers @@ -212,7 +214,7 @@ def getBitWidth(registerMap): return 32 -def isX86Compatible(registerMap): +def isX86Compatible(registerMap: Mapping[str, int]) -> bool: """ Return true, if the the given registers are X86 compatible, such as x86 or x86-64. ARM, PPC and your PDP-15 will fail this check and we don't support it right now. @@ -226,7 +228,7 @@ def isX86Compatible(registerMap): return any(register in registerMap for register in x86OnlyRegisters) -def isARMCompatible(registerMap): +def isARMCompatible(registerMap: Mapping[str, int]) -> bool: """ Return true, if the the given registers are either ARM or ARM64. diff --git a/FTB/Signatures/Symptom.py b/FTB/Signatures/Symptom.py index ffe4e2ba..813d86aa 100644 --- a/FTB/Signatures/Symptom.py +++ b/FTB/Signatures/Symptom.py @@ -16,10 +16,14 @@ import json from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Any from FTB.Signatures import JSONHelper from FTB.Signatures.Matchers import NumberMatch, StringMatch +if TYPE_CHECKING: + from FTB.Signatures.CrashInfo import CrashInfo + class Symptom(metaclass=ABCMeta): """ @@ -27,16 +31,16 @@ class Symptom(metaclass=ABCMeta): It also supports generating a CrashSignature based on the stored information. """ - def __init__(self, jsonObj): + def __init__(self, jsonObj: dict[str, Any]) -> None: # Store the original source so we can return it if someone wants to stringify us self.jsonsrc = json.dumps(jsonObj, indent=2) self.jsonobj = jsonObj - def __str__(self): + def __str__(self) -> str: return self.jsonsrc @staticmethod - def fromJSONObject(obj): + def fromJSONObject(obj: dict[str, Any]) -> "Symptom": """ Create the appropriate Symptom based on the given object (decoded from JSON) @@ -68,7 +72,7 @@ def fromJSONObject(obj): raise RuntimeError(f"Unknown symptom type: {stype}") @abstractmethod - def matches(self, crashInfo): + def matches(self, crashInfo: "CrashInfo") -> bool: """ Check if the symptom matches the given crash information @@ -78,18 +82,18 @@ def matches(self, crashInfo): @rtype: bool @return: True if the symptom matches, False otherwise """ - return + return False class OutputSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, Any]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ Symptom.__init__(self, obj) - self.output = StringMatch( - JSONHelper.getObjectOrStringChecked(obj, "value", True) - ) + checked = JSONHelper.getObjectOrStringChecked(obj, "value", True) + assert checked is not None + self.output = StringMatch(checked) self.src = JSONHelper.getStringChecked(obj, "src") if self.src is not None: @@ -101,7 +105,7 @@ def __init__(self, obj): ): raise RuntimeError(f"Invalid source specified: {self.src}") - def matches(self, crashInfo): + def matches(self, crashInfo: "CrashInfo") -> bool: """ Check if the symptom matches the given crash information @@ -111,7 +115,7 @@ def matches(self, crashInfo): @rtype: bool @return: True if the symptom matches, False otherwise """ - checkedOutput = [] + checkedOutput: list[str] = [] if self.src is None: checkedOutput.extend(crashInfo.rawStdout) @@ -124,6 +128,7 @@ def matches(self, crashInfo): else: checkedOutput = crashInfo.rawCrashData + assert crashInfo.configuration is not None windowsSlashWorkaround = crashInfo.configuration.os == "windows" for line in reversed(checkedOutput): if self.output.matches(line, windowsSlashWorkaround=windowsSlashWorkaround): @@ -133,23 +138,22 @@ def matches(self, crashInfo): class StackFrameSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, Any]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ Symptom.__init__(self, obj) - self.functionName = StringMatch( - JSONHelper.getNumberOrStringChecked(obj, "functionName", True) - ) - self.frameNumber = JSONHelper.getNumberOrStringChecked(obj, "frameNumber") - - if self.frameNumber is not None: - self.frameNumber = NumberMatch(self.frameNumber) + func = JSONHelper.getStringChecked(obj, "functionName", True) + assert func is not None + self.functionName = StringMatch(func) + frame = JSONHelper.getNumberOrStringChecked(obj, "frameNumber") + if frame is not None: + self.frameNumber = NumberMatch(frame) else: # Default to 0 self.frameNumber = NumberMatch(0) - def matches(self, crashInfo): + def matches(self, crashInfo: "CrashInfo") -> bool: """ Check if the symptom matches the given crash information @@ -171,16 +175,16 @@ def matches(self, crashInfo): class StackSizeSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, Any]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ Symptom.__init__(self, obj) - self.stackSize = NumberMatch( - JSONHelper.getNumberOrStringChecked(obj, "size", True) - ) + checked = JSONHelper.getNumberOrStringChecked(obj, "size", True) + assert checked is not None + self.stackSize = NumberMatch(checked) - def matches(self, crashInfo): + def matches(self, crashInfo: "CrashInfo") -> bool: """ Check if the symptom matches the given crash information @@ -194,16 +198,16 @@ def matches(self, crashInfo): class CrashAddressSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, Any]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ Symptom.__init__(self, obj) - self.address = NumberMatch( - JSONHelper.getNumberOrStringChecked(obj, "address", True) - ) + checked = JSONHelper.getNumberOrStringChecked(obj, "address", True) + assert checked is not None + self.address = NumberMatch(checked) - def matches(self, crashInfo): + def matches(self, crashInfo: "CrashInfo") -> bool: """ Check if the symptom matches the given crash information @@ -219,24 +223,23 @@ def matches(self, crashInfo): class InstructionSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, Any]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ Symptom.__init__(self, obj) + self.instructionName: StringMatch | None = None self.registerNames = JSONHelper.getArrayChecked(obj, "registerNames") - self.instructionName = JSONHelper.getObjectOrStringChecked( - obj, "instructionName" - ) - if self.instructionName is not None: - self.instructionName = StringMatch(self.instructionName) - elif self.registerNames is None or len(self.registerNames) == 0: + instr = JSONHelper.getObjectOrStringChecked(obj, "instructionName") + if instr is not None: + self.instructionName = StringMatch(instr) + elif not self.registerNames: raise RuntimeError( "Must provide at least instruction name or register names" ) - def matches(self, crashInfo): + def matches(self, crashInfo: "CrashInfo") -> bool: """ Check if the symptom matches the given crash information @@ -261,16 +264,16 @@ def matches(self, crashInfo): class TestcaseSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, Any]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ Symptom.__init__(self, obj) - self.output = StringMatch( - JSONHelper.getObjectOrStringChecked(obj, "value", True) - ) + checked = JSONHelper.getObjectOrStringChecked(obj, "value", True) + assert checked is not None + self.output = StringMatch(checked) - def matches(self, crashInfo): + def matches(self, crashInfo: "CrashInfo") -> bool: """ Check if the symptom matches the given crash information @@ -291,19 +294,19 @@ def matches(self, crashInfo): class StackFramesSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, Any]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ Symptom.__init__(self, obj) - self.functionNames = [] + self.functionNames: list[StringMatch] = [] rawFunctionNames = JSONHelper.getArrayChecked(obj, "functionNames", True) + if rawFunctionNames is not None: + for fn in rawFunctionNames: + self.functionNames.append(StringMatch(fn)) - for fn in rawFunctionNames: - self.functionNames.append(StringMatch(fn)) - - def matches(self, crashInfo): + def matches(self, crashInfo: "CrashInfo") -> bool: """ Check if the symptom matches the given crash information @@ -316,7 +319,9 @@ def matches(self, crashInfo): return StackFramesSymptom._match(crashInfo.backtrace, self.functionNames) - def diff(self, crashInfo): + def diff( + self, crashInfo: "CrashInfo" + ) -> tuple[int | None, "StackFramesSymptom | None"]: if self.matches(crashInfo): return (0, None) @@ -325,6 +330,7 @@ def diff(self, crashInfo): crashInfo.backtrace, self.functionNames, 0, 1, depth ) if bestDepth is not None: + assert bestGuess is not None guessedFunctionNames = [repr(x) for x in bestGuess] # Remove trailing wildcards as they are of no use @@ -348,7 +354,13 @@ def diff(self, crashInfo): return (None, None) @staticmethod - def _diff(stack, signatureGuess, startIdx, depth, maxDepth): + def _diff( + stack: list[str], + signatureGuess: list[StringMatch], + startIdx: int, + depth: int, + maxDepth: int, + ) -> tuple[int | None, list[StringMatch] | None]: singleWildcardMatch = StringMatch("?") newSignatureGuess = [] @@ -442,7 +454,9 @@ def _diff(stack, signatureGuess, startIdx, depth, maxDepth): return (bestDepth, bestGuess) @staticmethod - def _match(partialStack, partialFunctionNames): + def _match( + partialStack: list[str], partialFunctionNames: list[StringMatch] + ) -> bool: while True: # Process as many non-wildcard chars as we can find iteratively for # performance reasons diff --git a/pyproject.toml b/pyproject.toml index 9d8b0fd6..03432c5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,13 @@ norecursedirs = [ "dist", ] +[tool.mypy] +files = ["FTB"] +exclude = ['/tests/'] +ignore_missing_imports = true +show_error_codes = true +strict = true + [tool.ruff] extend-exclude = ["**/migrations/*.py"] fix = true @@ -62,7 +69,7 @@ select = [ # flake8-simplify "SIM", # flake8-type-checking - "TCH", + "TC", # pyupgrade "UP", # pycodestyle @@ -70,8 +77,6 @@ select = [ ] ignore = ["RUF012"] -[tool.setuptools_scm] - [tool.ruff.lint.isort] known-first-party = [ "Collector", @@ -85,3 +90,5 @@ known-first-party = [ "ec2spotmanager", "taskmanager", ] + +[tool.setuptools_scm] diff --git a/tox.ini b/tox.ini index e85b11c2..282f0879 100644 --- a/tox.ini +++ b/tox.ini @@ -21,6 +21,13 @@ passenv = extras = test +[testenv:mypy] +commands = + mypy --install-types --non-interactive {posargs} +deps = + mypy==v1.20.1 +usedevelop = true + [testenv:update-reqs] skip_install = true basepython = python3.10 From 4426f67f9c5c962a5161cf97c1fa6119a309f6c3 Mon Sep 17 00:00:00 2001 From: Tyson Smith Date: Thu, 23 Apr 2026 10:36:20 -0700 Subject: [PATCH 2/3] ci: add type hints to Collector and Reporter --- .pre-commit-config.yaml | 2 +- Collector/Collector.py | 83 ++++++++++++++++++-------------- FTB/ProgramConfiguration.py | 4 +- FTB/Running/AutoRunner.py | 13 +++-- FTB/Signatures/CrashInfo.py | 10 ++-- FTB/Signatures/CrashSignature.py | 14 +++--- FTB/Signatures/Matchers.py | 5 +- FTB/Signatures/Symptom.py | 24 ++++----- Reporter/Reporter.py | 54 +++++++++++++-------- pyproject.toml | 4 +- 10 files changed, 126 insertions(+), 87 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f7d0506..2d5192d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: entry: tox -e mypy -- language: system require_serial: true - files: ^FTB/ + files: ^(Collector|FTB|Reporter)/ exclude: (^|/)tests/ types: [python] pass_filenames: false diff --git a/Collector/Collector.py b/Collector/Collector.py index 4ddcec1a..0cce2670 100755 --- a/Collector/Collector.py +++ b/Collector/Collector.py @@ -23,7 +23,9 @@ import os import shutil import sys +from collections.abc import Iterator, Mapping from tempfile import mkstemp +from typing import Any from zipfile import ZipFile from FTB.ProgramConfiguration import ProgramConfiguration @@ -38,16 +40,16 @@ signature_checks, ) -__all__ = [] +__all__: list[str] = [] __version__ = 0.1 __date__ = "2014-10-01" -__updated__ = "2025-04-08" +__updated__ = "2026-04-23" class Collector(Reporter): @remote_checks @signature_checks - def refresh(self): + def refresh(self) -> None: """ Refresh signatures by contacting the server, downloading new signatures and invalidating old ones. @@ -68,12 +70,13 @@ def refresh(self): os.remove(zipFileName) @signature_checks - def refreshFromZip(self, zipFileName): + def refreshFromZip(self, zipFileName: str) -> None: """ Refresh signatures from a local zip file, adding new signatures and invalidating old ones. (This is a non-standard use case; you probably want to use refresh() instead.) """ + assert self.sigCacheDir is not None with ZipFile(zipFileName, "r") as zipFile: if zipFile.testzip(): raise InvalidDataError(f"Bad CRC for downloaded zipfile {zipFileName}") @@ -94,12 +97,12 @@ def refreshFromZip(self, zipFileName): @remote_checks def submit( self, - crashInfo, - testCase=None, - testCaseQuality=0, - testCaseSize=None, - metaData=None, - ): + crashInfo: CrashInfo, + testCase: str | None = None, + testCaseQuality: int = 0, + testCaseSize: int | None = None, + metaData: Mapping[str, Any] | None = None, + ) -> Any: """ Submit the given crash information and an optional testcase/metadata to the server for processing and storage. @@ -131,7 +134,7 @@ def submit( # Serialize our crash information, testcase and metadata into a dictionary to # POST - data = {} + data: dict[str, Any] = {} data["rawStdout"] = os.linesep.join(crashInfo.rawStdout) data["rawStderr"] = os.linesep.join(crashInfo.rawStderr) @@ -154,6 +157,7 @@ def submit( if testcase_ext: data["testcase_ext"] = testcase_ext + assert crashInfo.configuration is not None data["platform"] = crashInfo.configuration.platform data["product"] = crashInfo.configuration.product data["os"] = crashInfo.configuration.os @@ -165,7 +169,7 @@ def submit( data["tool"] = self.tool if crashInfo.configuration.metadata or metaData: - aggrMetaData = {} + aggrMetaData: dict[str, Any] = {} if crashInfo.configuration.metadata: aggrMetaData.update(crashInfo.configuration.metadata) @@ -184,7 +188,7 @@ def submit( return self.post(url, data).json() @signature_checks - def search(self, crashInfo): + def search(self, crashInfo: CrashInfo) -> tuple[str | None, dict[str, Any] | None]: """ Searches within the local signature cache directory for a signature matching the given crash. @@ -196,7 +200,7 @@ def search(self, crashInfo): @return: Tuple containing filename of the signature and metadata matching, or None if no match. """ - + assert self.sigCacheDir is not None cachedSigFiles = os.listdir(self.sigCacheDir) for sigFile in cachedSigFiles: @@ -210,7 +214,7 @@ def search(self, crashInfo): crashSig = CrashSignature(sigData) if crashSig.matches(crashInfo): metadataFile = sigFile.replace(".signature", ".metadata") - metadata = None + metadata: dict[str, Any] | None = None if os.path.exists(metadataFile): with open(metadataFile) as m: metadata = json.loads(m.read()) @@ -222,11 +226,11 @@ def search(self, crashInfo): @signature_checks def generate( self, - crashInfo, - forceCrashAddress=None, - forceCrashInstruction=None, - numFrames=None, - ): + crashInfo: CrashInfo, + forceCrashAddress: bool = False, + forceCrashInstruction: bool = False, + numFrames: int = 8, + ) -> str | None: """ Generates a signature in the local cache directory. It will be deleted when L{refresh} is called on the same local cache directory. @@ -257,7 +261,7 @@ def generate( return self.__store_signature_hashed(sig) @remote_checks - def download(self, crashId): + def download(self, crashId: int) -> tuple[str, dict[str, Any]] | None: """ Download the testcase for the specified crashId. @@ -300,7 +304,7 @@ def download(self, crashId): return (local_filename, resp_json) @remote_checks - def download_all(self, bucketId): + def download_all(self, bucketId: int) -> Iterator[str]: """ Download all testcases for the specified bucketId. @@ -310,8 +314,10 @@ def download_all(self, bucketId): @rtype: generator @return: generator of filenames where tests were stored. """ - params = {"query": json.dumps({"op": "OR", "bucket": bucketId})} - next_url = ( + params: dict[str, str] | None = { + "query": json.dumps({"op": "OR", "bucket": bucketId}) + } + next_url: str | None = ( f"{self.serverProtocol}://{self.serverHost}:{self.serverPort}" "/crashmanager/rest/crashes/" ) @@ -350,7 +356,7 @@ def download_all(self, bucketId): yield local_filename - def __store_signature_hashed(self, signature): + def __store_signature_hashed(self, signature: CrashSignature) -> str: """ Store a signature, using the sha1 hash hex representation as filename. @@ -361,11 +367,9 @@ def __store_signature_hashed(self, signature): @return: Name of the file that the signature was written to """ + assert self.sigCacheDir is not None h = hashlib.new("sha1") - if str is bytes: - h.update(str(signature)) - else: - h.update(str(signature).encode("utf-8")) + h.update(str(signature).encode("utf-8")) sigfile = os.path.join(self.sigCacheDir, h.hexdigest() + ".signature") with open(sigfile, "w") as f: f.write(str(signature)) @@ -373,7 +377,7 @@ def __store_signature_hashed(self, signature): return sigfile @staticmethod - def read_testcase(testCase): + def read_testcase(testCase: str) -> tuple[bytes, bool]: """ Read a testcase file, return the content and indicate if it is binary or not. @@ -394,7 +398,7 @@ def read_testcase(testCase): return (testCaseData, isBinary) -def main(args=None): +def main(args: list[str] | None = None) -> int: """Command line options.""" sentry_init() @@ -686,7 +690,7 @@ def main(args=None): if opts.testcase: (testCaseData, isBinary) = Collector.read_testcase(opts.testcase) if not isBinary: - crashInfo.testcase = testCaseData + crashInfo.testcase = testCaseData.decode("utf-8") serverauthtoken = None if opts.serverauthtokenfile: @@ -708,6 +712,7 @@ def main(args=None): return 0 if opts.submit: + assert crashInfo is not None testcase = opts.testcase collector.submit( crashInfo, testcase, opts.testcasequality, opts.testcasesize, metadata @@ -715,16 +720,18 @@ def main(args=None): return 0 if opts.search: - (sig, metadata) = collector.search(crashInfo) + assert crashInfo is not None + (sig, sigMetadata) = collector.search(crashInfo) if sig is None: print("No match found", file=sys.stderr) return 3 print(sig) - if metadata: - print(json.dumps(metadata, indent=4)) + if sigMetadata: + print(json.dumps(sigMetadata, indent=4)) return 0 if opts.generate: + assert crashInfo is not None sigFile = collector.generate( crashInfo, opts.forcecrashaddr, opts.forcecrashinst, opts.numframes ) @@ -738,6 +745,7 @@ def main(args=None): return 0 if opts.autosubmit: + assert configuration is not None runner = AutoRunner.fromBinaryArgs(opts.rargs[0], opts.rargs[1:], env=env) if runner.run(): crashInfo = runner.getCrashInfo(configuration) @@ -752,10 +760,11 @@ def main(args=None): return 1 if opts.download: - (retFile, retJSON) = collector.download(opts.download) - if not retFile: + downloadResult = collector.download(opts.download) + if downloadResult is None: print("Specified crash entry does not have a testcase", file=sys.stderr) return 1 + retFile, retJSON = downloadResult if retJSON.get("args"): args = json.loads(retJSON["args"]) diff --git a/FTB/ProgramConfiguration.py b/FTB/ProgramConfiguration.py index a8c824a6..09456880 100644 --- a/FTB/ProgramConfiguration.py +++ b/FTB/ProgramConfiguration.py @@ -15,6 +15,8 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + import os import sys @@ -63,7 +65,7 @@ def __init__( self.metadata = metadata @staticmethod - def fromBinary(binaryPath: str) -> "ProgramConfiguration | None": + def fromBinary(binaryPath: str) -> ProgramConfiguration | None: binaryConfig = f"{binaryPath}.fuzzmanagerconf" if not os.path.exists(binaryConfig): print( diff --git a/FTB/Running/AutoRunner.py b/FTB/Running/AutoRunner.py index 7209d412..663b1c07 100644 --- a/FTB/Running/AutoRunner.py +++ b/FTB/Running/AutoRunner.py @@ -13,11 +13,13 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + import os import shutil import subprocess import sys -from abc import ABCMeta +from abc import ABC, abstractmethod from pathlib import Path from shutil import rmtree from tempfile import mkdtemp @@ -26,7 +28,7 @@ from FTB.Signatures.CrashInfo import CrashInfo, NoCrashInfo -class AutoRunner(metaclass=ABCMeta): +class AutoRunner(ABC): """ Abstract base class that provides a method to instantiate the right sub class for running the given program and obtaining crash information. @@ -57,7 +59,6 @@ def __init__( self.args = args or [] - assert isinstance(self.env, dict) assert isinstance(self.args, list) # The command that we will run for obtaining crash information @@ -80,7 +81,7 @@ def fromBinaryArgs( env: dict[str, str] | None = None, cwd: str | None = None, stdin: str | list[str] | None = None, - ) -> "AutoRunner": + ) -> AutoRunner: process = subprocess.Popen( ["nm", "-g", binary], stdin=subprocess.PIPE, @@ -103,6 +104,10 @@ def fromBinaryArgs( return GDBRunner(binary, args=args, env=env, cwd=cwd, stdin=stdin) + @abstractmethod + def run(self) -> bool: + pass + class GDBRunner(AutoRunner): def __init__( diff --git a/FTB/Signatures/CrashInfo.py b/FTB/Signatures/CrashInfo.py index 291819c7..22c98bdb 100644 --- a/FTB/Signatures/CrashInfo.py +++ b/FTB/Signatures/CrashInfo.py @@ -15,22 +15,26 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + import json import os import re import sys import unicodedata from abc import ABCMeta -from collections.abc import Callable, Mapping from contextlib import suppress from functools import wraps -from typing import Any +from typing import TYPE_CHECKING, Any from FTB import AssertionHelper from FTB.ProgramConfiguration import ProgramConfiguration from FTB.Signatures import RegisterHelper from FTB.Signatures.CrashSignature import CrashSignature +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + def unicode_escape_result(func: Callable[..., str]) -> Callable[..., str]: r"""Decorator to escape control and special block unicode @@ -232,7 +236,7 @@ def fromRawCrashData( configuration: ProgramConfiguration, auxCrashData: str | list[str] | None = None, cacheObject: Mapping[str, Any] | None = None, - ) -> "CrashInfo": + ) -> CrashInfo: """ Create appropriate CrashInfo instance from raw crash data diff --git a/FTB/Signatures/CrashSignature.py b/FTB/Signatures/CrashSignature.py index a10bea19..5afe759a 100644 --- a/FTB/Signatures/CrashSignature.py +++ b/FTB/Signatures/CrashSignature.py @@ -15,6 +15,8 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + import difflib import json from typing import TYPE_CHECKING, Any @@ -70,14 +72,14 @@ def __init__(self, rawSignature: str) -> None: self.products = JSONHelper.getArrayChecked(obj, "products") @staticmethod - def fromFile(signatureFile: str) -> "CrashSignature": + def fromFile(signatureFile: str) -> CrashSignature: with open(signatureFile) as sigFd: return CrashSignature(sigFd.read()) def __str__(self) -> str: return self.rawSignature - def matches(self, crashInfo: "CrashInfo") -> bool: + def matches(self, crashInfo: CrashInfo) -> bool: """ Match this signature against the given crash information @@ -149,7 +151,7 @@ def getRequiredOutputSources(self) -> list[str]: return ret - def getDistance(self, crashInfo: "CrashInfo") -> int: + def getDistance(self, crashInfo: CrashInfo) -> int: distance = 0 for symptom in self.symptoms: @@ -179,7 +181,7 @@ def getDistance(self, crashInfo: "CrashInfo") -> int: return distance - def fit(self, crashInfo: "CrashInfo") -> "CrashSignature | None": + def fit(self, crashInfo: CrashInfo) -> CrashSignature | None: sigObj: dict[str, Any] = {} sigSymptoms: list[Any] = [] @@ -208,7 +210,7 @@ def fit(self, crashInfo: "CrashInfo") -> "CrashSignature | None": return CrashSignature(json.dumps(sigObj, indent=2, sort_keys=True)) - def getSymptomsDiff(self, crashInfo: "CrashInfo") -> list[dict[str, Any]]: + def getSymptomsDiff(self, crashInfo: CrashInfo) -> list[dict[str, Any]]: symptomsDiff: list[dict[str, Any]] = [] for symptom in self.symptoms: if symptom.matches(crashInfo): @@ -234,7 +236,7 @@ def getSymptomsDiff(self, crashInfo: "CrashInfo") -> list[dict[str, Any]]: return symptomsDiff def getSignatureUnifiedDiffTuples( - self, crashInfo: "CrashInfo" + self, crashInfo: CrashInfo ) -> list[tuple[str, str]]: diffTuples: list[tuple[str, str]] = [] diff --git a/FTB/Signatures/Matchers.py b/FTB/Signatures/Matchers.py index d740031f..ba15dd5a 100644 --- a/FTB/Signatures/Matchers.py +++ b/FTB/Signatures/Matchers.py @@ -16,7 +16,7 @@ import re from abc import ABCMeta, abstractmethod -from enum import StrEnum +from enum import Enum from typing import Any from FTB.Signatures import JSONHelper @@ -95,7 +95,8 @@ def __repr__(self) -> str: return self.value -class NumberMatchType(StrEnum): +# TODO: Python >= 3.11: Enum -> StrEnum +class NumberMatchType(str, Enum): EQ = "==" GE = ">=" GT = ">" diff --git a/FTB/Signatures/Symptom.py b/FTB/Signatures/Symptom.py index 813d86aa..df8f445b 100644 --- a/FTB/Signatures/Symptom.py +++ b/FTB/Signatures/Symptom.py @@ -14,6 +14,8 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + import json from abc import ABCMeta, abstractmethod from typing import TYPE_CHECKING, Any @@ -40,7 +42,7 @@ def __str__(self) -> str: return self.jsonsrc @staticmethod - def fromJSONObject(obj: dict[str, Any]) -> "Symptom": + def fromJSONObject(obj: dict[str, Any]) -> Symptom: """ Create the appropriate Symptom based on the given object (decoded from JSON) @@ -72,7 +74,7 @@ def fromJSONObject(obj: dict[str, Any]) -> "Symptom": raise RuntimeError(f"Unknown symptom type: {stype}") @abstractmethod - def matches(self, crashInfo: "CrashInfo") -> bool: + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information @@ -105,7 +107,7 @@ def __init__(self, obj: dict[str, Any]) -> None: ): raise RuntimeError(f"Invalid source specified: {self.src}") - def matches(self, crashInfo: "CrashInfo") -> bool: + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information @@ -153,7 +155,7 @@ def __init__(self, obj: dict[str, Any]) -> None: # Default to 0 self.frameNumber = NumberMatch(0) - def matches(self, crashInfo: "CrashInfo") -> bool: + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information @@ -184,7 +186,7 @@ def __init__(self, obj: dict[str, Any]) -> None: assert checked is not None self.stackSize = NumberMatch(checked) - def matches(self, crashInfo: "CrashInfo") -> bool: + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information @@ -207,7 +209,7 @@ def __init__(self, obj: dict[str, Any]) -> None: assert checked is not None self.address = NumberMatch(checked) - def matches(self, crashInfo: "CrashInfo") -> bool: + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information @@ -239,7 +241,7 @@ def __init__(self, obj: dict[str, Any]) -> None: "Must provide at least instruction name or register names" ) - def matches(self, crashInfo: "CrashInfo") -> bool: + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information @@ -273,7 +275,7 @@ def __init__(self, obj: dict[str, Any]) -> None: assert checked is not None self.output = StringMatch(checked) - def matches(self, crashInfo: "CrashInfo") -> bool: + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information @@ -306,7 +308,7 @@ def __init__(self, obj: dict[str, Any]) -> None: for fn in rawFunctionNames: self.functionNames.append(StringMatch(fn)) - def matches(self, crashInfo: "CrashInfo") -> bool: + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information @@ -320,8 +322,8 @@ def matches(self, crashInfo: "CrashInfo") -> bool: return StackFramesSymptom._match(crashInfo.backtrace, self.functionNames) def diff( - self, crashInfo: "CrashInfo" - ) -> tuple[int | None, "StackFramesSymptom | None"]: + self, crashInfo: CrashInfo + ) -> tuple[int | None, StackFramesSymptom | None]: if self.matches(crashInfo): return (0, None) diff --git a/Reporter/Reporter.py b/Reporter/Reporter.py index 6773ef63..f861a766 100644 --- a/Reporter/Reporter.py +++ b/Reporter/Reporter.py @@ -18,6 +18,8 @@ import platform import time from abc import ABC +from collections.abc import Callable +from typing import Any, Concatenate, ParamSpec, TypeVar import requests import requests.exceptions @@ -33,6 +35,10 @@ LOG = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T", bound="Reporter") + # Inheriting from RuntimeError because of legacy code. # All of these exceptions used to be RuntimeError. @@ -52,11 +58,13 @@ class InvalidDataError(ReporterException): """Reporter data validation failures.""" -def remote_checks(wrapped): +def remote_checks( + wrapped: Callable[Concatenate[T, P], R], +) -> Callable[Concatenate[T, P], R]: """Decorator to perform error checks before using remote features""" @functools.wraps(wrapped) - def decorator(self, *args, **kwargs): + def decorator(self: T, *args: P.args, **kwargs: P.kwargs) -> R: if not self.serverHost: raise ConfigurationError( "Must specify serverHost (configuration property: serverhost) to use " @@ -74,14 +82,16 @@ def decorator(self, *args, **kwargs): ) return wrapped(self, *args, **kwargs) - return decorator + return decorator # type: ignore[return-value] -def signature_checks(wrapped): +def signature_checks( + wrapped: Callable[Concatenate[T, P], R], +) -> Callable[Concatenate[T, P], R]: """Decorator to perform error checks before using signature features""" @functools.wraps(wrapped) - def decorator(self, *args, **kwargs): + def decorator(self: T, *args: P.args, **kwargs: P.kwargs) -> R: if not self.sigCacheDir: raise ConfigurationError( "Must specify sigCacheDir (configuration property: sigdir) to use " @@ -89,15 +99,17 @@ def decorator(self, *args, **kwargs): ) return wrapped(self, *args, **kwargs) - return decorator + return decorator # type: ignore[return-value] -def requests_retry(wrapped): +def requests_retry( + wrapped: Callable[..., requests.Response], +) -> Callable[..., requests.Response]: """Wrapper around requests methods that retries up to 2 minutes if it's likely that the response codes indicate a temporary error""" @functools.wraps(wrapped) - def wrapper(*args, **kwds): + def wrapper(*args: Any, **kwds: Any) -> requests.Response: success = kwds.pop("expected") # max_sleep is the upper limit for exponential backoff, # which begins at 2s and doubles each retry @@ -117,7 +129,7 @@ def wrapper(*args, **kwds): if response.status_code != success: # Allow for a total sleep time of up to 2 minutes if it's # likely that the response codes indicate a temporary error - retry_codes = [429, 500, 502, 503, 504] + retry_codes = (429, 500, 502, 503, 504) if response.status_code in retry_codes and current_timeout <= max_sleep: LOG.warning( "in %s, server returned %s, retrying...", @@ -136,7 +148,7 @@ def wrapper(*args, **kwds): return wrapper -def sentry_init(): +def sentry_init() -> None: if HAVE_SENTRY: sentry_fuzzing_config.init() @@ -144,14 +156,14 @@ def sentry_init(): class Reporter(ABC): def __init__( self, - sigCacheDir=None, - serverHost=None, - serverPort=None, - serverProtocol=None, - serverAuthToken=None, - clientId=None, - tool=None, - ): + sigCacheDir: str | None = None, + serverHost: str | None = None, + serverPort: int | None = None, + serverProtocol: str | None = None, + serverAuthToken: str | None = None, + clientId: str | None = None, + tool: str | None = None, + ) -> None: """ Initialize the Reporter. This constructor will also attempt to read a configuration file to populate any missing properties that have not @@ -229,7 +241,7 @@ def __init__( if self.serverHost is not None and self.clientId is None: self.clientId = platform.node() - def get(self, *args, **kwds): + def get(self, *args: Any, **kwds: Any) -> requests.Response: """requests.get, with added support for FuzzManager authentication and retry on 5xx errors. @@ -243,7 +255,7 @@ def get(self, *args, **kwds): ) return requests_retry(self._session.get)(*args, **kwds) - def post(self, *args, **kwds): + def post(self, *args: Any, **kwds: Any) -> requests.Response: """requests.post, with added support for FuzzManager authentication and retry on 5xx errors. @@ -257,7 +269,7 @@ def post(self, *args, **kwds): ) return requests_retry(self._session.post)(*args, **kwds) - def patch(self, *args, **kwds): + def patch(self, *args: Any, **kwds: Any) -> requests.Response: """requests.patch, with added support for FuzzManager authentication and retry on 5xx errors. diff --git a/pyproject.toml b/pyproject.toml index 03432c5e..718f6c49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ norecursedirs = [ ] [tool.mypy] -files = ["FTB"] +files = ["Collector", "FTB", "Reporter"] exclude = ['/tests/'] ignore_missing_imports = true show_error_codes = true @@ -56,6 +56,8 @@ select = [ "E", # Pyflakes "F", + # flake8-future-annotations + "FA", # Flynt "FLY", # isort From 6570c6b14a637ce93ed25b3b9810b19422ad875c Mon Sep 17 00:00:00 2001 From: Tyson Smith Date: Tue, 28 Apr 2026 12:48:11 -0700 Subject: [PATCH 3/3] ci: add py.typed files --- Collector/py.typed | 0 FTB/py.typed | 0 Reporter/py.typed | 0 setup.cfg | 5 +++++ 4 files changed, 5 insertions(+) create mode 100644 Collector/py.typed create mode 100644 FTB/py.typed create mode 100644 Reporter/py.typed diff --git a/Collector/py.typed b/Collector/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/FTB/py.typed b/FTB/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/Reporter/py.typed b/Reporter/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/setup.cfg b/setup.cfg index 3d46da9a..fd7bb6a9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,11 @@ packages = TaskStatusReporter python_requires = >=3.10 +[options.package_data] +Collector = py.typed +FTB = py.typed +Reporter = py.typed + [options.entry_points] console_scripts = collector = Collector:Collector.main