From bedaa199cc2ffb209f0034695ac5924931a60d7f Mon Sep 17 00:00:00 2001 From: Itamar Gafni Date: Fri, 13 Mar 2026 10:25:32 +0200 Subject: [PATCH 1/8] Run fuzz tests on pull requests to main Co-Authored-By: Claude Opus 4.6 --- .github/workflows/fuzz.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index 08ff6e4..9f1ef52 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -3,6 +3,8 @@ name: Fuzz Testing on: push: branches: [main] + pull_request: + branches: [main] jobs: fuzz: From f9ab0b91b7e33bb40a39c17f25a17549f3b79d4a Mon Sep 17 00:00:00 2001 From: Itamar Gafni Date: Fri, 13 Mar 2026 11:26:06 +0200 Subject: [PATCH 2/8] Apply Python coding style across entire codebase - Single quotes for strings, double quotes for docstrings - Type hints on all function signatures - X | Y unions, X | None instead of Optional - Native list, dict, tuple, set typing - StrEnum where applicable (DecoderType, _JSType) - Expanded abbreviated variable names - Imports at top of files, never inside functions - Reduced nesting via early returns and helper extraction - match statements replacing long if/elif chains Co-Authored-By: Claude Opus 4.6 --- pyjsclear/__init__.py | 12 +- pyjsclear/__main__.py | 7 +- pyjsclear/deobfuscator.py | 24 +- pyjsclear/generator.py | 221 ++++---- pyjsclear/parser.py | 4 +- pyjsclear/scope.py | 55 +- pyjsclear/transforms/aa_decode.py | 24 +- pyjsclear/transforms/anti_tamper.py | 27 +- pyjsclear/transforms/base.py | 20 +- pyjsclear/transforms/class_static_resolver.py | 20 +- pyjsclear/transforms/class_string_decoder.py | 127 ++--- pyjsclear/transforms/cleanup.py | 87 +-- pyjsclear/transforms/constant_prop.py | 68 ++- pyjsclear/transforms/control_flow.py | 191 +++---- pyjsclear/transforms/dead_branch.py | 18 +- pyjsclear/transforms/dead_class_props.py | 107 ++-- pyjsclear/transforms/dead_expressions.py | 10 +- pyjsclear/transforms/dead_object_props.py | 58 +- pyjsclear/transforms/else_if_flatten.py | 34 +- pyjsclear/transforms/enum_resolver.py | 30 +- pyjsclear/transforms/eval_unpack.py | 55 +- pyjsclear/transforms/expression_simplifier.py | 174 +++--- pyjsclear/transforms/global_alias.py | 127 +++-- pyjsclear/transforms/hex_escapes.py | 23 +- pyjsclear/transforms/hex_numerics.py | 4 +- pyjsclear/transforms/jj_decode.py | 316 ++++++----- pyjsclear/transforms/jsfuck_decode.py | 491 ++++++++-------- pyjsclear/transforms/logical_to_if.py | 114 ++-- pyjsclear/transforms/member_chain_resolver.py | 124 ++-- pyjsclear/transforms/noop_calls.py | 42 +- pyjsclear/transforms/nullish_coalescing.py | 67 ++- pyjsclear/transforms/object_packer.py | 24 +- pyjsclear/transforms/object_simplifier.py | 76 +-- pyjsclear/transforms/optional_chaining.py | 87 +-- pyjsclear/transforms/property_simplifier.py | 8 +- pyjsclear/transforms/proxy_functions.py | 68 +-- pyjsclear/transforms/reassignment.py | 38 +- pyjsclear/transforms/require_inliner.py | 30 +- pyjsclear/transforms/sequence_splitter.py | 120 ++-- pyjsclear/transforms/single_use_vars.py | 43 +- pyjsclear/transforms/string_revealer.py | 178 +++--- pyjsclear/transforms/unreachable_code.py | 27 +- pyjsclear/transforms/unused_vars.py | 36 +- pyjsclear/transforms/variable_renamer.py | 211 ++++--- pyjsclear/transforms/xor_string_decode.py | 78 +-- pyjsclear/traverser.py | 59 +- pyjsclear/utils/ast_helpers.py | 72 +-- pyjsclear/utils/string_decoders.py | 37 +- tests/conftest.py | 4 +- tests/fuzz/conftest_fuzz.py | 493 ++++++++-------- tests/fuzz/fuzz_deobfuscate.py | 12 +- tests/fuzz/fuzz_expression_simplifier.py | 8 +- tests/fuzz/fuzz_generator.py | 15 +- tests/fuzz/fuzz_parser.py | 12 +- tests/fuzz/fuzz_scope.py | 10 +- tests/fuzz/fuzz_string_decoders.py | 104 ++-- tests/fuzz/fuzz_transforms.py | 6 +- tests/fuzz/fuzz_traverser.py | 104 ++-- tests/test_regression.py | 53 +- tests/unit/conftest.py | 8 +- tests/unit/deobfuscator_test.py | 11 +- tests/unit/generator_test.py | 17 +- tests/unit/transforms/aa_decode_test.py | 12 +- tests/unit/transforms/anti_tamper_test.py | 32 +- tests/unit/transforms/base_test.py | 50 +- .../transforms/class_string_decoder_test.py | 40 +- tests/unit/transforms/constant_prop_test.py | 21 +- tests/unit/transforms/control_flow_test.py | 154 ++--- .../transforms/expression_simplifier_test.py | 18 +- tests/unit/transforms/hex_numerics_test.py | 8 +- tests/unit/transforms/logical_to_if_test.py | 15 +- .../transforms/member_chain_resolver_test.py | 9 +- tests/unit/transforms/object_packer_test.py | 6 +- .../unit/transforms/object_simplifier_test.py | 20 +- .../unit/transforms/optional_chaining_test.py | 70 +-- .../transforms/property_simplifier_test.py | 22 +- tests/unit/transforms/proxy_functions_test.py | 75 +-- tests/unit/transforms/reassignment_test.py | 48 +- tests/unit/transforms/require_inliner_test.py | 8 +- .../unit/transforms/sequence_splitter_test.py | 75 ++- tests/unit/transforms/single_use_vars_test.py | 36 +- tests/unit/transforms/string_revealer_test.py | 535 +++++++++--------- .../unit/transforms/unreachable_code_test.py | 18 +- tests/unit/transforms/unused_vars_test.py | 74 ++- .../unit/transforms/variable_renamer_test.py | 62 +- .../unit/transforms/xor_string_decode_test.py | 26 +- tests/unit/utils/ast_helpers_test.py | 240 ++++---- tests/unit/utils/string_decoders_test.py | 102 ++-- 88 files changed, 3212 insertions(+), 3294 deletions(-) diff --git a/pyjsclear/__init__.py b/pyjsclear/__init__.py index 39945b4..4e0b478 100644 --- a/pyjsclear/__init__.py +++ b/pyjsclear/__init__.py @@ -11,7 +11,7 @@ __version__ = '0.1.3' -def deobfuscate(code, max_iterations=50): +def deobfuscate(code: str, max_iterations: int = 50) -> str: """Deobfuscate JavaScript code. Returns cleaned source. Args: @@ -24,7 +24,7 @@ def deobfuscate(code, max_iterations=50): return Deobfuscator(code, max_iterations=max_iterations).execute() -def deobfuscate_file(input_path, output_path=None, max_iterations=50): +def deobfuscate_file(input_path: str, output_path: str | None = None, max_iterations: int = 50) -> str | bool: """Deobfuscate a JavaScript file. Args: @@ -35,13 +35,13 @@ def deobfuscate_file(input_path, output_path=None, max_iterations=50): Returns: True if content changed (when output_path given), or the deobfuscated string. """ - with open(input_path, 'r', errors='replace') as f: - code = f.read() + with open(input_path, 'r', errors='replace') as input_file: + code = input_file.read() result = deobfuscate(code, max_iterations=max_iterations) if output_path: - with open(output_path, 'w') as f: - f.write(result) + with open(output_path, 'w') as output_file: + output_file.write(result) return result != code return result diff --git a/pyjsclear/__main__.py b/pyjsclear/__main__.py index faf2262..1671410 100644 --- a/pyjsclear/__main__.py +++ b/pyjsclear/__main__.py @@ -6,7 +6,7 @@ from . import deobfuscate -def main(): +def main() -> None: parser = argparse.ArgumentParser(description='Deobfuscate JavaScript files.') parser.add_argument('input', help='Input JS file (use - for stdin)') parser.add_argument('-o', '--output', help='Output file (default: stdout)') @@ -29,8 +29,9 @@ def main(): if args.output: with open(args.output, 'w') as output_file: output_file.write(result) - else: - sys.stdout.write(result) + return + + sys.stdout.write(result) if __name__ == '__main__': diff --git a/pyjsclear/deobfuscator.py b/pyjsclear/deobfuscator.py index 8636cad..8e645fb 100644 --- a/pyjsclear/deobfuscator.py +++ b/pyjsclear/deobfuscator.py @@ -106,26 +106,26 @@ _NODE_COUNT_LIMIT = 50_000 # Skip ControlFlowRecoverer above this -def _count_nodes(ast): +def _count_nodes(ast: dict) -> int: """Count total AST nodes.""" count = 0 - def cb(node, parent): + def increment_count(node: dict, parent: dict | None) -> None: nonlocal count count += 1 - simple_traverse(ast, cb) + simple_traverse(ast, increment_count) return count class Deobfuscator: """Multi-pass JavaScript deobfuscator.""" - def __init__(self, code, max_iterations=50): + def __init__(self, code: str, max_iterations: int = 50) -> None: self.original_code = code self.max_iterations = max_iterations - def _run_pre_passes(self, code): + def _run_pre_passes(self, code: str) -> str | None: """Run encoding detection and eval unpacking pre-passes. Returns decoded code if an encoding/packing was detected and decoded, @@ -160,7 +160,7 @@ def _run_pre_passes(self, code): # Maximum number of outer re-parse cycles (generate → re-parse → re-transform) _MAX_OUTER_CYCLES = 5 - def execute(self): + def execute(self) -> str: """Run all transforms and return cleaned source.""" code = self.original_code @@ -238,7 +238,7 @@ def execute(self): # but also recursive. Return best result so far. return previous_code - def _run_ast_transforms(self, ast, code_size=0): + def _run_ast_transforms(self, ast: dict, code_size: int = 0) -> bool: """Run all AST transform passes. Returns True if any transform changed the AST.""" node_count = _count_nodes(ast) if code_size > _LARGE_FILE_SIZE else 0 @@ -261,7 +261,7 @@ def _run_ast_transforms(self, ast, code_size=0): # Multi-pass transform loop any_transform_changed = False - for i in range(max_iterations): + for iteration in range(max_iterations): modified = False for transform_class in transform_classes: if transform_class in skip_transforms: @@ -274,11 +274,9 @@ def _run_ast_transforms(self, ast, code_size=0): if result: modified = True any_transform_changed = True - else: - # If a transform didn't change anything after the first pass, - # skip it in subsequent iterations - if i > 0: - skip_transforms.add(transform_class) + elif iteration > 0: + # Skip transforms that haven't changed anything after the first pass + skip_transforms.add(transform_class) if not modified: break diff --git a/pyjsclear/generator.py b/pyjsclear/generator.py index 71f0053..c680fd9 100644 --- a/pyjsclear/generator.py +++ b/pyjsclear/generator.py @@ -1,4 +1,5 @@ """ESTree AST to JavaScript code generator.""" +from __future__ import annotations # Operator precedence (higher = binds tighter) _PRECEDENCE = { @@ -63,7 +64,7 @@ ) -def generate(node, indent=0): +def generate(node: dict | None, indent: int = 0) -> str: """Generate JavaScript source from an ESTree AST node.""" if node is None: return '' @@ -77,11 +78,11 @@ def generate(node, indent=0): return f'/* unknown: {node_type} */' -def _indent_str(level): +def _indent_str(level: int) -> str: return ' ' * level -def _is_directive(stmt): +def _is_directive(stmt: dict) -> bool: """Check if a statement is a string-literal directive (like 'use strict').""" return ( stmt.get('type') == 'ExpressionStatement' @@ -91,10 +92,10 @@ def _is_directive(stmt): ) -def _gen_program(node, indent): +def _gen_program(node: dict, indent: int) -> str: parts = [] body = node.get('body', []) - for i, stmt in enumerate(body): + for index, stmt in enumerate(body): if stmt is None: continue if stmt.get('type') == 'EmptyStatement': @@ -102,12 +103,12 @@ def _gen_program(node, indent): statement_code = _gen_stmt(stmt, indent) if statement_code.strip(): parts.append(statement_code) - if _is_directive(stmt) and i + 1 < len(body): + if _is_directive(stmt) and index + 1 < len(body): parts.append('') return '\n'.join(parts) -def _gen_stmt(node, indent): +def _gen_stmt(node: dict | None, indent: int) -> str: """Generate a statement with indentation.""" if node is None: return '' @@ -121,20 +122,20 @@ def _gen_stmt(node, indent): return prefix + code + ';' -def _gen_block(node, indent): +def _gen_block(node: dict, indent: int) -> str: if not node.get('body'): return '{}' lines = ['{'] body = node.get('body', []) - for i, stmt in enumerate(body): + for index, stmt in enumerate(body): lines.append(_gen_stmt(stmt, indent + 1)) - if _is_directive(stmt) and i + 1 < len(body): + if _is_directive(stmt) and index + 1 < len(body): lines.append('') lines.append(_indent_str(indent) + '}') return '\n'.join(lines) -def _gen_var_declaration(node, indent): +def _gen_var_declaration(node: dict, indent: int) -> str: kind = node.get('kind', 'var') declarations = [] for declaration in node.get('declarations', []): @@ -147,10 +148,10 @@ def _gen_var_declaration(node, indent): return f'{kind} {", ".join(declarations)}' -def _gen_function(node, indent, is_expression=False): +def _gen_function(node: dict, indent: int, is_expression: bool = False) -> str: """Shared generator for FunctionDeclaration and FunctionExpression.""" name = generate(node['id'], indent) if node.get('id') else '' - params = ', '.join(generate(p, indent) for p in node.get('params', [])) + params = ', '.join(generate(param, indent) for param in node.get('params', [])) async_prefix = 'async ' if node.get('async') else '' gen_prefix = '*' if node.get('generator') else '' body = generate(node['body'], indent) @@ -160,18 +161,18 @@ def _gen_function(node, indent, is_expression=False): return f'{async_prefix}function{gen_prefix} ({params}) {body}' -def _gen_function_decl(node, indent): +def _gen_function_decl(node: dict, indent: int) -> str: return _gen_function(node, indent) -def _gen_function_expr(node, indent): +def _gen_function_expr(node: dict, indent: int) -> str: return _gen_function(node, indent, is_expression=True) -def _gen_arrow(node, indent): +def _gen_arrow(node: dict, indent: int) -> str: params = node.get('params', []) async_prefix = 'async ' if node.get('async') else '' - param_str = '(' + ', '.join(generate(p, indent) for p in params) + ')' + param_str = '(' + ', '.join(generate(param, indent) for param in params) + ')' body = node.get('body', {}) body_str = generate(body, indent) # Wrap object literal in parens to avoid ambiguity with block @@ -180,14 +181,14 @@ def _gen_arrow(node, indent): return f'{async_prefix}{param_str} => {body_str}' -def _gen_return(node, indent): +def _gen_return(node: dict, indent: int) -> str: argument = node.get('argument') if argument: return f'return {generate(argument, indent)}' return 'return' -def _gen_if(node, indent): +def _gen_if(node: dict, indent: int) -> str: test = generate(node['test'], indent) consequent_code = generate(node['consequent'], indent) if node['consequent'].get('type') != 'BlockStatement': @@ -202,19 +203,19 @@ def _gen_if(node, indent): return f'if ({test}) {consequent_code}' -def _gen_while(node, indent): +def _gen_while(node: dict, indent: int) -> str: test = generate(node['test'], indent) body = generate(node['body'], indent) return f'while ({test}) {body}' -def _gen_do_while(node, indent): +def _gen_do_while(node: dict, indent: int) -> str: body = generate(node['body'], indent) test = generate(node['test'], indent) return f'do {body} while ({test})' -def _gen_for(node, indent): +def _gen_for(node: dict, indent: int) -> str: init = '' if node.get('init'): init = generate(node['init'], indent) @@ -224,21 +225,21 @@ def _gen_for(node, indent): return f'for ({init}; {test}; {update}) {body}' -def _gen_for_in(node, indent): +def _gen_for_in(node: dict, indent: int) -> str: left = generate(node['left'], indent) right = generate(node['right'], indent) body = generate(node['body'], indent) return f'for ({left} in {right}) {body}' -def _gen_for_of(node, indent): +def _gen_for_of(node: dict, indent: int) -> str: left = generate(node['left'], indent) right = generate(node['right'], indent) body = generate(node['body'], indent) return f'for ({left} of {right}) {body}' -def _gen_switch(node, indent): +def _gen_switch(node: dict, indent: int) -> str: discriminant = generate(node['discriminant'], indent) lines = [f'switch ({discriminant}) {{'] for case in node.get('cases', []): @@ -252,15 +253,15 @@ def _gen_switch(node, indent): return '\n'.join(lines) -def _gen_try(node, indent): +def _gen_try(node: dict, indent: int) -> str: block = generate(node['block'], indent) result = f'try {block}' handler = node.get('handler') if handler: - param = generate(handler.get('param'), indent) if handler.get('param') else '' + catch_param = generate(handler.get('param'), indent) if handler.get('param') else '' handler_body = generate(handler['body'], indent) - if param: - result += f' catch ({param}) {handler_body}' + if catch_param: + result += f' catch ({catch_param}) {handler_body}' else: result += f' catch {handler_body}' finalizer = node.get('finalizer') @@ -269,33 +270,33 @@ def _gen_try(node, indent): return result -def _gen_throw(node, indent): +def _gen_throw(node: dict, indent: int) -> str: return f'throw {generate(node["argument"], indent)}' -def _gen_break(node, indent): +def _gen_break(node: dict, indent: int) -> str: if node.get('label'): return f'break {generate(node["label"], indent)}' return 'break' -def _gen_continue(node, indent): +def _gen_continue(node: dict, indent: int) -> str: if node.get('label'): return f'continue {generate(node["label"], indent)}' return 'continue' -def _gen_labeled(node, indent): +def _gen_labeled(node: dict, indent: int) -> str: label = generate(node['label'], indent) body = _gen_stmt(node['body'], indent) return f'{label}:\n{body}' -def _gen_expr_stmt(node, indent): +def _gen_expr_stmt(node: dict, indent: int) -> str: return generate(node['expression'], indent) -def _gen_binary(node, indent): +def _gen_binary(node: dict, indent: int) -> str: operator = node.get('operator', '') left = generate(node['left'], indent) right = generate(node['right'], indent) @@ -309,11 +310,11 @@ def _gen_binary(node, indent): return f'{left} {operator} {right}' -def _gen_logical(node, indent): +def _gen_logical(node: dict, indent: int) -> str: return _gen_binary(node, indent) -def _gen_unary(node, indent): +def _gen_unary(node: dict, indent: int) -> str: operator = node.get('operator', '') operand = generate(node['argument'], indent) operand_prec = _expr_precedence(node['argument']) @@ -326,7 +327,7 @@ def _gen_unary(node, indent): return f'{operand}{operator}' -def _gen_update(node, indent): +def _gen_update(node: dict, indent: int) -> str: argument = generate(node['argument'], indent) operator = node.get('operator', '++') if node.get('prefix'): @@ -334,14 +335,14 @@ def _gen_update(node, indent): return f'{argument}{operator}' -def _gen_assignment(node, indent): +def _gen_assignment(node: dict, indent: int) -> str: left = generate(node['left'], indent) right = generate(node['right'], indent) operator = node.get('operator', '=') return f'{left} {operator} {right}' -def _gen_member(node, indent): +def _gen_member(node: dict, indent: int) -> str: object_code = generate(node['object'], indent) obj_type = node['object'].get('type', '') computed = node.get('computed') @@ -371,56 +372,56 @@ def _gen_member(node, indent): return f'{object_code}{dot}{property_code}' -def _gen_call(node, indent): +def _gen_call(node: dict, indent: int) -> str: callee = generate(node['callee'], indent) callee_type = node['callee'].get('type', '') if callee_type in ('FunctionExpression', 'ArrowFunctionExpression', 'SequenceExpression'): callee = f'({callee})' - args = ', '.join(generate(a, indent) for a in node.get('arguments', [])) + args = ', '.join(generate(argument, indent) for argument in node.get('arguments', [])) if node.get('optional'): return f'{callee}?.({args})' return f'{callee}({args})' -def _gen_new(node, indent): +def _gen_new(node: dict, indent: int) -> str: callee = generate(node['callee'], indent) args = node.get('arguments', []) if args: - arg_str = ', '.join(generate(a, indent) for a in args) + arg_str = ', '.join(generate(argument, indent) for argument in args) return f'new {callee}({arg_str})' return f'new {callee}()' -def _wrap_if_sequence(node, code): +def _wrap_if_sequence(node: dict | None, code: str) -> str: """Wrap code in parens if node is a SequenceExpression.""" if isinstance(node, dict) and node.get('type') == 'SequenceExpression': return f'({code})' return code -def _gen_conditional(node, indent): +def _gen_conditional(node: dict, indent: int) -> str: test = generate(node['test'], indent) consequent_code = _wrap_if_sequence(node.get('consequent'), generate(node['consequent'], indent)) alternate_code = _wrap_if_sequence(node.get('alternate'), generate(node['alternate'], indent)) return f'{test} ? {consequent_code} : {alternate_code}' -def _gen_sequence(node, indent): - exprs = ', '.join(generate(e, indent) for e in node.get('expressions', [])) +def _gen_sequence(node: dict, indent: int) -> str: + exprs = ', '.join(generate(expression, indent) for expression in node.get('expressions', [])) return exprs -def _gen_bracket_list(elements, indent): +def _gen_bracket_list(elements: list, indent: int) -> str: """Generate a bracketed list of elements, replacing None with empty slots.""" - elems = [generate(e, indent) if e is not None else '' for e in elements] + elems = [generate(element, indent) if element is not None else '' for element in elements] return '[' + ', '.join(elems) + ']' -def _gen_array(node, indent): +def _gen_array(node: dict, indent: int) -> str: return _gen_bracket_list(node.get('elements', []), indent) -def _gen_object_property(property_node, indent): +def _gen_object_property(property_node: dict, indent: int) -> str: """Generate a single object property string.""" if property_node.get('type') == 'SpreadElement': return '...' + generate(property_node['argument'], indent) @@ -432,7 +433,7 @@ def _gen_object_property(property_node, indent): kind = property_node.get('kind', 'init') if kind in ('get', 'set') or property_node.get('method'): prefix = f'{kind} ' if kind in ('get', 'set') else '' - params = ', '.join(generate(pp, indent) for pp in property_node['value'].get('params', [])) + params = ', '.join(generate(param, indent) for param in property_node['value'].get('params', [])) body = generate(property_node['value'].get('body'), indent) return f'{prefix}{key}({params}) {body}' @@ -443,7 +444,7 @@ def _gen_object_property(property_node, indent): return f'{key}: {value}' -def _gen_object(node, indent): +def _gen_object(node: dict, indent: int) -> str: properties = node.get('properties', []) if not properties: return '{}' @@ -454,23 +455,23 @@ def _gen_object(node, indent): return '{\n' + lines + '\n' + outer_indent + '}' -def _gen_property(node, indent): +def _gen_property(node: dict, indent: int) -> str: key = generate(node['key'], indent) value = generate(node['value'], indent) return f'{key}: {value}' -def _gen_spread(node, indent): +def _gen_spread(node: dict, indent: int) -> str: return '...' + generate(node['argument'], indent) -def _escape_string(val, raw): +def _escape_string(string_value: str, raw: str | None) -> str: """Escape a string value and wrap in the appropriate quotes.""" if raw and len(raw) >= 2 and raw[0] in ('"', "'"): quote = raw[0] else: quote = '"' - escaped = val.replace('\\', '\\\\') + escaped = string_value.replace('\\', '\\\\') escaped = escaped.replace('\n', '\\n') escaped = escaped.replace('\r', '\\r') escaped = escaped.replace('\t', '\\t') @@ -478,7 +479,7 @@ def _escape_string(val, raw): return f'{quote}{escaped}{quote}' -def _gen_literal(node, indent): +def _gen_literal(node: dict, indent: int) -> str: raw = node.get('raw') value = node.get('value') if isinstance(value, str): @@ -498,37 +499,37 @@ def _gen_literal(node, indent): return str(value) -def _gen_identifier(node, indent): +def _gen_identifier(node: dict, indent: int) -> str: return node.get('name', '') -def _gen_this(node, indent): +def _gen_this(node: dict, indent: int) -> str: return 'this' -def _gen_empty(node, indent): +def _gen_empty(node: dict, indent: int) -> str: return ';' -def _gen_template_literal(node, indent): +def _gen_template_literal(node: dict, indent: int) -> str: quasis = node.get('quasis', []) - exprs = node.get('expressions', []) + expressions = node.get('expressions', []) parts = [] - for i, quasi in enumerate(quasis): + for index, quasi in enumerate(quasis): raw = quasi.get('value', {}).get('raw', '') parts.append(raw) - if i < len(exprs): - parts.append('${' + generate(exprs[i], indent) + '}') + if index < len(expressions): + parts.append('${' + generate(expressions[index], indent) + '}') return '`' + ''.join(parts) + '`' -def _gen_tagged_template(node, indent): +def _gen_tagged_template(node: dict, indent: int) -> str: tag = generate(node['tag'], indent) quasi = generate(node['quasi'], indent) return f'{tag}{quasi}' -def _gen_class_decl(node, indent): +def _gen_class_decl(node: dict, indent: int) -> str: name = generate(node['id'], indent) if node.get('id') else '' superclass_clause = '' if node.get('superClass'): @@ -539,7 +540,7 @@ def _gen_class_decl(node, indent): return f'class{superclass_clause} {body}' -def _gen_class_body(node, indent): +def _gen_class_body(node: dict, indent: int) -> str: if not node.get('body'): return '{}' lines = ['{'] @@ -549,29 +550,29 @@ def _gen_class_body(node, indent): return '\n'.join(lines) -def _gen_method_def(node, indent): +def _gen_method_def(node: dict, indent: int) -> str: key = generate(node['key'], indent) if node.get('computed') or node['key'].get('type') == 'Literal': key = f'[{key}]' - static = 'static ' if node.get('static') else '' + static_prefix = 'static ' if node.get('static') else '' kind = node.get('kind', 'method') value = node.get('value', {}) - params = ', '.join(generate(p, indent) for p in value.get('params', [])) + params = ', '.join(generate(param, indent) for param in value.get('params', [])) body = generate(value.get('body'), indent) match kind: case 'constructor': - return f'{static}constructor({params}) {body}' + return f'{static_prefix}constructor({params}) {body}' case 'get': - return f'{static}get {key}({params}) {body}' + return f'{static_prefix}get {key}({params}) {body}' case 'set': - return f'{static}set {key}({params}) {body}' + return f'{static_prefix}set {key}({params}) {body}' case _: async_prefix = 'async ' if value.get('async') else '' gen_prefix = '*' if value.get('generator') else '' - return f'{static}{async_prefix}{gen_prefix}{key}({params}) {body}' + return f'{static_prefix}{async_prefix}{gen_prefix}{key}({params}) {body}' -def _gen_yield(node, indent): +def _gen_yield(node: dict, indent: int) -> str: argument = generate(node.get('argument'), indent) if node.get('argument') else '' delegate = '*' if node.get('delegate') else '' if argument: @@ -579,21 +580,21 @@ def _gen_yield(node, indent): return f'yield{delegate}' -def _gen_await(node, indent): +def _gen_await(node: dict, indent: int) -> str: return f'await {generate(node["argument"], indent)}' -def _gen_assignment_pattern(node, indent): +def _gen_assignment_pattern(node: dict, indent: int) -> str: left = generate(node['left'], indent) right = generate(node['right'], indent) return f'{left} = {right}' -def _gen_array_pattern(node, indent): +def _gen_array_pattern(node: dict, indent: int) -> str: return _gen_bracket_list(node.get('elements', []), indent) -def _gen_object_pattern_part(property_node, indent): +def _gen_object_pattern_part(property_node: dict, indent: int) -> str: """Generate a single destructuring pattern property.""" if property_node.get('type') == 'RestElement': return '...' + generate(property_node['argument'], indent) @@ -604,7 +605,7 @@ def _gen_object_pattern_part(property_node, indent): return f'{key}: {value}' -def _gen_object_pattern(node, indent): +def _gen_object_pattern(node: dict, indent: int) -> str: properties = [_gen_object_pattern_part(property_node, indent + 1) for property_node in node.get('properties', [])] if not properties: return '{}' @@ -614,75 +615,75 @@ def _gen_object_pattern(node, indent): return '{\n' + lines + '\n' + outer_indent + '}' -def _gen_rest_element(node, indent): +def _gen_rest_element(node: dict, indent: int) -> str: return '...' + generate(node['argument'], indent) -def _gen_import_specifier(spec, indent): +def _gen_import_specifier(specifier: dict, indent: int) -> str: """Generate a single import specifier.""" - spec_type = spec.get('type', '') - if spec_type == 'ImportDefaultSpecifier': - return generate(spec['local'], indent) - if spec_type == 'ImportNamespaceSpecifier': - return '* as ' + generate(spec['local'], indent) + specifier_type = specifier.get('type', '') + if specifier_type == 'ImportDefaultSpecifier': + return generate(specifier['local'], indent) + if specifier_type == 'ImportNamespaceSpecifier': + return '* as ' + generate(specifier['local'], indent) # ImportSpecifier - imported = generate(spec['imported'], indent) - local = generate(spec['local'], indent) + imported = generate(specifier['imported'], indent) + local = generate(specifier['local'], indent) if imported == local: return imported return f'{imported} as {local}' -def _gen_import_declaration(node, indent): +def _gen_import_declaration(node: dict, indent: int) -> str: source = generate(node['source'], indent) specifiers = node.get('specifiers', []) if not specifiers: return f'import {source}' - default_specs = [s for s in specifiers if s.get('type') == 'ImportDefaultSpecifier'] - namespace_specs = [s for s in specifiers if s.get('type') == 'ImportNamespaceSpecifier'] - named_specs = [s for s in specifiers if s.get('type') == 'ImportSpecifier'] + default_specifiers = [s for s in specifiers if s.get('type') == 'ImportDefaultSpecifier'] + namespace_specifiers = [s for s in specifiers if s.get('type') == 'ImportNamespaceSpecifier'] + named_specifiers = [s for s in specifiers if s.get('type') == 'ImportSpecifier'] parts = [] - if default_specs: - parts.append(_gen_import_specifier(default_specs[0], indent)) - if namespace_specs: - parts.append(_gen_import_specifier(namespace_specs[0], indent)) - if named_specs: - names = ', '.join(_gen_import_specifier(s, indent) for s in named_specs) + if default_specifiers: + parts.append(_gen_import_specifier(default_specifiers[0], indent)) + if namespace_specifiers: + parts.append(_gen_import_specifier(namespace_specifiers[0], indent)) + if named_specifiers: + names = ', '.join(_gen_import_specifier(specifier, indent) for specifier in named_specifiers) parts.append('{' + names + '}') return f'import {", ".join(parts)} from {source}' -def _gen_export_specifier(spec, indent): - exported = generate(spec['exported'], indent) - local = generate(spec['local'], indent) +def _gen_export_specifier(specifier: dict, indent: int) -> str: + exported = generate(specifier['exported'], indent) + local = generate(specifier['local'], indent) if exported == local: return exported return f'{local} as {exported}' -def _gen_export_named(node, indent): +def _gen_export_named(node: dict, indent: int) -> str: declaration = node.get('declaration') if declaration: return f'export {generate(declaration, indent)}' specifiers = node.get('specifiers', []) - names = ', '.join(_gen_export_specifier(s, indent) for s in specifiers) + names = ', '.join(_gen_export_specifier(specifier, indent) for specifier in specifiers) source = node.get('source') if source: return f'export {{{names}}} from {generate(source, indent)}' return f'export {{{names}}}' -def _gen_export_default(node, indent): +def _gen_export_default(node: dict, indent: int) -> str: declaration = node.get('declaration', {}) return f'export default {generate(declaration, indent)}' -def _gen_export_all(node, indent): +def _gen_export_all(node: dict, indent: int) -> str: source = generate(node['source'], indent) return f'export * from {source}' -def _expr_precedence(node): +def _expr_precedence(node: dict) -> int: """Get the precedence level of an expression node.""" if not isinstance(node, dict): return 20 diff --git a/pyjsclear/parser.py b/pyjsclear/parser.py index 049a87f..f794e72 100644 --- a/pyjsclear/parser.py +++ b/pyjsclear/parser.py @@ -8,7 +8,7 @@ _ASYNC_MAP = {'isAsync': 'async', 'allowAwait': 'await'} -def _fast_to_dict(obj): +def _fast_to_dict(obj: object) -> object: """Convert esprima AST objects to plain dicts, ~2x faster than toDict().""" if isinstance(obj, (str, int, float, bool, type(None))): return obj @@ -29,7 +29,7 @@ def _fast_to_dict(obj): return output -def parse(code): +def parse(code: str) -> dict: """Parse JavaScript code into an ESTree-compatible AST. Returns a Program node (dict). diff --git a/pyjsclear/scope.py b/pyjsclear/scope.py index f4ee9f5..7b97a3a 100644 --- a/pyjsclear/scope.py +++ b/pyjsclear/scope.py @@ -1,5 +1,8 @@ """Variable scope and binding analysis for ESTree ASTs.""" +from collections.abc import Callable +from typing import Any + from .utils.ast_helpers import _CHILD_KEYS from .utils.ast_helpers import get_child_keys @@ -9,16 +12,16 @@ class Binding: __slots__ = ('name', 'node', 'kind', 'scope', 'references', 'assignments') - def __init__(self, name, node, kind, scope): + def __init__(self, name: str, node: dict, kind: str, scope: 'Scope') -> None: self.name = name self.node = node # The declaration node self.kind = kind # 'var', 'let', 'const', 'function', 'param' self.scope = scope - self.references = [] # List of (node, parent, key, index) where name is referenced - self.assignments = [] # List of assignment nodes + self.references: list = [] # List of (node, parent, key, index) where name is referenced + self.assignments: list = [] # List of assignment nodes @property - def is_constant(self): + def is_constant(self) -> bool: """True if the binding is never reassigned after declaration.""" if self.kind == 'const': return True @@ -33,21 +36,21 @@ class Scope: __slots__ = ('parent', 'node', 'bindings', 'children', 'is_function') - def __init__(self, parent, node, is_function=False): + def __init__(self, parent: 'Scope | None', node: dict, is_function: bool = False) -> None: self.parent = parent self.node = node - self.bindings = {} # name -> Binding - self.children = [] + self.bindings: dict[str, Binding] = {} # name -> Binding + self.children: list['Scope'] = [] self.is_function = is_function if parent: parent.children.append(self) - def add_binding(self, name, node, kind): + def add_binding(self, name: str, node: dict, kind: str) -> Binding: binding = Binding(name, node, kind, self) self.bindings[name] = binding return binding - def get_binding(self, name): + def get_binding(self, name: str) -> Binding | None: """Look up a binding, walking up the scope chain.""" if name in self.bindings: return self.bindings[name] @@ -55,18 +58,18 @@ def get_binding(self, name): return self.parent.get_binding(name) return None - def get_own_binding(self, name): + def get_own_binding(self, name: str) -> Binding | None: return self.bindings.get(name) -def _nearest_function_scope(scope): +def _nearest_function_scope(scope: Scope | None) -> Scope | None: """Walk up to the nearest function (or root) scope.""" while scope and not scope.is_function: scope = scope.parent return scope -def _is_non_reference_identifier(parent, parent_key): +def _is_non_reference_identifier(parent: dict | None, parent_key: str | None) -> bool: """Return True if this Identifier usage is not a variable reference.""" if not parent: return False @@ -84,7 +87,9 @@ def _is_non_reference_identifier(parent, parent_key): return False -def _recurse_into_children(node, child_keys_map, callback): +def _recurse_into_children( + node: dict, child_keys_map: dict, callback: Callable[[dict], Any] +) -> None: """Walk child nodes, calling callback(child_node) for each dict with 'type'.""" node_type = node.get('type') child_keys = child_keys_map.get(node_type) @@ -102,18 +107,18 @@ def _recurse_into_children(node, child_keys_map, callback): callback(child) -def build_scope_tree(ast): +def build_scope_tree(ast: dict) -> tuple[Scope, dict[int, Scope]]: """Build a scope tree from an AST, collecting bindings and references. Returns the root Scope and a dict mapping node id -> Scope. """ root_scope = Scope(None, ast, is_function=True) # Maps id(node) -> scope for function/block scope nodes - node_scope = {id(ast): root_scope} + node_scope: dict[int, Scope] = {id(ast): root_scope} # We need to collect all declarations first, then references - all_scopes = [root_scope] + all_scopes: list[Scope] = [root_scope] - def _get_scope_for(node, current_scope): + def _get_scope_for(node: dict, current_scope: Scope) -> Scope: """Get or create the scope for a node.""" node_id = id(node) if node_id in node_scope: @@ -122,7 +127,7 @@ def _get_scope_for(node, current_scope): _child_keys_map = _CHILD_KEYS - def _collect_declarations(node, scope): + def _collect_declarations(node: dict, scope: Scope) -> None: """Walk the AST collecting variable declarations into scopes.""" if not isinstance(node, dict): return @@ -229,7 +234,7 @@ def _collect_declarations(node, scope): node, _child_keys_map, lambda child_node: _collect_declarations(child_node, scope) ) - def _collect_pattern_names(pattern, scope, kind, declaration): + def _collect_pattern_names(pattern: dict | None, scope: Scope, kind: str, declaration: dict) -> None: """Collect binding names from destructuring patterns.""" if not isinstance(pattern, dict): return @@ -263,7 +268,13 @@ def _collect_pattern_names(pattern, scope, kind, declaration): _collect_declarations(ast, root_scope) # Second pass: collect references and assignments - def _collect_references(node, scope, parent=None, parent_key=None, parent_index=None): + def _collect_references( + node: dict, + scope: Scope, + parent: dict | None = None, + parent_key: str | None = None, + parent_index: int | None = None, + ) -> None: if not isinstance(node, dict): return node_type = node.get('type') @@ -298,9 +309,9 @@ def _collect_references(node, scope, parent=None, parent_key=None, parent_index= if child is None: continue if isinstance(child, list): - for i, item in enumerate(child): + for child_index, item in enumerate(child): if isinstance(item, dict) and 'type' in item: - _collect_references(item, scope, node, key, i) + _collect_references(item, scope, node, key, child_index) elif isinstance(child, dict) and 'type' in child: _collect_references(child, scope, node, key, None) diff --git a/pyjsclear/transforms/aa_decode.py b/pyjsclear/transforms/aa_decode.py index 6f8d042..cec51a3 100644 --- a/pyjsclear/transforms/aa_decode.py +++ b/pyjsclear/transforms/aa_decode.py @@ -42,7 +42,7 @@ ] -def is_aa_encoded(code): +def is_aa_encoded(code: str) -> bool: """Check if *code* looks like AAEncoded JavaScript. Returns True when the characteristic execution pattern is found. @@ -52,7 +52,7 @@ def is_aa_encoded(code): return _SIGNATURE in code -def aa_decode(code): +def aa_decode(code: str) -> str | None: """Decode AAEncoded JavaScript. Returns the decoded source string, or ``None`` on any failure. @@ -72,7 +72,7 @@ def aa_decode(code): # --------------------------------------------------------------------------- -def _decode_impl(code): +def _decode_impl(code: str) -> str | None: """Core decoding logic.""" # 1. Isolate the data section. # AAEncode wraps data inside an execution pattern. The encoded payload @@ -84,8 +84,8 @@ def _decode_impl(code): # Find the data region: everything after the initial variable setup and # before the trailing execution portion. # The data starts at the first separator token. - sep_idx = code.find(_SEPARATOR) - if sep_idx == -1: + separator_index = code.find(_SEPARATOR) + if separator_index == -1: return None # The trailing execution wrapper varies but typically looks like: @@ -95,16 +95,16 @@ def _decode_impl(code): "(\uff9f\u0414\uff9f)['_']", '(\uff9f\u0414\uff9f)["_"]', ] - data = code[sep_idx:] - for pat in tail_patterns: - tail_pos = data.rfind(pat) - if tail_pos != -1: - data = data[:tail_pos] + data = code[separator_index:] + for tail_pattern in tail_patterns: + tail_position = data.rfind(tail_pattern) + if tail_position != -1: + data = data[:tail_position] break # 2. Apply emoticon-to-digit replacements. - for old, new in _REPLACEMENTS: - data = data.replace(old, new) + for original, replacement in _REPLACEMENTS: + data = data.replace(original, replacement) # 3. Split on the separator to get individual character segments. segments = data.split(_SEPARATOR) diff --git a/pyjsclear/transforms/anti_tamper.py b/pyjsclear/transforms/anti_tamper.py index 7781461..13d027d 100644 --- a/pyjsclear/transforms/anti_tamper.py +++ b/pyjsclear/transforms/anti_tamper.py @@ -38,7 +38,7 @@ class AntiTamperRemover(Transform): ] @staticmethod - def _extract_iife_call(expr): + def _extract_iife_call(expr: dict) -> dict | None: """Extract a CallExpression from an IIFE pattern.""" if expr.get('type') == 'CallExpression': return expr @@ -46,23 +46,20 @@ def _extract_iife_call(expr): return expr.get('argument') return None - def _matches_anti_tamper_pattern(self, src): + def _matches_anti_tamper_pattern(self, source: str) -> bool: """Check if source matches any anti-tamper pattern.""" - for pattern in self._SELF_DEFENDING_PATTERNS: - if pattern.search(src): - return True - if any(p.search(src) for p in self._DEBUG_PATTERNS): - if re.search(r'\bdebugger\b', src) and (re.search(r'\bwhile\b|\bfor\b|\bsetInterval\b', src)): - return True - for pattern in self._CONSOLE_PATTERNS: - if pattern.search(src): - return True + if any(pattern.search(source) for pattern in self._SELF_DEFENDING_PATTERNS): + return True + if re.search(r'\bdebugger\b', source) and re.search(r'\bwhile\b|\bfor\b|\bsetInterval\b', source): + return True + if any(pattern.search(source) for pattern in self._CONSOLE_PATTERNS): + return True return False - def execute(self): + def execute(self) -> bool: nodes_to_remove = [] - def enter(node, parent, key, index): + def enter(node: dict, parent: dict, key: str, index: int | None) -> None: if node.get('type') != 'ExpressionStatement': return expr = node.get('expression') @@ -94,9 +91,9 @@ def enter(node, parent, key, index): # Remove flagged nodes if nodes_to_remove: - remove_set = set(id(n) for n in nodes_to_remove) + remove_set = {id(node) for node in nodes_to_remove} - def remover(node, parent, key, index): + def remover(node: dict, parent: dict, key: str, index: int | None) -> object | None: if id(node) in remove_set: self.set_changed() return REMOVE diff --git a/pyjsclear/transforms/base.py b/pyjsclear/transforms/base.py index 35219dc..594b264 100644 --- a/pyjsclear/transforms/base.py +++ b/pyjsclear/transforms/base.py @@ -1,5 +1,12 @@ """Base transform class.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..scope import Scope + class Transform: """Base class for all AST transforms.""" @@ -7,18 +14,23 @@ class Transform: # Subclasses can set this to True to trigger scope rebuild after execution rebuild_scope = False - def __init__(self, ast, scope_tree=None, node_scope=None): + def __init__( + self, + ast: dict, + scope_tree: Scope | None = None, + node_scope: dict[int, Scope] | None = None, + ) -> None: self.ast = ast self.scope_tree = scope_tree self.node_scope = node_scope self._changed = False - def execute(self): + def execute(self) -> bool: """Execute the transform. Returns True if the AST was modified.""" raise NotImplementedError - def set_changed(self): + def set_changed(self) -> None: self._changed = True - def has_changed(self): + def has_changed(self) -> bool: return self._changed diff --git a/pyjsclear/transforms/class_static_resolver.py b/pyjsclear/transforms/class_static_resolver.py index 78bca39..75cb1ce 100644 --- a/pyjsclear/transforms/class_static_resolver.py +++ b/pyjsclear/transforms/class_static_resolver.py @@ -24,7 +24,7 @@ class ClassStaticResolver(Transform): """Inline class static constant properties and identity methods.""" - def execute(self): + def execute(self) -> bool: # Step 1: Find class variables (var X = class { ... }) class_vars = {} # name -> ClassExpression node @@ -146,7 +146,7 @@ def enter(node, parent, key, index): traverse(self.ast, {'enter': enter}) return self.has_changed() - def _get_prop_name(self, member_expr): + def _get_prop_name(self, member_expr: dict) -> str | None: """Get the property name from a MemberExpression.""" prop = member_expr.get('property') if not prop: @@ -159,7 +159,7 @@ def _get_prop_name(self, member_expr): return prop['name'] return None - def _is_identity_function(self, func_node): + def _is_identity_function(self, func_node: dict) -> bool: """Check if a function simply returns its first argument.""" params = func_node.get('params', []) if len(params) != 1: @@ -173,12 +173,12 @@ def _is_identity_function(self, func_node): stmts = body.get('body', []) if len(stmts) != 1 or stmts[0].get('type') != 'ReturnStatement': return False - arg = stmts[0].get('argument') - if not arg or not is_identifier(arg): + return_argument = stmts[0].get('argument') + if not return_argument or not is_identifier(return_argument): return False - return arg['name'] == param['name'] + return return_argument['name'] == param['name'] - def _try_inline_identity(self, member_expr, method_node): + def _try_inline_identity(self, member_expr: dict, method_node: dict) -> None: """Inline Class.identity(arg) → arg.""" result = find_parent(self.ast, member_expr) if not result: @@ -194,11 +194,11 @@ def _try_inline_identity(self, member_expr, method_node): grandparent_result = find_parent(self.ast, call_parent) if not grandparent_result: return - gp, gp_key, gp_index = grandparent_result - self._replace_in_parent(call_parent, replacement, gp, gp_key, gp_index) + grandparent, grandparent_key, grandparent_index = grandparent_result + self._replace_in_parent(call_parent, replacement, grandparent, grandparent_key, grandparent_index) self.set_changed() - def _replace_in_parent(self, target, replacement, parent, key, index): + def _replace_in_parent(self, target: dict, replacement: dict, parent: dict, key: str, index: int | None) -> None: """Replace target node in the AST using known parent info.""" if index is not None: parent[key][index] = replacement diff --git a/pyjsclear/transforms/class_string_decoder.py b/pyjsclear/transforms/class_string_decoder.py index 42e9961..5061b34 100644 --- a/pyjsclear/transforms/class_string_decoder.py +++ b/pyjsclear/transforms/class_string_decoder.py @@ -33,9 +33,9 @@ class ClassStringDecoder(Transform): """Resolve class-based string encoder patterns.""" - def execute(self): - class_props = {} - decoders = {} + def execute(self) -> bool: + class_props: dict = {} + decoders: dict = {} self._collect_class_props(class_props) self._find_decoders(class_props, decoders) @@ -48,7 +48,7 @@ def execute(self): self._resolve_calls(decoders) return self.has_changed() - def _collect_class_props(self, class_props): + def _collect_class_props(self, class_props: dict) -> None: """Collect static property assignments on class variables. Builds: class_props[var_name] = {prop_name: value, ...} @@ -56,7 +56,7 @@ def _collect_class_props(self, class_props): Handles assignments in ExpressionStatements and SequenceExpressions. """ - def visit(node, parent): + def visit(node: dict, parent: dict) -> None: if node.get('type') != 'AssignmentExpression': return if node.get('operator') != '=': @@ -78,24 +78,24 @@ def visit(node, parent): simple_traverse(self.ast, visit) - def _resolve_array(self, class_props, var_name, elements): + def _resolve_array(self, class_props: dict, var_name: str, elements: list) -> list | None: """Resolve an array of MemberExpression references to string values.""" props = class_props.get(var_name, {}) resolved = [] - for el in elements: - el_obj, el_prop = get_member_names(el) - if not el_obj or el_obj != var_name: + for element in elements: + element_object, element_property = get_member_names(element) + if not element_object or element_object != var_name: return None - value = props.get(el_prop) + value = props.get(element_property) if not isinstance(value, str): return None resolved.append(value) return resolved - def _find_decoders(self, class_props, decoders): + def _find_decoders(self, class_props: dict, decoders: dict) -> None: """Find decoder methods and their associated lookup tables.""" - def visit(node, parent): + def visit(node: dict, parent: dict) -> None: if node.get('type') != 'MethodDefinition': return if not node.get('static'): @@ -111,21 +111,21 @@ def visit(node, parent): case _: return - func = node.get('value') - if not func or func.get('type') != 'FunctionExpression': + function_node = node.get('value') + if not function_node or function_node.get('type') != 'FunctionExpression': return - params = func.get('params', []) + params = function_node.get('params', []) if len(params) != 1: return - body = func.get('body') + body = function_node.get('body') if not body or body.get('type') != 'BlockStatement': return - stmts = body.get('body', []) - if len(stmts) < 3: + statements = body.get('body', []) + if len(statements) < 3: return - table_info = self._extract_decoder_table(stmts, class_props) + table_info = self._extract_decoder_table(statements, class_props) if not table_info: return @@ -139,12 +139,12 @@ def visit(node, parent): simple_traverse(self.ast, visit) - def _resolve_aliases(self, decoders): + def _resolve_aliases(self, decoders: dict) -> None: """Find identifier aliases (X = Y) where Y is a decoder class, and register X too.""" decoder_classes = {cls for cls, _ in decoders} new_entries = {} - def visit(node, parent): + def visit(node: dict, parent: dict) -> None: if node.get('type') != 'AssignmentExpression': return if node.get('operator') != '=': @@ -164,22 +164,23 @@ def visit(node, parent): simple_traverse(self.ast, visit) decoders.update(new_entries) - def _extract_decoder_table(self, stmts, class_props): + def _extract_decoder_table(self, statements: list, class_props: dict) -> tuple | None: """Extract the lookup table and offset from decoder method body.""" table_class_var = None table_prop = None - for stmt in stmts: - if stmt.get('type') != 'VariableDeclaration': + for statement in statements: + if statement.get('type') != 'VariableDeclaration': continue - for decl in stmt.get('declarations', []): - init = decl.get('init') + for declaration in statement.get('declarations', []): + init = declaration.get('init') obj_name, prop_name = get_member_names(init) - if obj_name and prop_name: - decl_id = decl.get('id') - if decl_id and decl_id.get('type') == 'Identifier': - table_class_var = obj_name - table_prop = prop_name + if not obj_name or not prop_name: + continue + declaration_id = declaration.get('id') + if declaration_id and declaration_id.get('type') == 'Identifier': + table_class_var = obj_name + table_prop = prop_name if not table_class_var: return None @@ -193,14 +194,14 @@ def _extract_decoder_table(self, stmts, class_props): if not resolved: return None - offset = self._find_offset(stmts) + offset = self._find_offset(statements) return resolved, offset - def _find_offset(self, stmts): + def _find_offset(self, statements: list) -> int: """Find the subtraction offset in the decoder loop (e.g., - 48).""" offset = 48 - def scan(node, parent): + def scan(node: dict, parent: dict) -> None: nonlocal offset if not isinstance(node, dict): return @@ -211,17 +212,17 @@ def scan(node, parent): if isinstance(val, (int, float)) and val > 0: offset = int(val) - for stmt in stmts: - if stmt.get('type') == 'ForStatement': - simple_traverse(stmt, scan) + for statement in statements: + if statement.get('type') == 'ForStatement': + simple_traverse(statement, scan) return offset - def _find_enclosing_class_var(self, method_node): + def _find_enclosing_class_var(self, method_node: dict) -> str | None: """Find the variable name of the class containing this method.""" result = [None] - def _check_class_body(class_expr, var_name): + def _check_class_body(class_expr: dict, var_name: str) -> bool: body = class_expr.get('body') if body and body.get('type') == 'ClassBody': for member in body.get('body', []): @@ -230,15 +231,15 @@ def _check_class_body(class_expr, var_name): return True return False - def scan(node, parent): + def scan(node: dict, parent: dict) -> None: if result[0]: return if node.get('type') == 'VariableDeclarator': init = node.get('init') if init and init.get('type') == 'ClassExpression': - decl_id = node.get('id') - if decl_id and decl_id.get('type') == 'Identifier': - _check_class_body(init, decl_id['name']) + declaration_id = node.get('id') + if declaration_id and declaration_id.get('type') == 'Identifier': + _check_class_body(init, declaration_id['name']) elif node.get('type') == 'AssignmentExpression': right = node.get('right') if right and right.get('type') == 'ClassExpression': @@ -249,7 +250,7 @@ def scan(node, parent): simple_traverse(self.ast, scan) return result[0] - def _decode_call(self, lookup_table, offset, args): + def _decode_call(self, lookup_table: list, offset: int, args: list) -> str | None: """Statically evaluate a decoder call: decode([0x4f, 0x3a, ...]).""" if len(args) != 1: return None @@ -258,10 +259,10 @@ def _decode_call(self, lookup_table, offset, args): return None elements = arg.get('elements', []) result = '' - for el in elements: - if not is_numeric_literal(el): + for element in elements: + if not is_numeric_literal(element): return None - idx = int(el['value']) - offset + idx = int(element['value']) - offset if idx < 0 or idx >= len(lookup_table): return None entry = lookup_table[idx] @@ -270,34 +271,34 @@ def _decode_call(self, lookup_table, offset, args): result += entry[0] return result - def _resolve_calls(self, decoders): + def _resolve_calls(self, decoders: dict) -> None: """Replace all decoder calls with their decoded string literals.""" - decoded_constants = {} + decoded_constants: dict = {} - def enter(node, parent, key, index): + def enter(node: dict, parent: dict, key: str, index: int | None) -> dict | None: if node.get('type') != 'CallExpression': - return + return None callee = node.get('callee') obj_name, method_name = get_member_names(callee) if not obj_name: - return + return None decoder_key = (obj_name, method_name) if decoder_key not in decoders: - return + return None lookup_table, offset = decoders[decoder_key] decoded = self._decode_call(lookup_table, offset, node.get('arguments', [])) if decoded is None: - return + return None replacement = make_literal(decoded) # Track the assignment target so we can inline the constant later if parent and parent.get('type') == 'AssignmentExpression' and key == 'right': - lobj, lprop = get_member_names(parent.get('left')) - if lobj and lprop: - decoded_constants[(lobj, lprop)] = decoded + left_object, left_property = get_member_names(parent.get('left')) + if left_object and left_property: + decoded_constants[(left_object, left_property)] = decoded self.set_changed() return replacement @@ -307,23 +308,23 @@ def enter(node, parent, key, index): if decoded_constants: self._inline_decoded_constants(decoded_constants) - def _inline_decoded_constants(self, decoded_constants): + def _inline_decoded_constants(self, decoded_constants: dict) -> None: """Replace references like _0x279589["propName"] with the decoded string.""" - def enter(node, parent, key, index): + def enter(node: dict, parent: dict, key: str, index: int | None) -> dict | None: if node.get('type') != 'MemberExpression': - return + return None # Skip assignment targets if parent and parent.get('type') == 'AssignmentExpression' and key == 'left': - return + return None obj_name, prop_name = get_member_names(node) if not obj_name: - return + return None lookup_key = (obj_name, prop_name) if lookup_key not in decoded_constants: - return + return None decoded = decoded_constants[lookup_key] self.set_changed() diff --git a/pyjsclear/transforms/cleanup.py b/pyjsclear/transforms/cleanup.py index e79dd17..7294c74 100644 --- a/pyjsclear/transforms/cleanup.py +++ b/pyjsclear/transforms/cleanup.py @@ -22,9 +22,9 @@ class EmptyIfRemover(Transform): - ``if (expr) {} else { body }`` → ``if (!expr) { body }`` """ - def execute(self): + def execute(self) -> bool: - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object: if node.get('type') != 'IfStatement': return consequent = node.get('consequent') @@ -52,14 +52,14 @@ def enter(node, parent, key, index): return self.has_changed() @staticmethod - def _is_empty_block(node): + def _is_empty_block(node: object) -> bool: """Check if a node is an empty block statement ``{}``.""" if not isinstance(node, dict): return False if node.get('type') != 'BlockStatement': return False body = node.get('body') - return not body or len(body) == 0 + return not body class TrailingReturnRemover(Transform): @@ -71,20 +71,24 @@ class TrailingReturnRemover(Transform): _FUNC_TYPES = frozenset({'FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'}) - def execute(self): + def execute(self) -> bool: - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') not in self._FUNC_TYPES: return body = node.get('body') if not isinstance(body, dict) or body.get('type') != 'BlockStatement': return - stmts = body.get('body') - if not stmts or not isinstance(stmts, list): + statements = body.get('body') + if not statements or not isinstance(statements, list): return - last = stmts[-1] - if isinstance(last, dict) and last.get('type') == 'ReturnStatement' and last.get('argument') is None: - stmts.pop() + last_statement = statements[-1] + if ( + isinstance(last_statement, dict) + and last_statement.get('type') == 'ReturnStatement' + and last_statement.get('argument') is None + ): + statements.pop() self.set_changed() traverse(self.ast, {'enter': enter}) @@ -94,9 +98,9 @@ def enter(node, parent, key, index): class OptionalCatchBinding(Transform): """Remove unused catch clause parameters (ES2019 optional catch binding).""" - def execute(self): + def execute(self) -> bool: - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'CatchClause': return param = node.get('param') @@ -114,32 +118,33 @@ def enter(node, parent, key, index): traverse(self.ast, {'enter': enter}) return self.has_changed() - def _is_name_used(self, body, name): + def _is_name_used(self, body: dict, name: str) -> bool: """Check if an identifier name is used anywhere in the subtree.""" - found = [False] + found = False - def cb(node, parent): - if found[0]: + def callback(node: dict, parent: dict | None) -> None: + nonlocal found + if found: return if is_identifier(node) and node.get('name') == name: - found[0] = True + found = True - simple_traverse(body, cb) - return found[0] + simple_traverse(body, callback) + return found class ReturnUndefinedCleanup(Transform): """Simplify `return undefined;` to `return;`.""" - def execute(self): + def execute(self) -> bool: - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'ReturnStatement': return - arg = node.get('argument') - if not arg: + argument = node.get('argument') + if not argument: return - if is_identifier(arg) and arg.get('name') == 'undefined': + if is_identifier(argument) and argument.get('name') == 'undefined': node['argument'] = None self.set_changed() @@ -156,28 +161,28 @@ class LetToConst(Transform): - The binding has no assignments after declaration """ - def execute(self): + def execute(self) -> bool: scope_tree, _ = build_scope_tree(self.ast) - safe_declarators = set() + safe_declarators: set[int] = set() self._collect_let_const_candidates(scope_tree, safe_declarators) if not safe_declarators: return False - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'VariableDeclaration': return if node.get('kind') != 'let': return - decls = node.get('declarations', []) - if len(decls) == 1 and id(decls[0]) in safe_declarators: + declarations = node.get('declarations', []) + if len(declarations) == 1 and id(declarations[0]) in safe_declarators: node['kind'] = 'const' self.set_changed() traverse(self.ast, {'enter': enter}) return self.has_changed() - def _collect_let_const_candidates(self, scope, safe_declarators): + def _collect_let_const_candidates(self, scope: object, safe_declarators: set[int]) -> None: """Find let bindings that are never reassigned and have initializers.""" for name, binding in scope.bindings.items(): if binding.kind != 'let': @@ -206,19 +211,19 @@ class VarToConst(Transform): but const is block-scoped """ - def execute(self): + def execute(self) -> bool: scope_tree, _ = build_scope_tree(self.ast) - safe_declarators = set() + safe_declarators: set[int] = set() self._collect_const_candidates(scope_tree, safe_declarators, in_function=True) if not safe_declarators: return False # Track which BlockStatements are direct function bodies - func_body_ids = set() + func_body_ids: set[int] = set() self._collect_func_bodies(self.ast, func_body_ids) - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'VariableDeclaration': return if node.get('kind') != 'var': @@ -234,26 +239,26 @@ def enter(node, parent, key, index): return # Inside a nested block — unsafe else: return - decls = node.get('declarations', []) - if len(decls) == 1 and id(decls[0]) in safe_declarators: + declarations = node.get('declarations', []) + if len(declarations) == 1 and id(declarations[0]) in safe_declarators: node['kind'] = 'const' self.set_changed() traverse(self.ast, {'enter': enter}) return self.has_changed() - def _collect_func_bodies(self, ast, func_body_ids): + def _collect_func_bodies(self, ast: dict, func_body_ids: set[int]) -> None: """Collect ids of BlockStatements that are direct function bodies.""" - def cb(node, parent): + def callback(node: dict, parent: dict | None) -> None: if node.get('type') in ('FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'): body = node.get('body') if body and body.get('type') == 'BlockStatement': func_body_ids.add(id(body)) - simple_traverse(ast, cb) + simple_traverse(ast, callback) - def _collect_const_candidates(self, scope, safe_declarators, in_function=False): + def _collect_const_candidates(self, scope: object, safe_declarators: set[int], in_function: bool = False) -> None: """Find var bindings that are never reassigned and have initializers.""" if in_function: for name, binding in scope.bindings.items(): diff --git a/pyjsclear/transforms/constant_prop.py b/pyjsclear/transforms/constant_prop.py index d218ecf..943135d 100644 --- a/pyjsclear/transforms/constant_prop.py +++ b/pyjsclear/transforms/constant_prop.py @@ -1,5 +1,7 @@ """Constant propagation — replace references to constant variables with their literal values.""" +from ..scope import Binding +from ..scope import Scope from ..scope import build_scope_tree from ..traverser import REMOVE from ..traverser import SKIP @@ -9,16 +11,16 @@ from .base import Transform -def _should_skip_reference(ref_parent, ref_key): +def _should_skip_reference(reference_parent: dict | None, reference_key: str | None) -> bool: """Return True if this reference should not be replaced with its literal value.""" - if not ref_parent: + if not reference_parent: return True - match ref_parent.get('type'): - case 'AssignmentExpression' if ref_key == 'left': + match reference_parent.get('type'): + case 'AssignmentExpression' if reference_key == 'left': return True case 'UpdateExpression': return True - case 'VariableDeclarator' if ref_key == 'id': + case 'VariableDeclarator' if reference_key == 'id': return True return False @@ -28,7 +30,7 @@ class ConstantProp(Transform): rebuild_scope = True - def execute(self): + def execute(self) -> bool: scope_tree, node_scope = build_scope_tree(self.ast) replacements = dict(self._iter_constant_bindings(scope_tree)) @@ -39,7 +41,9 @@ def execute(self): self._remove_fully_propagated(replacements, bindings_replaced) return self.has_changed() - def _iter_constant_bindings(self, scope): + def _iter_constant_bindings( + self, scope: Scope + ) -> list[tuple[int, tuple[Binding, dict]]]: """Yield (binding_id, (binding, literal)) for constant bindings with literal values.""" for name, binding in scope.bindings.items(): if not binding.is_constant: @@ -47,56 +51,58 @@ def _iter_constant_bindings(self, scope): node = binding.node if not isinstance(node, dict) or node.get('type') != 'VariableDeclarator': continue - init_val = node.get('init') - if not init_val or not is_literal(init_val): + init_value = node.get('init') + if not init_value or not is_literal(init_value): continue - yield id(binding), (binding, init_val) + yield id(binding), (binding, init_value) for child in scope.children: yield from self._iter_constant_bindings(child) - def _replace_references(self, replacements): + def _replace_references(self, replacements: dict[int, tuple[Binding, dict]]) -> set[int]: """Replace all qualifying references with their literal values.""" bindings_replaced = set() - for bind_id, (binding, literal) in replacements.items(): - for ref_node, ref_parent, ref_key, ref_index in binding.references: - if _should_skip_reference(ref_parent, ref_key): + for binding_id, (binding, literal) in replacements.items(): + for reference_node, reference_parent, reference_key, reference_index in binding.references: + if _should_skip_reference(reference_parent, reference_key): continue new_node = deep_copy(literal) - if ref_index is not None: - ref_parent[ref_key][ref_index] = new_node + if reference_index is not None: + reference_parent[reference_key][reference_index] = new_node else: - ref_parent[ref_key] = new_node + reference_parent[reference_key] = new_node self.set_changed() - bindings_replaced.add(bind_id) + bindings_replaced.add(binding_id) return bindings_replaced - def _remove_fully_propagated(self, replacements, bindings_replaced): + def _remove_fully_propagated( + self, replacements: dict[int, tuple[Binding, dict]], bindings_replaced: set[int] + ) -> None: """Remove declarations whose bindings were fully propagated.""" - for bind_id in bindings_replaced: - binding = replacements[bind_id][0] + for binding_id in bindings_replaced: + binding = replacements[binding_id][0] if binding.assignments: continue - decl_node = binding.node - if not isinstance(decl_node, dict): + declarator_node = binding.node + if not isinstance(declarator_node, dict): continue - if decl_node.get('type') != 'VariableDeclarator': + if declarator_node.get('type') != 'VariableDeclarator': continue - self._remove_declarator(decl_node) + self._remove_declarator(declarator_node) - def _remove_declarator(self, declarator_node): + def _remove_declarator(self, declarator_node: dict) -> None: """Remove a VariableDeclarator from its parent VariableDeclaration.""" - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None): if node.get('type') != 'VariableDeclaration': return - decls = node.get('declarations', []) - for i, declaration in enumerate(decls): + declarations = node.get('declarations', []) + for i, declaration in enumerate(declarations): if declaration is not declarator_node: continue - decls.pop(i) + declarations.pop(i) self.set_changed() - if not decls: + if not declarations: return REMOVE return SKIP diff --git a/pyjsclear/transforms/control_flow.py b/pyjsclear/transforms/control_flow.py index 06646a2..aecfaa6 100644 --- a/pyjsclear/transforms/control_flow.py +++ b/pyjsclear/transforms/control_flow.py @@ -7,10 +7,7 @@ And reconstructs the linear statement sequence. """ -from ..utils.ast_helpers import get_child_keys -from ..utils.ast_helpers import is_identifier -from ..utils.ast_helpers import is_literal -from ..utils.ast_helpers import is_string_literal +from ..utils.ast_helpers import get_child_keys, is_identifier, is_literal, is_string_literal from .base import Transform @@ -19,11 +16,11 @@ class ControlFlowRecoverer(Transform): rebuild_scope = True - def execute(self): + def execute(self) -> bool: self._recover_in_bodies(self.ast) return self.has_changed() - def _recover_in_bodies(self, root): + def _recover_in_bodies(self, root: dict) -> None: """Walk through the AST looking for bodies containing CFF patterns.""" stack = [root] visited = set() @@ -47,7 +44,7 @@ def _recover_in_bodies(self, root): self._queue_children(node, stack) @staticmethod - def _queue_children(node, stack): + def _queue_children(node: dict, stack: list) -> None: """Add all child nodes to the traversal stack.""" for key in get_child_keys(node): child = node.get(key) @@ -60,140 +57,140 @@ def _queue_children(node, stack): elif isinstance(child, dict) and 'type' in child: stack.append(child) - def _try_recover_body(self, parent_node, body_key, body): + def _try_recover_body(self, parent_node: dict, body_key: str, body: list) -> None: """Try to find and recover CFF patterns in a body array.""" - i = 0 - while i < len(body): - statement = body[i] + index = 0 + while index < len(body): + statement = body[index] if not isinstance(statement, dict): - i += 1 + index += 1 continue - if self._try_recover_variable_pattern(body, i, statement): + if self._try_recover_variable_pattern(body, index, statement): continue - if self._try_recover_expression_pattern(body, i, statement): + if self._try_recover_expression_pattern(body, index, statement): continue - i += 1 + index += 1 - def _try_recover_variable_pattern(self, body, i, stmt): + def _try_recover_variable_pattern(self, body: list, index: int, statement: dict) -> bool: """Try Pattern 1: VariableDeclaration with split + loop. Returns True if recovered.""" - if stmt.get('type') != 'VariableDeclaration': + if statement.get('type') != 'VariableDeclaration': return False - state_info = self._find_state_array_in_decl(stmt) + state_info = self._find_state_array_in_decl(statement) if not state_info: return False states, state_var, counter_var = state_info - next_idx = i + 1 - if next_idx >= len(body): + next_index = index + 1 + if next_index >= len(body): return False - recovered = self._try_recover_from_loop(body[next_idx], states, state_var, counter_var) + recovered = self._try_recover_from_loop(body[next_index], states, state_var, counter_var) if recovered is None: return False - body[i : next_idx + 1] = recovered + body[index : next_index + 1] = recovered self.set_changed() return True - def _try_recover_expression_pattern(self, body, i, stmt): + def _try_recover_expression_pattern(self, body: list, index: int, statement: dict) -> bool: """Try Pattern 2: ExpressionStatement with split assignment + loop.""" - if stmt.get('type') != 'ExpressionStatement': + if statement.get('type') != 'ExpressionStatement': return False - expr = stmt.get('expression') - if not expr or expr.get('type') != 'AssignmentExpression': + expression = statement.get('expression') + if not expression or expression.get('type') != 'AssignmentExpression': return False - state_info = self._find_state_from_assignment(expr) + state_info = self._find_state_from_assignment(expression) if not state_info: return False states, state_var = state_info - next_idx = i + 1 + next_index = index + 1 counter_var = None - if next_idx < len(body): - counter_variable = self._find_counter_init(body[next_idx]) + if next_index < len(body): + counter_variable = self._find_counter_init(body[next_index]) if counter_variable is not None: counter_var = counter_variable - next_idx += 1 - if next_idx >= len(body): + next_index += 1 + if next_index >= len(body): return False - recovered = self._try_recover_from_loop(body[next_idx], states, state_var, counter_var or '_index') + recovered = self._try_recover_from_loop(body[next_index], states, state_var, counter_var or '_index') if recovered is None: return False - body[i : next_idx + 1] = recovered + body[index : next_index + 1] = recovered self.set_changed() return True - def _find_state_array_in_decl(self, decl): + def _find_state_array_in_decl(self, declaration: dict) -> tuple | None: """Find "X".split("|") pattern in a VariableDeclaration.""" - for declaration in decl.get('declarations', []): - initializer = declaration.get('init') + for declarator in declaration.get('declarations', []): + initializer = declarator.get('init') if not initializer or not self._is_split_call(initializer): continue states = self._extract_split_states(initializer) if not states: continue - if declaration.get('id', {}).get('type') != 'Identifier': + if declarator.get('id', {}).get('type') != 'Identifier': continue - state_var = declaration['id']['name'] - counter_var = self._find_counter_in_declaration(decl, exclude=declaration) + state_var = declarator['id']['name'] + counter_var = self._find_counter_in_declaration(declaration, exclude=declarator) return states, state_var, counter_var return None - def _find_counter_in_declaration(self, decl, exclude): + def _find_counter_in_declaration(self, declaration: dict, exclude: dict) -> str | None: """Find a numeric-initialized counter variable in a declaration, skipping *exclude*.""" - for declaration in decl.get('declarations', []): - if declaration is exclude: + for declarator in declaration.get('declarations', []): + if declarator is exclude: continue - if declaration.get('id', {}).get('type') != 'Identifier': + if declarator.get('id', {}).get('type') != 'Identifier': continue - initializer = declaration.get('init') + initializer = declarator.get('init') if ( initializer and initializer.get('type') == 'Literal' and isinstance(initializer.get('value'), (int, float)) ): - return declaration['id']['name'] + return declarator['id']['name'] return None - def _find_state_from_assignment(self, expr): + def _find_state_from_assignment(self, expression: dict) -> tuple | None: """Find state array from assignment expression.""" - if expr.get('type') != 'AssignmentExpression': + if expression.get('type') != 'AssignmentExpression': return None - if not is_identifier(expr.get('left')): + if not is_identifier(expression.get('left')): return None - right = expr.get('right') + right = expression.get('right') if self._is_split_call(right): states = self._extract_split_states(right) if states: - return states, expr['left']['name'] + return states, expression['left']['name'] return None - def _find_counter_init(self, statement): + def _find_counter_init(self, statement: dict) -> str | None: """Find counter variable initialization.""" if not isinstance(statement, dict): return None match statement.get('type'): case 'VariableDeclaration': - for declaration in statement.get('declarations', []): - if declaration.get('id', {}).get('type') == 'Identifier': - initializer = declaration.get('init') + for declarator in statement.get('declarations', []): + if declarator.get('id', {}).get('type') == 'Identifier': + initializer = declarator.get('init') if ( initializer and initializer.get('type') == 'Literal' and isinstance(initializer.get('value'), (int, float)) ): - return declaration['id']['name'] + return declarator['id']['name'] case 'ExpressionStatement': - expr = statement.get('expression') + expression = statement.get('expression') if ( - expr - and expr.get('type') == 'AssignmentExpression' - and is_identifier(expr.get('left')) - and is_literal(expr.get('right')) - and isinstance(expr['right'].get('value'), (int, float)) + expression + and expression.get('type') == 'AssignmentExpression' + and is_identifier(expression.get('left')) + and is_literal(expression.get('right')) + and isinstance(expression['right'].get('value'), (int, float)) ): - return expr['left']['name'] + return expression['left']['name'] return None - def _is_split_call(self, node): + def _is_split_call(self, node: dict) -> bool: """Check if node is "X".split("|").""" if not isinstance(node, dict): return False @@ -210,36 +207,36 @@ def _is_split_call(self, node): is_string_literal(property_expression) and property_expression.get('value') == 'split' ): return False - args = node.get('arguments', []) - if len(args) != 1 or not is_string_literal(args[0]): + arguments = node.get('arguments', []) + if len(arguments) != 1 or not is_string_literal(arguments[0]): return False return True - def _extract_split_states(self, node): + def _extract_split_states(self, node: dict) -> list: """Extract states from "1|0|3|2|4".split("|").""" callee = node['callee'] string = callee['object']['value'] separator = node['arguments'][0]['value'] return string.split(separator) - def _try_recover_from_loop(self, loop, states, state_var, counter_var): + def _try_recover_from_loop( + self, loop: dict, states: list, state_var: str, counter_var: str | None + ) -> list | None: """Try to recover statements from a for/while loop with switch dispatcher.""" if not isinstance(loop, dict): return None - loop_type = loop.get('type', '') - switch_body = None initial_value = 0 + switch_body = None - if loop_type == 'ForStatement': - # for(var _i = 0; ...) { switch(_array[_i++]) { ... } break; } - initial_value = self._extract_for_init_value(loop.get('init')) - switch_body = self._extract_switch_from_loop_body(loop.get('body')) - - elif loop_type == 'WhileStatement': - test = loop.get('test') - if self._is_truthy(test): + match loop.get('type', ''): + case 'ForStatement': + # for(var _i = 0; ...) { switch(_array[_i++]) { ... } break; } + initial_value = self._extract_for_init_value(loop.get('init')) switch_body = self._extract_switch_from_loop_body(loop.get('body')) + case 'WhileStatement': + if self._is_truthy(loop.get('test')): + switch_body = self._extract_switch_from_loop_body(loop.get('body')) if switch_body is None: return None @@ -248,20 +245,20 @@ def _try_recover_from_loop(self, loop, states, state_var, counter_var): return self._reconstruct_statements(cases_map, states, initial_value) @staticmethod - def _extract_for_init_value(initializer): + def _extract_for_init_value(initializer: dict | None) -> int: """Extract the initial counter value from a for-loop init clause.""" if not initializer: return 0 if initializer.get('type') == 'VariableDeclaration': - for declaration in initializer.get('declarations', []): - if declaration.get('init') and declaration['init'].get('type') == 'Literal': - return int(declaration['init'].get('value', 0)) + for declarator in initializer.get('declarations', []): + if declarator.get('init') and declarator['init'].get('type') == 'Literal': + return int(declarator['init'].get('value', 0)) elif initializer.get('type') == 'AssignmentExpression' and is_literal(initializer.get('right')): return int(initializer['right'].get('value', 0)) return 0 @staticmethod - def _build_case_map(cases): + def _build_case_map(cases: list) -> dict: """Build map from case test value to (filtered statements, original statements).""" cases_map = {} for case in cases: @@ -282,11 +279,11 @@ def _build_case_map(cases): return cases_map @staticmethod - def _reconstruct_statements(cases_map, states, initial_value): + def _reconstruct_statements(cases_map: dict, states: list, initial_value: int) -> list | None: """Reconstruct linear statement sequence from case map and state order.""" recovered = [] - for idx in range(initial_value, len(states)): - state = states[idx] + for index in range(initial_value, len(states)): + state = states[index] if state not in cases_map: break statements, original = cases_map[state] @@ -296,20 +293,20 @@ def _reconstruct_statements(cases_map, states, initial_value): break return recovered or None - def _extract_switch_from_loop_body(self, body): + def _extract_switch_from_loop_body(self, body: dict | None) -> dict | None: """Extract SwitchStatement from loop body.""" if not isinstance(body, dict): return None if body.get('type') == 'BlockStatement': - stmts = body.get('body', []) - for stmt in stmts: - if stmt.get('type') == 'SwitchStatement': - return stmt + statements = body.get('body', []) + for statement in statements: + if statement.get('type') == 'SwitchStatement': + return statement elif body.get('type') == 'SwitchStatement': return body return None - def _is_truthy(self, node): + def _is_truthy(self, node: dict | None) -> bool: """Check if a test expression is always truthy.""" if not isinstance(node, dict): return False @@ -317,14 +314,14 @@ def _is_truthy(self, node): return bool(node.get('value')) # !0 = true, !![] = true if node.get('type') == 'UnaryExpression' and node.get('operator') == '!': - arg = node.get('argument') - if arg and arg.get('type') == 'Literal' and arg.get('value') == 0: + argument = node.get('argument') + if argument and argument.get('type') == 'Literal' and argument.get('value') == 0: return True - if arg and arg.get('type') == 'ArrayExpression': + if argument and argument.get('type') == 'ArrayExpression': return False # ![] = false, but !![] = true - if arg and arg.get('type') == 'UnaryExpression' and arg.get('operator') == '!': + if argument and argument.get('type') == 'UnaryExpression' and argument.get('operator') == '!': # !!something - inner = arg.get('argument') + inner = argument.get('argument') if inner and inner.get('type') == 'ArrayExpression': return True return False diff --git a/pyjsclear/transforms/dead_branch.py b/pyjsclear/transforms/dead_branch.py index ec7a9b0..7c222ea 100644 --- a/pyjsclear/transforms/dead_branch.py +++ b/pyjsclear/transforms/dead_branch.py @@ -5,7 +5,7 @@ from .base import Transform -def _is_truthy_literal(node): +def _is_truthy_literal(node: dict) -> bool | None: """Check if node is a literal that is truthy in JS. Returns None if unknown.""" if not isinstance(node, dict): return None @@ -34,15 +34,15 @@ def _is_truthy_literal(node): case 'LogicalExpression': left = _is_truthy_literal(node.get('left')) right = _is_truthy_literal(node.get('right')) - op = node.get('operator') - if op == '&&': + operator = node.get('operator') + if operator == '&&': # falsy && anything → falsy if left is False: return False # truthy && right → right (if right is known) if left is True and right is not None: return right - elif op == '||': + elif operator == '||': # truthy || anything → truthy if left is True: return True @@ -52,7 +52,7 @@ def _is_truthy_literal(node): return None -def _unwrap_block(node): +def _unwrap_block(node: dict) -> dict: """Unwrap a single-statement block to its contents.""" if isinstance(node, dict) and node.get('type') == 'BlockStatement': body = node.get('body', []) @@ -64,8 +64,8 @@ def _unwrap_block(node): class DeadBranchRemover(Transform): """Remove dead branches from if statements and ternary expressions.""" - def execute(self): - def enter(node, parent, key, index): + def execute(self) -> bool: + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None: node_type = node.get('type', '') if node_type == 'IfStatement': @@ -75,8 +75,8 @@ def enter(node, parent, key, index): self.set_changed() if truthy: return node.get('consequent') - alt = node.get('alternate') - return alt if alt else REMOVE + alternate_branch = node.get('alternate') + return alternate_branch if alternate_branch else REMOVE if node_type == 'ConditionalExpression': truthy = _is_truthy_literal(node.get('test')) diff --git a/pyjsclear/transforms/dead_class_props.py b/pyjsclear/transforms/dead_class_props.py index 205135f..1c328fe 100644 --- a/pyjsclear/transforms/dead_class_props.py +++ b/pyjsclear/transforms/dead_class_props.py @@ -26,73 +26,73 @@ class DeadClassPropRemover(Transform): """Remove dead property assignments on class variables.""" - def execute(self): + def execute(self) -> bool: # Step 1: Find class variable names, aliases, and class-id-to-name mapping # in a single traversal - class_vars = set() - class_aliases = {} # inner_name -> outer_name - reverse_aliases = {} # outer_name -> set of inner_names - class_node_to_name = {} # id(ClassExpression node) -> outer name + class_vars: set[str] = set() + class_aliases: dict[str, str] = {} # inner_name -> outer_name + reverse_aliases: dict[str, set[str]] = {} # outer_name -> set of inner_names + class_node_to_name: dict[int, str] = {} # id(ClassExpression node) -> outer name - def find_classes(node, parent): + def find_classes(node: dict, parent: dict | None) -> None: if node.get('type') == 'VariableDeclarator': init = node.get('init') if init and init.get('type') == 'ClassExpression': decl_id = node.get('id') if decl_id and is_identifier(decl_id): - outer = decl_id['name'] - class_vars.add(outer) - class_node_to_name[id(init)] = outer + outer_name = decl_id['name'] + class_vars.add(outer_name) + class_node_to_name[id(init)] = outer_name class_id = init.get('id') if class_id and is_identifier(class_id): - inner = class_id['name'] - if inner != outer: - class_vars.add(inner) - class_aliases[inner] = outer - reverse_aliases.setdefault(outer, set()).add(inner) + inner_name = class_id['name'] + if inner_name != outer_name: + class_vars.add(inner_name) + class_aliases[inner_name] = outer_name + reverse_aliases.setdefault(outer_name, set()).add(inner_name) elif node.get('type') == 'AssignmentExpression': right = node.get('right') if right and right.get('type') == 'ClassExpression': left = node.get('left') if left and is_identifier(left): - outer = left['name'] - class_vars.add(outer) - class_node_to_name[id(right)] = outer + outer_name = left['name'] + class_vars.add(outer_name) + class_node_to_name[id(right)] = outer_name class_id = right.get('id') if class_id and is_identifier(class_id): - inner = class_id['name'] - if inner != outer: - class_vars.add(inner) - class_aliases[inner] = outer - reverse_aliases.setdefault(outer, set()).add(inner) + inner_name = class_id['name'] + if inner_name != outer_name: + class_vars.add(inner_name) + class_aliases[inner_name] = outer_name + reverse_aliases.setdefault(outer_name, set()).add(inner_name) simple_traverse(self.ast, find_classes) if not class_vars: return False - def _normalize(obj_name): + def _normalize(obj_name: str) -> str: """Resolve class aliases to their canonical (outer) name.""" return class_aliases.get(obj_name, obj_name) - def _has_standalone(name): + def _has_standalone(name: str) -> bool: if standalone_refs.get(name, 0) > 0: return True canonical = class_aliases.get(name, name) if standalone_refs.get(canonical, 0) > 0: return True - for inner in reverse_aliases.get(name, ()): - if standalone_refs.get(inner, 0) > 0: + for inner_name in reverse_aliases.get(name, ()): + if standalone_refs.get(inner_name, 0) > 0: return True return False # Step 2: Classify identifier references and collect this.prop reads # in a single traversal - member_refs = {v: 0 for v in class_vars} - standalone_refs = {v: 0 for v in class_vars} - this_reads = set() + member_refs: dict[str, int] = {var: 0 for var in class_vars} + standalone_refs: dict[str, int] = {var: 0 for var in class_vars} + this_reads: set[tuple[str, str]] = set() - def classify_and_collect(node, parent): + def classify_and_collect(node: dict, parent: dict | None) -> None: # Collect this.prop reads inside class bodies if node.get('type') == 'ClassExpression': class_name = class_node_to_name.get(id(node)) @@ -133,43 +133,46 @@ def classify_and_collect(node, parent): else: standalone_refs[name] = standalone_refs.get(name, 0) + 1 - def _collect_this_reads_in_class(class_node, class_name): - def _inner(n, p): - if n.get('type') != 'MemberExpression': + def _collect_this_reads_in_class(class_node: dict, class_name: str) -> None: + def _visit(node: dict, parent: dict | None) -> None: + if node.get('type') != 'MemberExpression': return - obj = n.get('object') + obj = node.get('object') if not obj or obj.get('type') != 'ThisExpression': return - prop = n.get('property') + prop = node.get('property') if not prop: return - if n.get('computed'): + if node.get('computed'): if is_string_literal(prop): this_reads.add((class_name, prop['value'])) elif is_identifier(prop): this_reads.add((class_name, prop['name'])) - simple_traverse(class_node.get('body', {}), _inner) + simple_traverse(class_node.get('body', {}), _visit) simple_traverse(self.ast, classify_and_collect) # Classes with `this.prop` reads use their own properties — not fully dead - classes_with_this = set() + classes_with_this: set[str] = set() for name, prop in this_reads: classes_with_this.add(name) classes_with_this.add(class_aliases.get(name, name)) # Classes that never escape (only used as X.prop) — all their props are dead - fully_dead_classes = {v for v in class_vars if not _has_standalone(v) and v not in classes_with_this} + fully_dead_classes = { + var for var in class_vars + if not _has_standalone(var) and var not in classes_with_this + } # Classes that have escaped — skip individual prop dead-code analysis - escaped_classes = {_normalize(v) for v in class_vars if _has_standalone(v)} + escaped_classes = {_normalize(var) for var in class_vars if _has_standalone(var)} # Step 3: For non-fully-dead classes, find individually dead properties - writes = set() - reads = set() + writes: set[tuple[str, str]] = set() + reads: set[tuple[str, str]] = set() - def count_prop_refs(node, parent): + def count_prop_refs(node: dict, parent: dict | None) -> None: if node.get('type') != 'MemberExpression': return obj_name, prop_name = get_member_names(node) @@ -192,16 +195,16 @@ def count_prop_refs(node, parent): reads |= {(_normalize(name), prop) for name, prop in this_reads} # Dead props: written but never read, OR belonging to fully dead classes - dead_props = set() + dead_props: set[tuple[str, str]] = set() for pair in writes: if pair not in reads: dead_props.add(pair) # Collect all props of fully dead classes in a single traversal if fully_dead_classes: - fully_dead_canonical = {_normalize(v) for v in fully_dead_classes} | fully_dead_classes + fully_dead_canonical = {_normalize(var) for var in fully_dead_classes} | fully_dead_classes - def collect_all_dead(node, parent): + def collect_all_dead(node: dict, parent: dict | None) -> None: if node.get('type') != 'AssignmentExpression' or node.get('operator') != '=': return obj_name, prop_name = get_member_names(node.get('left')) @@ -214,11 +217,11 @@ def collect_all_dead(node, parent): return False # Step 4: Remove dead assignment expressions - def _is_dead(obj_name, prop_name): + def _is_dead(obj_name: str, prop_name: str) -> bool: canonical = _normalize(obj_name) return (canonical, prop_name) in dead_props or (obj_name, prop_name) in dead_props - def remove_dead_stmts(node, parent, key, index): + def remove_dead_stmts(node: dict, parent: dict | None, key: str | None, index: int | None) -> object: if node.get('type') != 'ExpressionStatement': return expr = node.get('expression') @@ -232,13 +235,13 @@ def remove_dead_stmts(node, parent, key, index): if expr.get('type') == 'SequenceExpression': exprs = expr.get('expressions', []) remaining = [] - for e in exprs: - if e.get('type') == 'AssignmentExpression' and e.get('operator') == '=': - obj_name, prop_name = get_member_names(e.get('left')) + for expression in exprs: + if expression.get('type') == 'AssignmentExpression' and expression.get('operator') == '=': + obj_name, prop_name = get_member_names(expression.get('left')) if obj_name and _is_dead(obj_name, prop_name): self.set_changed() continue - remaining.append(e) + remaining.append(expression) if not remaining: return REMOVE if len(remaining) < len(exprs): diff --git a/pyjsclear/transforms/dead_expressions.py b/pyjsclear/transforms/dead_expressions.py index 0d7d17d..15df3d6 100644 --- a/pyjsclear/transforms/dead_expressions.py +++ b/pyjsclear/transforms/dead_expressions.py @@ -12,14 +12,14 @@ class DeadExpressionRemover(Transform): like `(0, fn())`, and other numeric literal statements. """ - def execute(self): - def enter(node, parent, key, index): + def execute(self) -> bool: + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object: if node.get('type') != 'ExpressionStatement': return - expr = node.get('expression') - if not isinstance(expr, dict) or expr.get('type') != 'Literal': + expression = node.get('expression') + if not isinstance(expression, dict) or expression.get('type') != 'Literal': return - value = expr.get('value') + value = expression.get('value') # Only remove numeric literals (not strings/booleans/null/regex) if isinstance(value, (int, float)) and not isinstance(value, bool): self.set_changed() diff --git a/pyjsclear/transforms/dead_object_props.py b/pyjsclear/transforms/dead_object_props.py index 7e97beb..a6acbc0 100644 --- a/pyjsclear/transforms/dead_object_props.py +++ b/pyjsclear/transforms/dead_object_props.py @@ -46,61 +46,60 @@ class DeadObjectPropRemover(Transform): """Remove object property assignments where the property is never read.""" - def execute(self): + def execute(self) -> bool: # Phase 1: Find all obj.PROP = value statements and all obj.PROP reads. # Also track which objects "escape" (are assigned to external refs, passed as # function arguments, or returned) — their properties may be read externally. - writes = {} # (obj_name, prop_name) -> count - reads = set() # set of (obj_name, prop_name) - escaped = set() # set of obj_name that escape + writes: dict[tuple[str, str], int] = {} # (obj_name, prop_name) -> count + reads: set[tuple[str, str]] = set() # set of (obj_name, prop_name) + escaped: set[str] = set() # set of obj_name that escape # Phase 0: Collect locally declared variable names (var/let/const). # Only properties on locally declared objects are candidates for removal. - local_vars = set() + local_vars: set[str] = set() - def collect_locals(node, parent): + def collect_locals(node: dict, parent: dict | None) -> None: if not isinstance(node, dict): return - t = node.get('type') - if t == 'VariableDeclarator': - vid = node.get('id') - if vid and is_identifier(vid): - local_vars.add(vid['name']) + node_type = node.get('type') + if node_type == 'VariableDeclarator': + variable_id = node.get('id') + if variable_id and is_identifier(variable_id): + local_vars.add(variable_id['name']) # Function/arrow params are externally provided — mark as escaped - if t in ('FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'): + if node_type in ('FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'): for param in node.get('params', []): if is_identifier(param): escaped.add(param['name']) simple_traverse(self.ast, collect_locals) - def collect(node, parent): + def collect(node: dict, parent: dict | None) -> None: if not isinstance(node, dict): return - t = node.get('type') + node_type = node.get('type') # Track identifiers that escape - if t == 'Identifier' and parent: + if node_type == 'Identifier' and parent: name = node.get('name', '') if name in _GLOBAL_OBJECTS: escaped.add(name) - pt = parent.get('type') + parent_type = parent.get('type') # RHS of assignment to a member (e.g., r.exports = obj) - if pt == 'AssignmentExpression' and node is parent.get('right'): + if parent_type == 'AssignmentExpression' and node is parent.get('right'): left = parent.get('left') if left and left.get('type') == 'MemberExpression': escaped.add(name) # Function/method argument - if pt == 'CallExpression' or pt == 'NewExpression': - args = parent.get('arguments', []) - if node in args: + if parent_type in ('CallExpression', 'NewExpression'): + if node in parent.get('arguments', []): escaped.add(name) # Return value - if pt == 'ReturnStatement': + if parent_type == 'ReturnStatement': escaped.add(name) # Track member access patterns - if t != 'MemberExpression': + if node_type != 'MemberExpression': return if node.get('computed'): return @@ -125,7 +124,7 @@ def collect(node, parent): return False # Phase 2: Remove dead assignment statements - def remove_dead(node, parent, key, index): + def remove_dead(node: dict, parent: dict | None, key: str | None, index: int | None) -> object: if node.get('type') != 'ExpressionStatement': return expr = node.get('expression') @@ -139,12 +138,13 @@ def remove_dead(node, parent, key, index): if not obj or not is_identifier(obj) or not prop or not is_identifier(prop): return pair = (obj['name'], prop['name']) - if pair in dead_props: - # Only remove if the RHS is side-effect-free - rhs = expr.get('right') - if is_side_effect_free(rhs): - self.set_changed() - return REMOVE + if pair not in dead_props: + return + # Only remove if the RHS is side-effect-free + rhs = expr.get('right') + if is_side_effect_free(rhs): + self.set_changed() + return REMOVE traverse(self.ast, {'enter': remove_dead}) return self.has_changed() diff --git a/pyjsclear/transforms/else_if_flatten.py b/pyjsclear/transforms/else_if_flatten.py index 8559627..8a3fb27 100644 --- a/pyjsclear/transforms/else_if_flatten.py +++ b/pyjsclear/transforms/else_if_flatten.py @@ -11,22 +11,22 @@ class ElseIfFlattener(Transform): where each else block wraps a single if statement. """ - def execute(self): - def enter(node, parent, key, index): - if node.get('type') != 'IfStatement': - return - alt = node.get('alternate') - if not alt or alt.get('type') != 'BlockStatement': - return - body = alt.get('body', []) - if len(body) != 1: - return - inner = body[0] - if inner.get('type') != 'IfStatement': - return - # Flatten: replace the block with the inner if - node['alternate'] = inner - self.set_changed() + def _enter_node(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> None: + if node.get('type') != 'IfStatement': + return + alternate_node = node.get('alternate') + if not alternate_node or alternate_node.get('type') != 'BlockStatement': + return + body = alternate_node.get('body', []) + if len(body) != 1: + return + inner_if = body[0] + if inner_if.get('type') != 'IfStatement': + return + # Flatten: replace the block with the inner if + node['alternate'] = inner_if + self.set_changed() - traverse(self.ast, {'enter': enter}) + def execute(self) -> bool: + traverse(self.ast, {'enter': self._enter_node}) return self.has_changed() diff --git a/pyjsclear/transforms/enum_resolver.py b/pyjsclear/transforms/enum_resolver.py index 6c9ee94..9555043 100644 --- a/pyjsclear/transforms/enum_resolver.py +++ b/pyjsclear/transforms/enum_resolver.py @@ -21,10 +21,10 @@ class EnumResolver(Transform): """Replace TypeScript enum member accesses with their numeric values.""" - def execute(self): + def execute(self) -> bool: # Phase 1: Detect enum IIFEs and extract member values. # Pattern: (function(param) { param[param.NAME = VALUE] = "NAME"; ... })(E || (E = {})) - enum_members = {} # (enum_name, member_name) -> numeric value + enum_members: dict[tuple[str, str], int | float] = {} # (enum_name, member_name) -> numeric value def find_enums(node, parent): if node.get('type') != 'CallExpression': @@ -91,7 +91,7 @@ def resolve(node, parent, key, index): traverse(self.ast, {'enter': resolve}) return self.has_changed() - def _extract_enum_name(self, arg): + def _extract_enum_name(self, argument_node: dict) -> str | None: """Extract the enum name from the IIFE argument pattern. Handles: @@ -99,21 +99,21 @@ def _extract_enum_name(self, arg): - E = X.Y || (X.Y = {}) (export-assigned variant) """ # Simple case: just an identifier - if is_identifier(arg): - return arg['name'] + if is_identifier(argument_node): + return argument_node['name'] # Assignment wrapper: E = X.Y || (X.Y = {}) - if arg.get('type') == 'AssignmentExpression' and arg.get('operator') == '=': - assign_left = arg.get('left') + if argument_node.get('type') == 'AssignmentExpression' and argument_node.get('operator') == '=': + assign_left = argument_node.get('left') if is_identifier(assign_left): - inner = arg.get('right') + inner = argument_node.get('right') if inner and inner.get('type') == 'LogicalExpression': return assign_left['name'] return None # Logical OR pattern: E || (E = {}) - if arg.get('type') != 'LogicalExpression' or arg.get('operator') != '||': + if argument_node.get('type') != 'LogicalExpression' or argument_node.get('operator') != '||': return None - left = arg.get('left') - right = arg.get('right') + left = argument_node.get('left') + right = argument_node.get('right') if not is_identifier(left): return None name = left['name'] @@ -124,15 +124,15 @@ def _extract_enum_name(self, arg): return name return None - def _extract_enum_assignment(self, expr, param_name): + def _extract_enum_assignment(self, expression: dict | None, param_name: str) -> tuple[str | None, int | float | None]: """Extract (member_name, value) from: param[param.NAME = VALUE] = "NAME". Returns (member_name, numeric_value) or (None, None). """ - if not expr or expr.get('type') != 'AssignmentExpression': + if not expression or expression.get('type') != 'AssignmentExpression': return None, None # The outer assignment: param[...] = "NAME" - left = expr.get('left') + left = expression.get('left') if not left or left.get('type') != 'MemberExpression' or not left.get('computed'): return None, None obj = left.get('object') @@ -159,7 +159,7 @@ def _extract_enum_assignment(self, expr, param_name): return member_name, value @staticmethod - def _get_numeric_value(node): + def _get_numeric_value(node: dict | None) -> int | float | None: """Extract a numeric value from a literal or unary minus expression.""" if not node: return None diff --git a/pyjsclear/transforms/eval_unpack.py b/pyjsclear/transforms/eval_unpack.py index 7f2baa1..9acf547 100644 --- a/pyjsclear/transforms/eval_unpack.py +++ b/pyjsclear/transforms/eval_unpack.py @@ -27,21 +27,21 @@ _EVAL_RE = re.compile(r'^eval\s*\(', re.MULTILINE) -def is_eval_packed(code): +def is_eval_packed(code: str) -> bool: """Check if code uses eval packing.""" return bool(_PACKER_RE.search(code) or _PACKER_RE2.search(code) or _EVAL_RE.search(code.lstrip())) -def _dean_edwards_unpack(packed, radix, count, keywords): +def _dean_edwards_unpack(packed: str, radix: int, count: int, keywords: list[str]) -> str: """Pure Python implementation of Dean Edwards unpacker.""" # Build the replacement function - def base_encode(c): - prefix = '' if c < radix else base_encode(int(c / radix)) - c = c % radix - if c > 35: - return prefix + chr(c + 29) - return prefix + ('0123456789abcdefghijklmnopqrstuvwxyz'[c] if c < 36 else chr(c + 29)) + def base_encode(value: int) -> str: + prefix = '' if value < radix else base_encode(int(value / radix)) + remainder = value % radix + if remainder > 35: + return prefix + chr(remainder + 29) + return prefix + ('0123456789abcdefghijklmnopqrstuvwxyz'[remainder] if remainder < 36 else chr(remainder + 29)) # Build dictionary lookup = {} @@ -51,34 +51,35 @@ def base_encode(c): lookup[key] = keywords[count] if count < len(keywords) and keywords[count] else key # Replace tokens in packed string - def replacer(match): - token = match.group(0) + def replacer(token_match: re.Match) -> str: + token = token_match.group(0) return lookup.get(token, token) return re.sub(r'\b\w+\b', replacer, packed) -def eval_unpack(code): +def eval_unpack(code: str) -> str | None: """Unpack eval-packed JavaScript. Returns unpacked code or None.""" return _try_dean_edwards(code) -def _try_dean_edwards(code): +def _try_dean_edwards(code: str) -> str | None: """Try to unpack Dean Edwards packer format.""" for pattern in [_PACKER_RE, _PACKER_RE2]: - m = pattern.search(code) - if m: - packed = m.group(1) - radix = int(m.group(2)) - count = int(m.group(3)) - keywords_str = m.group(4) - keywords = keywords_str.split('|') - - # Unescape the packed string - packed = packed.replace("\\'", "'").replace('\\\\', '\\') - - try: - return _dean_edwards_unpack(packed, radix, count, keywords) - except Exception: - continue + pattern_match = pattern.search(code) + if not pattern_match: + continue + + packed = pattern_match.group(1) + radix = int(pattern_match.group(2)) + count = int(pattern_match.group(3)) + keywords = pattern_match.group(4).split('|') + + # Unescape the packed string + packed = packed.replace("\\'", "'").replace('\\\\', '\\') + + try: + return _dean_edwards_unpack(packed, radix, count, keywords) + except Exception: + continue return None diff --git a/pyjsclear/transforms/expression_simplifier.py b/pyjsclear/transforms/expression_simplifier.py index 073a6f5..743ce97 100644 --- a/pyjsclear/transforms/expression_simplifier.py +++ b/pyjsclear/transforms/expression_simplifier.py @@ -1,6 +1,7 @@ """Evaluate static unary/binary expressions to literals.""" import math +from typing import Any from ..traverser import traverse from ..utils.ast_helpers import is_literal @@ -41,7 +42,7 @@ class ExpressionSimplifier(Transform): """Simplify constant unary/binary expressions to literals.""" - def execute(self): + def execute(self) -> bool: self._simplify_unary_binary() self._simplify_conditionals() self._simplify_awaits() @@ -49,32 +50,38 @@ def execute(self): self._simplify_method_calls() return self.has_changed() - def _simplify_unary_binary(self): + def _simplify_unary_binary(self) -> None: """Fold constant unary and binary expressions.""" - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None: match node.get('type', ''): case 'UnaryExpression': result = self._simplify_unary(node) case 'BinaryExpression': result = self._simplify_binary(node) case _: - return + return None if result is not None: self.set_changed() return result + return None traverse(self.ast, {'enter': enter}) - def _simplify_conditionals(self): + def _simplify_conditionals(self) -> None: """Convert test ? false : true → !test.""" - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None: if node.get('type') != 'ConditionalExpression': - return - cons = node.get('consequent') - alt = node.get('alternate') - if is_literal(cons) and cons.get('value') is False and is_literal(alt) and alt.get('value') is True: + return None + consequent = node.get('consequent') + alternate = node.get('alternate') + if ( + is_literal(consequent) + and consequent.get('value') is False + and is_literal(alternate) + and alternate.get('value') is True + ): self.set_changed() return { 'type': 'UnaryExpression', @@ -82,48 +89,49 @@ def enter(node, parent, key, index): 'prefix': True, 'argument': node['test'], } + return None traverse(self.ast, {'enter': enter}) - def _simplify_awaits(self): + def _simplify_awaits(self) -> None: """Simplify await (0x0, expr) → await expr.""" - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'AwaitExpression': return - arg = node.get('argument') - if not isinstance(arg, dict) or arg.get('type') != 'SequenceExpression': + argument = node.get('argument') + if not isinstance(argument, dict) or argument.get('type') != 'SequenceExpression': return - exprs = arg.get('expressions', []) - if len(exprs) <= 1: + expressions = argument.get('expressions', []) + if len(expressions) <= 1: return - node['argument'] = exprs[-1] + node['argument'] = expressions[-1] self.set_changed() traverse(self.ast, {'enter': enter}) - def _simplify_comma_calls(self): + def _simplify_comma_calls(self) -> None: """Simplify (0, expr)(args) → expr(args).""" - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'CallExpression': return callee = node.get('callee') if not isinstance(callee, dict) or callee.get('type') != 'SequenceExpression': return - exprs = callee.get('expressions', []) - if len(exprs) < 2: + expressions = callee.get('expressions', []) + if len(expressions) < 2: return # Only simplify when the leading expressions are side-effect-free literals - for expr in exprs[:-1]: - if not isinstance(expr, dict) or expr.get('type') != 'Literal': + for expression in expressions[:-1]: + if not isinstance(expression, dict) or expression.get('type') != 'Literal': return - node['callee'] = exprs[-1] + node['callee'] = expressions[-1] self.set_changed() traverse(self.ast, {'enter': enter}) - def _simplify_unary(self, node): + def _simplify_unary(self, node: dict) -> dict | None: operator = node.get('operator', '') if operator not in _RESOLVABLE_UNARY: return None @@ -137,8 +145,8 @@ def _simplify_unary(self, node): argument = node.get('argument') argument = self._simplify_expr(argument) - value, ok = self._get_resolvable_value(argument) - if not ok: + value, resolved = self._get_resolvable_value(argument) + if not resolved: return None try: @@ -147,7 +155,7 @@ def _simplify_unary(self, node): return None return self._value_to_node(result) - def _simplify_binary(self, node): + def _simplify_binary(self, node: dict) -> dict | None: operator = node.get('operator', '') if operator not in _RESOLVABLE_BINARY: return None @@ -172,7 +180,7 @@ def _simplify_binary(self, node): return None return self._value_to_node(result) - def _simplify_expr(self, node): + def _simplify_expr(self, node: Any) -> Any: if not isinstance(node, dict): return node match node.get('type', ''): @@ -184,7 +192,7 @@ def _simplify_expr(self, node): return result if result is not None else node return node - def _is_negative_numeric(self, node): + def _is_negative_numeric(self, node: Any) -> bool: return ( isinstance(node, dict) and node.get('type') == 'UnaryExpression' @@ -193,7 +201,7 @@ def _is_negative_numeric(self, node): and isinstance(node['argument'].get('value'), (int, float)) ) - def _get_resolvable_value(self, node): + def _get_resolvable_value(self, node: Any) -> tuple[Any, bool]: if not isinstance(node, dict): return None, False match node.get('type', ''): @@ -204,9 +212,9 @@ def _get_resolvable_value(self, node): # Literal with value None is JS null, not undefined return (_JS_NULL if value is None else value), True case 'UnaryExpression' if node.get('operator') == '-': - arg = node.get('argument') - if is_literal(arg) and isinstance(arg.get('value'), (int, float)): - return -arg['value'], True + argument = node.get('argument') + if is_literal(argument) and isinstance(argument.get('value'), (int, float)): + return -argument['value'], True case 'Identifier' if node.get('name') == 'undefined': return None, True case 'ArrayExpression' if len(node.get('elements', [])) == 0: @@ -215,7 +223,7 @@ def _get_resolvable_value(self, node): return {}, True return None, False - def _apply_unary(self, operator, value): + def _apply_unary(self, operator: str, value: Any) -> Any: match operator: case '-': return -self._js_to_number(value) @@ -224,16 +232,16 @@ def _apply_unary(self, operator, value): case '!': return not self._js_truthy(value) case '~': - n = self._js_to_number(value) - if isinstance(n, float) and math.isnan(n): + number = self._js_to_number(value) + if isinstance(number, float) and math.isnan(number): return -1 # ~NaN → -1 - return ~int(n) + return ~int(number) case 'typeof': return self._js_typeof(value) case 'void': return None # JS undefined - def _apply_binary(self, operator, left, right): + def _apply_binary(self, operator: str, left: Any, right: Any) -> Any: match operator: case '+': if isinstance(left, str) or isinstance(right, str): @@ -244,15 +252,15 @@ def _apply_binary(self, operator, left, right): case '*': return self._js_to_number(left) * self._js_to_number(right) case '/': - result = self._js_to_number(right) - if result == 0: + divisor = self._js_to_number(right) + if divisor == 0: raise ValueError('division by zero') - return self._js_to_number(left) / result + return self._js_to_number(left) / divisor case '%': - result = self._js_to_number(right) - if result == 0: + modulus = self._js_to_number(right) + if modulus == 0: raise ValueError('mod by zero') - return self._js_to_number(left) % result + return self._js_to_number(left) % modulus case '**': return self._js_to_number(left) ** self._js_to_number(right) case '|': @@ -267,14 +275,14 @@ def _apply_binary(self, operator, left, right): return self._js_to_int32(left) >> (self._js_to_int32(right) & 31) case '>>>': left_operand = self._js_to_int32(left) & 0xFFFFFFFF - result = self._js_to_int32(right) & 31 - return left_operand >> result + shift = self._js_to_int32(right) & 31 + return left_operand >> shift case '==' | '!=': - eq = self._js_abstract_eq(left, right) - return eq if operator == '==' else not eq + equal = self._js_abstract_eq(left, right) + return equal if operator == '==' else not equal case '===' | '!==': - eq = self._js_strict_eq(left, right) - return eq if operator == '===' else not eq + equal = self._js_strict_eq(left, right) + return equal if operator == '===' else not equal case '<': return self._js_compare(left, right) < 0 case '<=': @@ -286,7 +294,7 @@ def _apply_binary(self, operator, left, right): case _: raise ValueError(f'Unknown operator: {operator}') - def _js_abstract_eq(self, left, right): + def _js_abstract_eq(self, left: Any, right: Any) -> bool: """JS == (null and undefined are equal to each other only).""" if (left is None or left is _JS_NULL) and (right is None or right is _JS_NULL): return True @@ -294,7 +302,7 @@ def _js_abstract_eq(self, left, right): return False return left == right - def _js_strict_eq(self, left, right): + def _js_strict_eq(self, left: Any, right: Any) -> bool: """JS === (null !== undefined).""" if left is _JS_NULL: return right is _JS_NULL @@ -302,7 +310,7 @@ def _js_strict_eq(self, left, right): return False return left == right and type(left) == type(right) - def _js_truthy(self, value): + def _js_truthy(self, value: Any) -> bool: if value is None or value is _JS_NULL: return False match value: @@ -317,7 +325,7 @@ def _js_truthy(self, value): case _: return bool(value) - def _js_typeof(self, value): + def _js_typeof(self, value: Any) -> str: if value is _JS_NULL: return 'object' # typeof null === 'object' in JS if value is None: @@ -334,11 +342,11 @@ def _js_typeof(self, value): case _: return 'undefined' - def _js_to_int32(self, value): + def _js_to_int32(self, value: Any) -> int: """Coerce to 32-bit integer (for bitwise ops).""" return int(self._js_to_number(value)) - def _js_to_number(self, value): + def _js_to_number(self, value: Any) -> int | float: if value is _JS_NULL: return 0 # Number(null) → 0 if value is None: @@ -362,7 +370,7 @@ def _js_to_number(self, value): case _: return 0 - def _js_to_string(self, value): + def _js_to_string(self, value: Any) -> str: if value is _JS_NULL: return 'null' if value is None: @@ -381,7 +389,7 @@ def _js_to_string(self, value): case _: return str(value) - def _js_compare(self, left, right): + def _js_compare(self, left: Any, right: Any) -> int | float: # JS compares strings lexicographically, not numerically if isinstance(left, str) and isinstance(right, str): if left < right: @@ -403,7 +411,7 @@ def _js_compare(self, left, right): return 1 return 0 - def _value_to_node(self, value): + def _value_to_node(self, value: Any) -> dict | None: if value is _JS_NULL: return make_literal(None) # null literal if value is None: @@ -426,7 +434,7 @@ def _value_to_node(self, value): return make_literal(value) return None - def _simplify_method_calls(self): + def _simplify_method_calls(self) -> None: """Statically evaluate simple method calls on literals. Handles: @@ -434,27 +442,27 @@ def _simplify_method_calls(self): (N).toString() → "N" """ - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None: if node.get('type') != 'CallExpression': - return + return None callee = node.get('callee') if not isinstance(callee, dict) or callee.get('type') != 'MemberExpression': - return + return None prop = callee.get('property') if not prop: - return + return None method_name = prop.get('name') if prop.get('type') == 'Identifier' else None if not method_name: - return + return None # (N).toString() → "N" if method_name == 'toString' and len(node.get('arguments', [])) == 0: obj = callee.get('object') if is_numeric_literal(obj): val = obj['value'] - s = str(int(val)) if isinstance(val, float) and val == int(val) else str(val) + string_value = str(int(val)) if isinstance(val, float) and val == int(val) else str(val) self.set_changed() - return make_literal(s) + return make_literal(string_value) # Buffer.from([...nums...]).toString(encoding) → string literal if method_name == 'toString' and len(node.get('arguments', [])) <= 1: @@ -463,9 +471,11 @@ def enter(node, parent, key, index): self.set_changed() return make_literal(result) + return None + traverse(self.ast, {'enter': enter}) - def _try_eval_buffer_from_tostring(self, obj, args): + def _try_eval_buffer_from_tostring(self, obj: Any, arguments: list) -> str | None: """Try to evaluate Buffer.from([...nums...]).toString(encoding).""" if not isinstance(obj, dict) or obj.get('type') != 'CallExpression': return None @@ -473,11 +483,19 @@ def _try_eval_buffer_from_tostring(self, obj, args): if not isinstance(callee, dict) or callee.get('type') != 'MemberExpression': return None # Check for Buffer.from - buf_obj = callee.get('object') - buf_prop = callee.get('property') - if not (buf_obj and buf_obj.get('type') == 'Identifier' and buf_obj.get('name') == 'Buffer'): + buffer_object = callee.get('object') + buffer_property = callee.get('property') + if not ( + buffer_object + and buffer_object.get('type') == 'Identifier' + and buffer_object.get('name') == 'Buffer' + ): return None - if not (buf_prop and buf_prop.get('type') == 'Identifier' and buf_prop.get('name') == 'from'): + if not ( + buffer_property + and buffer_property.get('type') == 'Identifier' + and buffer_property.get('name') == 'from' + ): return None # First arg must be an array of numbers call_args = obj.get('arguments', []) @@ -485,17 +503,17 @@ def _try_eval_buffer_from_tostring(self, obj, args): return None elements = call_args[0].get('elements', []) byte_values = [] - for el in elements: - if not is_numeric_literal(el): + for element in elements: + if not is_numeric_literal(element): return None - val = el['value'] + val = element['value'] if not isinstance(val, (int, float)) or val != int(val) or val < 0 or val > 255: return None byte_values.append(int(val)) # Determine encoding for toString encoding = 'utf8' - if args and is_literal(args[0]) and isinstance(args[0].get('value'), str): - encoding = args[0]['value'] + if arguments and is_literal(arguments[0]) and isinstance(arguments[0].get('value'), str): + encoding = arguments[0]['value'] try: data = bytes(byte_values) if encoding in ('utf8', 'utf-8'): diff --git a/pyjsclear/transforms/global_alias.py b/pyjsclear/transforms/global_alias.py index ffab53b..1c644e2 100644 --- a/pyjsclear/transforms/global_alias.py +++ b/pyjsclear/transforms/global_alias.py @@ -54,64 +54,75 @@ class GlobalAliasInliner(Transform): of mangled variable names is extremely unlikely. """ - def execute(self): - aliases = {} - - # Phase 1: Find `var X = GLOBAL` patterns - def find_aliases(node, parent, key, index): - if node.get('type') != 'VariableDeclarator': - return - decl_id = node.get('id') - init = node.get('init') - if not is_identifier(decl_id) or not is_identifier(init): - return - if init['name'] in _WELL_KNOWN_GLOBALS: - aliases[decl_id['name']] = init['name'] - - traverse(self.ast, {'enter': find_aliases}) - - if not aliases: + def _find_var_aliases(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> None: + """Collect `var alias = GLOBAL` patterns into self._aliases.""" + if node.get('type') != 'VariableDeclarator': + return + declaration_id = node.get('id') + initializer = node.get('init') + if not is_identifier(declaration_id) or not is_identifier(initializer): + return + if initializer['name'] in _WELL_KNOWN_GLOBALS: + self._aliases[declaration_id['name']] = initializer['name'] + + def _find_assignment_aliases(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> None: + """Collect `alias = GLOBAL` assignment patterns into self._aliases.""" + if node.get('type') != 'AssignmentExpression': + return + if node.get('operator') != '=': + return + left_node = node.get('left') + right_node = node.get('right') + if not is_identifier(left_node) or not is_identifier(right_node): + return + if right_node['name'] in _WELL_KNOWN_GLOBALS: + self._aliases[left_node['name']] = right_node['name'] + + def _is_non_reference_position(self, parent: dict | None, key: str | None) -> bool: + """Return True if the identifier is in a non-reference (definition/key) position.""" + if not parent: return False + parent_type = parent.get('type') + # Non-computed property name + if parent_type == 'MemberExpression' and key == 'property' and not parent.get('computed'): + return True + # Variable declaration target + if parent_type == 'VariableDeclarator' and key == 'id': + return True + # Assignment left-hand side + if parent_type == 'AssignmentExpression' and key == 'left': + return True + # Function/method name + if parent_type in ('FunctionDeclaration', 'FunctionExpression') and key == 'id': + return True + # Non-computed property key + if parent_type == 'Property' and key == 'key' and not parent.get('computed'): + return True + return False + + def _replace_alias_refs(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None: + """Replace aliased identifier references with the global name.""" + if not is_identifier(node): + return None + if self._is_non_reference_position(parent, key): + return None + name = node.get('name') + if name in self._aliases: + self.set_changed() + return make_identifier(self._aliases[name]) + return None + + def execute(self) -> bool: + self._aliases: dict[str, str] = {} + + # Phase 1: collect `var X = GLOBAL` and `X = GLOBAL` patterns + traverse(self.ast, {'enter': self._find_var_aliases}) + + if not self._aliases: + return False + + traverse(self.ast, {'enter': self._find_assignment_aliases}) - # Also find assignment aliases: X = GLOBAL (not just var X = GLOBAL) - def find_assignment_aliases(node, parent, key, index): - if node.get('type') != 'AssignmentExpression': - return - if node.get('operator') != '=': - return - left = node.get('left') - right = node.get('right') - if not is_identifier(left) or not is_identifier(right): - return - if right['name'] in _WELL_KNOWN_GLOBALS: - aliases[left['name']] = right['name'] - - traverse(self.ast, {'enter': find_assignment_aliases}) - - # Phase 2: Replace all references - def replace_refs(node, parent, key, index): - if not is_identifier(node): - return - # Skip non-computed property names - if parent and parent.get('type') == 'MemberExpression' and key == 'property' and not parent.get('computed'): - return - # Skip declaration targets - if parent and parent.get('type') == 'VariableDeclarator' and key == 'id': - return - # Skip assignment left-hand sides - if parent and parent.get('type') == 'AssignmentExpression' and key == 'left': - return - # Skip function/method names - if parent and parent.get('type') in ('FunctionDeclaration', 'FunctionExpression') and key == 'id': - return - # Skip property keys - if parent and parent.get('type') == 'Property' and key == 'key' and not parent.get('computed'): - return - - name = node.get('name') - if name in aliases: - self.set_changed() - return make_identifier(aliases[name]) - - traverse(self.ast, {'enter': replace_refs}) + # Phase 2: replace all alias references + traverse(self.ast, {'enter': self._replace_alias_refs}) return self.has_changed() diff --git a/pyjsclear/transforms/hex_escapes.py b/pyjsclear/transforms/hex_escapes.py index 6af0c2b..01b6965 100644 --- a/pyjsclear/transforms/hex_escapes.py +++ b/pyjsclear/transforms/hex_escapes.py @@ -9,17 +9,15 @@ class HexEscapes(Transform): """Pre-AST regex pass to decode hex escape sequences.""" - def execute(self): - # This works on the raw source before/after AST - # But since we operate on AST, we decode hex in string literal raw values + def execute(self) -> bool: + # Decode hex/unicode escapes in string literal raw values (value already decoded by parser) - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'Literal' or not isinstance(node.get('value'), str): return raw_string = node.get('raw', '') if '\\x' not in raw_string and '\\u' not in raw_string: return - # The value is already decoded by parser, just fix raw value = node['value'] new_raw = ( '"' @@ -38,7 +36,7 @@ def enter(node, parent, key, index): return self.has_changed() -def decode_hex_escapes_source(code): +def decode_hex_escapes_source(code: str) -> str: """Decode hex escapes in source code string (pre-parse pass). Only decodes hex escapes that produce printable characters (0x20-0x7e), @@ -47,22 +45,21 @@ def decode_hex_escapes_source(code): are left as \\xHH to avoid breaking the parser. """ - def replace_in_string(match_result): + def replace_in_string(match_result: re.Match) -> str: quote = match_result.group(1) content = match_result.group(2) # Decode hex escapes, but skip backslash, both quote chars, # and control chars. Quote chars are left for AST-level handling # which normalizes to double quotes like Babel. - def replace_hex_in_context(hex_match): - value = int(hex_match.group(1), 16) - if 0x20 <= value <= 0x7E and value not in (0x22, 0x27, 0x5C): - return chr(value) + def replace_hex_in_context(hex_match: re.Match) -> str: + char_value = int(hex_match.group(1), 16) + if 0x20 <= char_value <= 0x7E and char_value not in (0x22, 0x27, 0x5C): + return chr(char_value) return hex_match.group(0) decoded = re.sub(r'\\x([0-9a-fA-F]{2})', replace_hex_in_context, content) return quote + decoded + quote # Match string literals and decode hex escapes within them - result = re.sub(r"""(['"])((?:(?!\1|\\).|\\.)*?)\1""", replace_in_string, code) - return result + return re.sub(r"""(['"])((?:(?!\1|\\).|\\.)*?)\1""", replace_in_string, code) diff --git a/pyjsclear/transforms/hex_numerics.py b/pyjsclear/transforms/hex_numerics.py index 2eb5256..97461b0 100644 --- a/pyjsclear/transforms/hex_numerics.py +++ b/pyjsclear/transforms/hex_numerics.py @@ -7,9 +7,9 @@ class HexNumerics(Transform): """Convert hex numeric literals to decimal representation.""" - def execute(self): + def execute(self) -> bool: - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'Literal': return value = node.get('value') diff --git a/pyjsclear/transforms/jj_decode.py b/pyjsclear/transforms/jj_decode.py index 99b3924..148f5b7 100644 --- a/pyjsclear/transforms/jj_decode.py +++ b/pyjsclear/transforms/jj_decode.py @@ -18,7 +18,7 @@ # --------------------------------------------------------------------------- -def is_jj_encoded(code): +def is_jj_encoded(code: str) -> bool: """Return True if *code* looks like JJEncoded JavaScript. Checks for the ``VARNAME=~[]`` initialisation pattern that begins every @@ -37,84 +37,91 @@ def is_jj_encoded(code): _OBJECT_STR = '[object Object]' +# Single-char JS escape sequences +_SINGLE_CHAR_ESCAPES: dict[str, str] = { + 'n': '\n', 'r': '\r', 't': '\t', + '\\': '\\', "'": "'", '"': '"', + '/': '/', 'b': '\b', 'f': '\f', +} -def _split_at_depth_zero(text, delimiter): + +def _split_at_depth_zero(text: str, delimiter: str) -> list[str]: """Split *text* on *delimiter* only when bracket/paren depth is 0 and not inside a string literal. All logic is iterative.""" parts = [] current = [] depth = 0 - i = 0 + index = 0 in_string = None - while i < len(text): - ch = text[i] + while index < len(text): + character = text[index] if in_string is not None: - current.append(ch) - if ch == '\\' and i + 1 < len(text): - i += 1 - current.append(text[i]) - elif ch == in_string: + current.append(character) + if character == '\\' and index + 1 < len(text): + index += 1 + current.append(text[index]) + elif character == in_string: in_string = None - i += 1 + index += 1 continue - if ch in ('"', "'"): - in_string = ch - current.append(ch) - i += 1 + if character in ('"', "'"): + in_string = character + current.append(character) + index += 1 continue - if ch in ('(', '[', '{'): + if character in ('(', '[', '{'): depth += 1 - current.append(ch) - i += 1 + current.append(character) + index += 1 continue - if ch in (')', ']', '}'): + if character in (')', ']', '}'): depth -= 1 - current.append(ch) - i += 1 + current.append(character) + index += 1 continue - if depth == 0 and text[i:i + len(delimiter)] == delimiter: + if depth == 0 and text[index:index + len(delimiter)] == delimiter: parts.append(''.join(current)) current = [] - i += len(delimiter) + index += len(delimiter) continue - current.append(ch) - i += 1 + current.append(character) + index += 1 parts.append(''.join(current)) return parts -def _find_matching_close(text, start, open_ch, close_ch): - """Return index of *close_ch* matching *open_ch* at *start*. +def _find_matching_close(text: str, start: int, open_character: str, close_character: str) -> int: + """Return index of *close_character* matching *open_character* at *start*. Iterative, respects strings.""" depth = 0 in_string = None - i = start - while i < len(text): - ch = text[i] + index = start + while index < len(text): + character = text[index] if in_string is not None: - if ch == '\\' and i + 1 < len(text): - i += 2 + if character == '\\' and index + 1 < len(text): + index += 2 continue - if ch == in_string: + if character == in_string: in_string = None - i += 1 + index += 1 continue - if ch in ('"', "'"): - in_string = ch - elif ch == open_ch: + if character in ('"', "'"): + in_string = character + elif character == open_character: depth += 1 - elif ch == close_ch: + elif character == close_character: depth -= 1 if depth == 0: - return i - i += 1 + return index + index += 1 return -1 @@ -123,7 +130,7 @@ def _find_matching_close(text, start, open_ch, close_ch): # --------------------------------------------------------------------------- -def _parse_symbol_table(stmt, varname): +def _parse_symbol_table(stmt: str, varname: str) -> dict[str, int | str] | None: """Parse the ``$={___:++$, ...}`` statement and return a dict mapping property names to their resolved values (ints or chars).""" prefix = varname + '=' @@ -137,7 +144,7 @@ def _parse_symbol_table(stmt, varname): else: return None - table = {} + table: dict[str, int | str] = {} counter = -1 # ~[] = -1 entries = _split_at_depth_zero(body, ',') @@ -145,11 +152,11 @@ def _parse_symbol_table(stmt, varname): entry = entry.strip() if not entry: continue - colon_idx = entry.find(':') - if colon_idx == -1: + colon_index = entry.find(':') + if colon_index == -1: continue - key = entry[:colon_idx].strip() - value_expr = entry[colon_idx + 1:].strip() + key = entry[:colon_index].strip() + value_expr = entry[colon_index + 1:].strip() if value_expr.startswith('++'): counter += 1 @@ -163,16 +170,16 @@ def _parse_symbol_table(stmt, varname): # Walk backwards to find matching [ depth = 0 bracket_start = -1 - j = bracket_end - while j >= 0: - if value_expr[j] == ']': + scan = bracket_end + while scan >= 0: + if value_expr[scan] == ']': depth += 1 - elif value_expr[j] == '[': + elif value_expr[scan] == '[': depth -= 1 if depth == 0: - bracket_start = j + bracket_start = scan break - j -= 1 + scan -= 1 if bracket_start <= 0: continue @@ -192,7 +199,7 @@ def _parse_symbol_table(stmt, varname): return table -def _eval_coercion(expr, varname): +def _eval_coercion(expr: str, varname: str) -> str | None: """Evaluate a coercion expression to a string. Handles: (![]+"") -> "false", (!""+"") -> "true", @@ -212,7 +219,7 @@ def _eval_coercion(expr, varname): if expr == '![]': return 'false' - if expr == '!""' or expr == "!''": + if expr in ('!""', "!''"): return 'true' if expr == '{}': return _OBJECT_STR @@ -220,7 +227,7 @@ def _eval_coercion(expr, varname): if expr == varname + '[' + varname + ']': return 'undefined' # (!VARNAME) where VARNAME is object -> false - if expr == '!' + varname or expr == '(!' + varname + ')': + if expr in ('!' + varname, '(!' + varname + ')'): return 'false' # General X[X] pattern if re.match(r'^([a-zA-Z_$][a-zA-Z0-9_$]*)\[\1\]$', expr): @@ -239,14 +246,14 @@ def _eval_coercion(expr, varname): _MAX_EVAL_DEPTH = 100 -def _eval_expr(expr, table, varname, _depth=0): +def _eval_expr(expr: str, table: dict[str, int | str], varname: str, depth: int = 0) -> str | None: """Evaluate a JJEncode expression to a string value. Handles symbol-table references, string literals, coercion expressions with indexing, sub-assignments, and concatenation. Returns the resolved string or None. """ - if _depth > _MAX_EVAL_DEPTH: + if depth > _MAX_EVAL_DEPTH: return None expr = expr.strip() @@ -281,15 +288,15 @@ def _eval_expr(expr, table, varname, _depth=0): else: break - val = _eval_inner(inner, table, varname, _depth + 1) + val = _eval_inner(inner, table, varname, depth + 1) if not rest: return val # Check for [index] after the paren if rest.startswith('[') and rest.endswith(']'): if val is None: return None - idx_expr = rest[1:-1].strip() - idx = _resolve_int(idx_expr, table, varname) + index_expr = rest[1:-1].strip() + idx = _resolve_int(index_expr, table, varname) if isinstance(val, str) and idx is not None and 0 <= idx < len(val): return val[idx] return None @@ -309,8 +316,8 @@ def _eval_expr(expr, table, varname, _depth=0): key = expr[len(prefix):bracket_pos] str_val = table.get(key) if isinstance(str_val, str) and expr.endswith(']'): - idx_expr = expr[bracket_pos + 1:-1] - idx = _resolve_int(idx_expr, table, varname) + index_expr = expr[bracket_pos + 1:-1] + idx = _resolve_int(index_expr, table, varname) if idx is not None and 0 <= idx < len(str_val): return str_val[idx] return None @@ -326,20 +333,20 @@ def _eval_expr(expr, table, varname, _depth=0): tokens = _split_at_depth_zero(expr, '+') if len(tokens) > 1: parts = [] - for t in tokens: - v = _eval_expr(t, table, varname, _depth + 1) - if v is None: + for token in tokens: + token_val = _eval_expr(token, table, varname, depth + 1) + if token_val is None: return None - parts.append(v) + parts.append(token_val) return ''.join(parts) return None -def _eval_inner(inner, table, varname, _depth=0): +def _eval_inner(inner: str, table: dict[str, int | str], varname: str, depth: int = 0) -> str | None: """Evaluate the inside of a parenthesised expression. Handles sub-assignments and simple expressions.""" - if _depth > _MAX_EVAL_DEPTH: + if depth > _MAX_EVAL_DEPTH: return None prefix = varname + '.' @@ -350,7 +357,7 @@ def _eval_inner(inner, table, varname, _depth=0): if eq_pos is not None: key = inner[len(prefix):eq_pos] rhs = inner[eq_pos + 1:] - val = _eval_expr(rhs, table, varname, _depth + 1) + val = _eval_expr(rhs, table, varname, depth + 1) if val is not None: table[key] = val return val @@ -361,41 +368,41 @@ def _eval_inner(inner, table, varname, _depth=0): return coercion_str # Just a nested expression - return _eval_expr(inner, table, varname, _depth + 1) + return _eval_expr(inner, table, varname, depth + 1) -def _find_top_level_eq(expr): +def _find_top_level_eq(expr: str) -> int | None: """Find the position of the first ``=`` at depth 0 that is not ``==``.""" depth = 0 in_string = None - i = 0 - while i < len(expr): - ch = expr[i] + index = 0 + while index < len(expr): + character = expr[index] if in_string is not None: - if ch == '\\' and i + 1 < len(expr): - i += 2 + if character == '\\' and index + 1 < len(expr): + index += 2 continue - if ch == in_string: + if character == in_string: in_string = None - i += 1 + index += 1 continue - if ch in ('"', "'"): - in_string = ch - elif ch in ('(', '[', '{'): + if character in ('"', "'"): + in_string = character + elif character in ('(', '[', '{'): depth += 1 - elif ch in (')', ']', '}'): + elif character in (')', ']', '}'): depth -= 1 - elif ch == '=' and depth == 0: + elif character == '=' and depth == 0: # Check not == - if i + 1 < len(expr) and expr[i + 1] == '=': - i += 2 + if index + 1 < len(expr) and expr[index + 1] == '=': + index += 2 continue - return i - i += 1 + return index + index += 1 return None -def _eval_coercion_indexed(expr, table, varname): +def _eval_coercion_indexed(expr: str, table: dict[str, int | str], varname: str) -> str | None: """Handle ``(![]+"")[$._$_]`` — coercion string indexed by a symbol table reference.""" if not expr.endswith(']'): @@ -404,16 +411,16 @@ def _eval_coercion_indexed(expr, table, varname): bracket_end = len(expr) - 1 depth = 0 bracket_start = -1 - j = bracket_end - while j >= 0: - if expr[j] == ']': + scan = bracket_end + while scan >= 0: + if expr[scan] == ']': depth += 1 - elif expr[j] == '[': + elif expr[scan] == '[': depth -= 1 if depth == 0: - bracket_start = j + bracket_start = scan break - j -= 1 + scan -= 1 if bracket_start <= 0: return None @@ -434,7 +441,7 @@ def _eval_coercion_indexed(expr, table, varname): return '' -def _resolve_int(expr, table, varname): +def _resolve_int(expr: str, table: dict[str, int | str], varname: str) -> int | None: """Resolve an expression to an integer.""" expr = expr.strip() prefix = varname + '.' @@ -455,7 +462,7 @@ def _resolve_int(expr, table, varname): # --------------------------------------------------------------------------- -def _parse_augment_statement(stmt, table, varname): +def _parse_augment_statement(stmt: str, table: dict[str, int | str], varname: str) -> None: """Parse statements that build multi-character strings like "constructor" and "return" by concatenation, and store intermediate single-char sub-assignments into the table.""" @@ -492,51 +499,51 @@ def _parse_augment_statement(stmt, table, varname): # --------------------------------------------------------------------------- -def _decode_js_string_literal(s): +def _decode_js_string_literal(content: str) -> str: """Decode escapes in a JS string literal content (between quotes). Only handles \\\\ -> \\, \\\" -> \", \\' -> ', and leaves everything else (like \\1, \\x, \\u) as-is for later processing.""" result = [] - i = 0 - while i < len(s): - if s[i] == '\\' and i + 1 < len(s): - nch = s[i + 1] - if nch in ('"', "'", '\\'): - result.append(nch) - i += 2 + index = 0 + while index < len(content): + if content[index] == '\\' and index + 1 < len(content): + next_character = content[index + 1] + if next_character in ('"', "'", '\\'): + result.append(next_character) + index += 2 continue - result.append(s[i]) - i += 1 + result.append(content[index]) + index += 1 return ''.join(result) -def _decode_escapes(s): +def _decode_escapes(text: str) -> str: """Decode octal (\\NNN), hex (\\xNN), unicode (\\uNNNN) escape sequences in a single left-to-right pass. Also handles standard single-char escapes.""" result = [] - i = 0 - while i < len(s): - if s[i] == '\\' and i + 1 < len(s): - nch = s[i + 1] + index = 0 + while index < len(text): + if text[index] == '\\' and index + 1 < len(text): + next_character = text[index + 1] # Unicode escape \uNNNN - if nch == 'u' and i + 5 < len(s): - hex_str = s[i + 2:i + 6] + if next_character == 'u' and index + 5 < len(text): + hex_str = text[index + 2:index + 6] try: result.append(chr(int(hex_str, 16))) - i += 6 + index += 6 continue except ValueError: pass # Hex escape \xNN - if nch == 'x' and i + 3 < len(s): - hex_str = s[i + 2:i + 4] + if next_character == 'x' and index + 3 < len(text): + hex_str = text[index + 2:index + 4] try: result.append(chr(int(hex_str, 16))) - i += 4 + index += 4 continue except ValueError: pass @@ -544,35 +551,30 @@ def _decode_escapes(s): # Octal escape: JS allows \0-\377 (max value 255). # First digit 0-3: up to 3 total digits (\000-\377). # First digit 4-7: up to 2 total digits (\40-\77). - if '0' <= nch <= '7': - max_digits = 3 if nch <= '3' else 2 + if '0' <= next_character <= '7': + max_digits = 3 if next_character <= '3' else 2 octal = '' - j = i + 1 - while j < len(s) and j < i + 1 + max_digits and '0' <= s[j] <= '7': - octal += s[j] - j += 1 + scan = index + 1 + while scan < len(text) and scan < index + 1 + max_digits and '0' <= text[scan] <= '7': + octal += text[scan] + scan += 1 result.append(chr(int(octal, 8))) - i = j + index = scan continue # Standard single-char escapes - _esc = { - 'n': '\n', 'r': '\r', 't': '\t', - '\\': '\\', "'": "'", '"': '"', - '/': '/', 'b': '\b', 'f': '\f', - } - if nch in _esc: - result.append(_esc[nch]) - i += 2 + if next_character in _SINGLE_CHAR_ESCAPES: + result.append(_SINGLE_CHAR_ESCAPES[next_character]) + index += 2 continue # Unknown escape — keep literal - result.append(nch) - i += 2 + result.append(next_character) + index += 2 continue - result.append(s[i]) - i += 1 + result.append(text[index]) + index += 1 return ''.join(result) @@ -582,7 +584,7 @@ def _decode_escapes(s): # --------------------------------------------------------------------------- -def _extract_payload_expression(stmt, varname): +def _extract_payload_expression(stmt: str, varname: str) -> str | None: """Extract the inner concatenation expression from the payload statement ``$.$($.$(EXPR)())()``.""" # Find VARNAME.$(VARNAME.$( @@ -596,28 +598,28 @@ def _extract_payload_expression(stmt, varname): # Find matching ) for the inner $.$( depth = 1 in_string = None - i = start - while i < len(stmt): - ch = stmt[i] + index = start + while index < len(stmt): + character = stmt[index] if in_string is not None: - if ch == '\\' and i + 1 < len(stmt): - i += 2 + if character == '\\' and index + 1 < len(stmt): + index += 2 continue - if ch == in_string: + if character == in_string: in_string = None - i += 1 + index += 1 continue - if ch in ('"', "'"): - in_string = ch - i += 1 + if character in ('"', "'"): + in_string = character + index += 1 continue - if ch == '(': + if character == '(': depth += 1 - elif ch == ')': + elif character == ')': depth -= 1 if depth == 0: - return stmt[start:i] - i += 1 + return stmt[start:index] + index += 1 return None @@ -627,7 +629,7 @@ def _extract_payload_expression(stmt, varname): # --------------------------------------------------------------------------- -def jj_decode(code): +def jj_decode(code: str) -> str | None: """Decode JJEncoded JavaScript. Returns the decoded string, or ``None`` on any failure.""" try: @@ -637,16 +639,16 @@ def jj_decode(code): return None -def _jj_decode_inner(code): +def _jj_decode_inner(code: str) -> str | None: if not code or not code.strip(): return None stripped = code.strip() - m = re.match(r'^([a-zA-Z_$][a-zA-Z0-9_$]*)\s*=\s*~\s*\[\s*\]', stripped) - if not m: + match = re.match(r'^([a-zA-Z_$][a-zA-Z0-9_$]*)\s*=\s*~\s*\[\s*\]', stripped) + if not match: return None - varname = m.group(1) + varname = match.group(1) # Find the JJEncode line jj_line = None @@ -661,7 +663,7 @@ def _jj_decode_inner(code): # Split into semicolon-delimited statements at depth 0 stmts = _split_at_depth_zero(jj_line, ';') - stmts = [s.strip() for s in stmts if s.strip()] + stmts = [statement.strip() for statement in stmts if statement.strip()] if len(stmts) < 5: return None diff --git a/pyjsclear/transforms/jsfuck_decode.py b/pyjsclear/transforms/jsfuck_decode.py index dac5f49..6dbe547 100644 --- a/pyjsclear/transforms/jsfuck_decode.py +++ b/pyjsclear/transforms/jsfuck_decode.py @@ -6,7 +6,20 @@ string passed to Function(). """ -def is_jsfuck(code): +from enum import StrEnum + + +class _JSType(StrEnum): + ARRAY = 'array' + BOOL = 'bool' + NUMBER = 'number' + STRING = 'string' + UNDEFINED = 'undefined' + OBJECT = 'object' + FUNCTION = 'function' + + +def is_jsfuck(code: str) -> bool: """Check if code is JSFuck-encoded. JSFuck code consists only of []()!+ characters (with optional whitespace/semicolons). @@ -18,7 +31,7 @@ def is_jsfuck(code): # Only count the six JSFuck operator characters — whitespace and # semicolons are not distinctive and inflate the ratio on minified JS. jsfuck_chars = set('[]()!+') - jsfuck_count = sum(1 for c in stripped if c in jsfuck_chars) + jsfuck_count = sum(1 for character in stripped if character in jsfuck_chars) return jsfuck_count / len(stripped) > 0.95 @@ -32,45 +45,45 @@ class _JSValue: __slots__ = ('val', 'type') - def __init__(self, val, typ): + def __init__(self, val: object, js_type: _JSType | str) -> None: self.val = val - self.type = typ # 'array', 'bool', 'number', 'string', 'undefined', 'object', 'function' + self.type = js_type # -- coercion helpers --------------------------------------------------- - def to_number(self): + def to_number(self) -> int | float: match self.type: - case 'number': + case _JSType.NUMBER: return self.val - case 'bool': + case _JSType.BOOL: return 1 if self.val else 0 - case 'string': - s = self.val.strip() - if s == '': + case _JSType.STRING: + stripped = self.val.strip() + if stripped == '': return 0 try: - return int(s) + return int(stripped) except ValueError: try: - return float(s) + return float(stripped) except ValueError: return float('nan') - case 'array': + case _JSType.ARRAY: if len(self.val) == 0: return 0 if len(self.val) == 1: return _JSValue(self.val[0], _guess_type(self.val[0])).to_number() return float('nan') - case 'undefined': + case _JSType.UNDEFINED: return float('nan') case _: return float('nan') - def to_string(self): + def to_string(self) -> str: match self.type: - case 'string': + case _JSType.STRING: return self.val - case 'number': + case _JSType.NUMBER: if isinstance(self.val, float): if self.val != self.val: # NaN return 'NaN' @@ -82,9 +95,9 @@ def to_string(self): return str(int(self.val)) return str(self.val) return str(self.val) - case 'bool': + case _JSType.BOOL: return 'true' if self.val else 'false' - case 'array': + case _JSType.ARRAY: parts = [] for item in self.val: if item is None: @@ -94,139 +107,143 @@ def to_string(self): else: parts.append(_JSValue(item, _guess_type(item)).to_string()) return ','.join(parts) - case 'undefined': + case _JSType.UNDEFINED: return 'undefined' - case 'object': + case _JSType.OBJECT: return '[object Object]' case _: return str(self.val) - def to_bool(self): + def to_bool(self) -> bool: match self.type: - case 'bool': + case _JSType.BOOL: return self.val - case 'number': + case _JSType.NUMBER: return self.val != 0 and self.val == self.val # 0 and NaN are falsy - case 'string': + case _JSType.STRING: return len(self.val) > 0 - case 'array': + case _JSType.ARRAY: return True # arrays are always truthy in JS - case 'undefined': + case _JSType.UNDEFINED: return False - case 'object': + case _JSType.OBJECT: return True case _: return bool(self.val) - def get_property(self, key): + def get_property(self, key: '_JSValue') -> '_JSValue': """Property access: self[key].""" - key_str = key.to_string() if isinstance(key, _JSValue) else str(key) - - if self.type == 'string': - # String indexing - try: - idx = int(key_str) - if 0 <= idx < len(self.val): - return _JSValue(self.val[idx], 'string') - except (ValueError, IndexError): - pass - # String properties - if key_str == 'length': - return _JSValue(len(self.val), 'number') - if key_str == 'constructor': - return _STRING_CONSTRUCTOR - # String.prototype methods - return _get_string_method(self, key_str) - - if self.type == 'array': - try: - idx = int(key_str) - if 0 <= idx < len(self.val): - item = self.val[idx] - if isinstance(item, _JSValue): - return item - return _JSValue(item, _guess_type(item)) - except (ValueError, IndexError): - pass - if key_str == 'length': - return _JSValue(len(self.val), 'number') - if key_str == 'constructor': - return _ARRAY_CONSTRUCTOR - # Array methods that JSFuck commonly accesses - if key_str in ( - 'flat', - 'fill', - 'find', - 'filter', - 'entries', - 'concat', - 'join', - 'sort', - 'reverse', - 'slice', - 'map', - 'forEach', - 'reduce', - 'some', - 'every', - 'indexOf', - 'includes', - 'keys', - 'values', - 'at', - 'pop', - 'push', - 'shift', - 'unshift', - 'splice', - 'toString', - 'valueOf', - ): - return _JSValue(key_str, 'function') - - if self.type == 'number': - if key_str == 'constructor': - return _NUMBER_CONSTRUCTOR - if key_str == 'toString': - return _JSValue('toString', 'function') - return _JSValue(None, 'undefined') - - if self.type == 'bool': - if key_str == 'constructor': - return _BOOLEAN_CONSTRUCTOR - return _JSValue(None, 'undefined') - - if self.type == 'function': - if key_str == 'constructor': - return _FUNCTION_CONSTRUCTOR - return _JSValue(None, 'undefined') - - if self.type == 'object': - if key_str == 'constructor': - return _OBJECT_CONSTRUCTOR - return _JSValue(None, 'undefined') - - return _JSValue(None, 'undefined') - - def __repr__(self): + key_string = key.to_string() if isinstance(key, _JSValue) else str(key) + + match self.type: + case _JSType.STRING: + return _get_string_property(self, key_string) + case _JSType.ARRAY: + return _get_array_property(self, key_string) + case _JSType.NUMBER: + if key_string == 'constructor': + return _NUMBER_CONSTRUCTOR + if key_string == 'toString': + return _JSValue('toString', _JSType.FUNCTION) + return _JSValue(None, _JSType.UNDEFINED) + case _JSType.BOOL: + if key_string == 'constructor': + return _BOOLEAN_CONSTRUCTOR + return _JSValue(None, _JSType.UNDEFINED) + case _JSType.FUNCTION: + if key_string == 'constructor': + return _FUNCTION_CONSTRUCTOR + return _JSValue(None, _JSType.UNDEFINED) + case _JSType.OBJECT: + if key_string == 'constructor': + return _OBJECT_CONSTRUCTOR + return _JSValue(None, _JSType.UNDEFINED) + case _: + return _JSValue(None, _JSType.UNDEFINED) + + def __repr__(self) -> str: return f'_JSValue({self.val!r}, {self.type!r})' -def _guess_type(val): - if isinstance(val, bool): - return 'bool' - if isinstance(val, (int, float)): - return 'number' - if isinstance(val, str): - return 'string' - if isinstance(val, list): - return 'array' - if val is None: - return 'undefined' - return 'object' +def _get_string_property(string_value: '_JSValue', key_string: str) -> '_JSValue': + """Return the result of property access on a string value.""" + try: + index = int(key_string) + if 0 <= index < len(string_value.val): + return _JSValue(string_value.val[index], _JSType.STRING) + except (ValueError, IndexError): + pass + if key_string == 'length': + return _JSValue(len(string_value.val), _JSType.NUMBER) + if key_string == 'constructor': + return _STRING_CONSTRUCTOR + return _get_string_method(key_string) + + +def _get_array_property(array_value: '_JSValue', key_string: str) -> '_JSValue': + """Return the result of property access on an array value.""" + try: + index = int(key_string) + if 0 <= index < len(array_value.val): + item = array_value.val[index] + if isinstance(item, _JSValue): + return item + return _JSValue(item, _guess_type(item)) + except (ValueError, IndexError): + pass + if key_string == 'length': + return _JSValue(len(array_value.val), _JSType.NUMBER) + if key_string == 'constructor': + return _ARRAY_CONSTRUCTOR + # Array methods that JSFuck commonly accesses + if key_string in ( + 'flat', + 'fill', + 'find', + 'filter', + 'entries', + 'concat', + 'join', + 'sort', + 'reverse', + 'slice', + 'map', + 'forEach', + 'reduce', + 'some', + 'every', + 'indexOf', + 'includes', + 'keys', + 'values', + 'at', + 'pop', + 'push', + 'shift', + 'unshift', + 'splice', + 'toString', + 'valueOf', + ): + return _JSValue(key_string, _JSType.FUNCTION) + return _JSValue(None, _JSType.UNDEFINED) + + +def _guess_type(value: object) -> _JSType: + if isinstance(value, bool): + return _JSType.BOOL + if isinstance(value, (int, float)): + return _JSType.NUMBER + if isinstance(value, str): + return _JSType.STRING + if isinstance(value, list): + return _JSType.ARRAY + if value is None: + return _JSType.UNDEFINED + return _JSType.OBJECT -def _get_string_method(string_val, method_name): +def _get_string_method(method_name: str) -> '_JSValue': """Return a callable _JSValue wrapping a string method.""" if method_name in ( 'italics', @@ -265,17 +282,17 @@ def _get_string_method(string_val, method_name): 'normalize', 'flat', ): - return _JSValue(method_name, 'function') - return _JSValue(None, 'undefined') + return _JSValue(method_name, _JSType.FUNCTION) + return _JSValue(None, _JSType.UNDEFINED) # Sentinel constructors for property chain resolution -_STRING_CONSTRUCTOR = _JSValue('String', 'function') -_NUMBER_CONSTRUCTOR = _JSValue('Number', 'function') -_BOOLEAN_CONSTRUCTOR = _JSValue('Boolean', 'function') -_ARRAY_CONSTRUCTOR = _JSValue('Array', 'function') -_OBJECT_CONSTRUCTOR = _JSValue('Object', 'function') -_FUNCTION_CONSTRUCTOR = _JSValue('Function', 'function') +_STRING_CONSTRUCTOR = _JSValue('String', _JSType.FUNCTION) +_NUMBER_CONSTRUCTOR = _JSValue('Number', _JSType.FUNCTION) +_BOOLEAN_CONSTRUCTOR = _JSValue('Boolean', _JSType.FUNCTION) +_ARRAY_CONSTRUCTOR = _JSValue('Array', _JSType.FUNCTION) +_OBJECT_CONSTRUCTOR = _JSValue('Object', _JSType.FUNCTION) +_FUNCTION_CONSTRUCTOR = _JSValue('Function', _JSType.FUNCTION) # Known constructor-of-constructor chain results _CONSTRUCTOR_MAP = { @@ -293,12 +310,12 @@ def _get_string_method(string_val, method_name): # --------------------------------------------------------------------------- -def _tokenize(code): +def _tokenize(code: str) -> list[str]: """Tokenize JSFuck code into a list of characters/tokens.""" tokens = [] - for ch in code: - if ch in '[]()!+': - tokens.append(ch) + for character in code: + if character in '[]()!+': + tokens.append(character) # Skip whitespace, semicolons return tokens @@ -334,166 +351,166 @@ class _Parser: arbitrarily deep nesting never overflows the Python call stack. """ - def __init__(self, tokens): + def __init__(self, tokens: list[str]) -> None: self.tokens = tokens self.pos = 0 - self.captured = None # Result from Function(body)() + self.captured: str | None = None # Result from Function(body)() - def peek(self): + def peek(self) -> str | None: if self.pos < len(self.tokens): return self.tokens[self.pos] return None - def consume(self, expected=None): + def consume(self, expected: str | None = None) -> str: if self.pos >= len(self.tokens): raise _ParseError('Unexpected end of input') - tok = self.tokens[self.pos] - if expected is not None and tok != expected: - raise _ParseError(f'Expected {expected!r}, got {tok!r}') + token = self.tokens[self.pos] + if expected is not None and token != expected: + raise _ParseError(f'Expected {expected!r}, got {token!r}') self.pos += 1 - return tok + return token # ------------------------------------------------------------------ - def parse(self): + def parse(self) -> _JSValue: """Parse and evaluate the full expression (iterative).""" - val_stack = [] - cont = [(_K_DONE,)] + value_stack: list[_JSValue] = [] + continuation: list[tuple] = [(_K_DONE,)] state = _S_EXPR while True: if state == _S_EXPR: # expression = unary ('+' unary)* - cont.append((_K_EXPR_LOOP,)) + continuation.append((_K_EXPR_LOOP,)) state = _S_UNARY elif state == _S_UNARY: # Collect prefix operators, then parse postfix - ops = [] + operators = [] while self.peek() in ('!', '+'): - ops.append(self.consume()) - cont.append((_K_UNARY_APPLY, ops)) + operators.append(self.consume()) + continuation.append((_K_UNARY_APPLY, operators)) state = _S_POSTFIX elif state == _S_POSTFIX: # Parse primary, then handle postfix [ ] and ( ) - cont.append((_K_POSTFIX_LOOP, None)) # receiver=None + continuation.append((_K_POSTFIX_LOOP, None)) # receiver=None state = _S_PRIMARY elif state == _S_PRIMARY: - tok = self.peek() - if tok == '(': + token = self.peek() + if token == '(': self.consume('(') - cont.append((_K_PAREN_CLOSE,)) + continuation.append((_K_PAREN_CLOSE,)) state = _S_EXPR - elif tok == '[': + elif token == '[': self.consume('[') if self.peek() == ']': self.consume(']') - val_stack.append(_JSValue([], 'array')) + value_stack.append(_JSValue([], _JSType.ARRAY)) state = _S_RESUME else: - cont.append((_K_ARRAY_ELEM, [])) + continuation.append((_K_ARRAY_ELEM, [])) state = _S_EXPR else: raise _ParseError( - f'Unexpected token: {tok!r} at pos {self.pos}') + f'Unexpected token: {token!r} at pos {self.pos}') elif state == _S_RESUME: - k = cont.pop() - ktype = k[0] + continuation_frame = continuation.pop() + continuation_type = continuation_frame[0] - if ktype == _K_DONE: - return val_stack.pop() + if continuation_type == _K_DONE: + return value_stack.pop() - elif ktype == _K_PAREN_CLOSE: + elif continuation_type == _K_PAREN_CLOSE: self.consume(')') state = _S_RESUME - elif ktype == _K_ARRAY_ELEM: - elements = k[1] - elements.append(val_stack.pop()) + elif continuation_type == _K_ARRAY_ELEM: + elements = continuation_frame[1] + elements.append(value_stack.pop()) if self.peek() not in (']', None): - cont.append((_K_ARRAY_ELEM, elements)) + continuation.append((_K_ARRAY_ELEM, elements)) state = _S_EXPR else: self.consume(']') - val_stack.append(_JSValue(elements, 'array')) + value_stack.append(_JSValue(elements, _JSType.ARRAY)) state = _S_RESUME - elif ktype == _K_POSTFIX_LOOP: - receiver = k[1] - val = val_stack[-1] + elif continuation_type == _K_POSTFIX_LOOP: + receiver = continuation_frame[1] + current_value = value_stack[-1] if self.peek() == '[': self.consume('[') - val_stack.pop() - cont.append((_K_POSTFIX_BRACKET, val)) + value_stack.pop() + continuation.append((_K_POSTFIX_BRACKET, current_value)) state = _S_EXPR elif self.peek() == '(': self.consume('(') if self.peek() == ')': self.consume(')') - val_stack.pop() - result = self._call(val, [], receiver) - val_stack.append(result) - cont.append((_K_POSTFIX_LOOP, None)) + value_stack.pop() + result = self._call(current_value, [], receiver) + value_stack.append(result) + continuation.append((_K_POSTFIX_LOOP, None)) state = _S_RESUME else: - val_stack.pop() - cont.append((_K_POSTFIX_ARGDONE, val, receiver)) + value_stack.pop() + continuation.append((_K_POSTFIX_ARGDONE, current_value, receiver)) state = _S_EXPR else: # No more postfix ops state = _S_RESUME - elif ktype == _K_POSTFIX_BRACKET: - parent_val = k[1] - key = val_stack.pop() + elif continuation_type == _K_POSTFIX_BRACKET: + parent_value = continuation_frame[1] + key = value_stack.pop() self.consume(']') - val_stack.append(parent_val.get_property(key)) - cont.append((_K_POSTFIX_LOOP, parent_val)) + value_stack.append(parent_value.get_property(key)) + continuation.append((_K_POSTFIX_LOOP, parent_value)) state = _S_RESUME - elif ktype == _K_POSTFIX_ARGDONE: - func = k[1] - receiver = k[2] - arg = val_stack.pop() + elif continuation_type == _K_POSTFIX_ARGDONE: + func = continuation_frame[1] + receiver = continuation_frame[2] + argument = value_stack.pop() self.consume(')') - result = self._call(func, [arg], receiver) - val_stack.append(result) - cont.append((_K_POSTFIX_LOOP, None)) + result = self._call(func, [argument], receiver) + value_stack.append(result) + continuation.append((_K_POSTFIX_LOOP, None)) state = _S_RESUME - elif ktype == _K_UNARY_APPLY: - ops = k[1] - val = val_stack.pop() - for op in reversed(ops): - if op == '!': - val = _JSValue(not val.to_bool(), 'bool') - elif op == '+': - val = _JSValue(val.to_number(), 'number') - val_stack.append(val) + elif continuation_type == _K_UNARY_APPLY: + operators = continuation_frame[1] + current_value = value_stack.pop() + for operator in reversed(operators): + if operator == '!': + current_value = _JSValue(not current_value.to_bool(), _JSType.BOOL) + elif operator == '+': + current_value = _JSValue(current_value.to_number(), _JSType.NUMBER) + value_stack.append(current_value) state = _S_RESUME - elif ktype == _K_EXPR_LOOP: + elif continuation_type == _K_EXPR_LOOP: if self.peek() == '+': self.consume('+') - left = val_stack.pop() - cont.append((_K_EXPR_ADD, left)) + left = value_stack.pop() + continuation.append((_K_EXPR_ADD, left)) state = _S_UNARY else: state = _S_RESUME - elif ktype == _K_EXPR_ADD: - left = k[1] - right = val_stack.pop() - val_stack.append(_js_add(left, right)) - cont.append((_K_EXPR_LOOP,)) + elif continuation_type == _K_EXPR_ADD: + left = continuation_frame[1] + right = value_stack.pop() + value_stack.append(_js_add(left, right)) + continuation.append((_K_EXPR_LOOP,)) state = _S_RESUME # ------------------------------------------------------------------ - def _call(self, func, args, receiver=None): + def _call(self, func: _JSValue, args: list[_JSValue], receiver: _JSValue | None = None) -> _JSValue: """Handle function call semantics. Only single-argument calls are supported (e.g. Function(body), @@ -501,53 +518,53 @@ def _call(self, func, args, receiver=None): emits multi-argument calls. """ # Function constructor: Function(body) returns a new function - if func.type == 'function' and func.val == 'Function': + if func.type == _JSType.FUNCTION and func.val == 'Function': if args: body = args[-1].to_string() - return _JSValue(('__function_body__', body), 'function') + return _JSValue(('__function_body__', body), _JSType.FUNCTION) # Calling a function created by Function(body) - if func.type == 'function' and isinstance(func.val, tuple): + if func.type == _JSType.FUNCTION and isinstance(func.val, tuple): if func.val[0] == '__function_body__': self.captured = func.val[1] - return _JSValue(None, 'undefined') + return _JSValue(None, _JSType.UNDEFINED) # Constructor property access — e.g., []["flat"]["constructor"] - if func.type == 'function' and isinstance(func.val, str): + if func.type == _JSType.FUNCTION and isinstance(func.val, str): name = func.val if name in _CONSTRUCTOR_MAP: if args: - return _JSValue(args[0].to_string(), 'string') - return _JSValue('', 'string') + return _JSValue(args[0].to_string(), _JSType.STRING) + return _JSValue('', _JSType.STRING) if name == 'italics': - return _JSValue('', 'string') + return _JSValue('', _JSType.STRING) if name == 'fontcolor': - return _JSValue('', 'string') + return _JSValue('', _JSType.STRING) # toString with radix — e.g., (10)["toString"](36) → "a" if name == 'toString' and args and receiver is not None: radix = args[0].to_number() if isinstance(radix, (int, float)) and radix == int(radix): radix = int(radix) - if 2 <= radix <= 36 and receiver.type == 'number': + if 2 <= radix <= 36 and receiver.type == _JSType.NUMBER: num = receiver.to_number() if isinstance(num, (int, float)) and num == int(num): - return _JSValue(_int_to_base(int(num), radix), 'string') + return _JSValue(_int_to_base(int(num), radix), _JSType.STRING) - return _JSValue(None, 'undefined') + return _JSValue(None, _JSType.UNDEFINED) -def _js_add(left, right): +def _js_add(left: _JSValue, right: _JSValue) -> _JSValue: """JS + operator with type coercion.""" - if left.type == 'string' or right.type == 'string': - return _JSValue(left.to_string() + right.to_string(), 'string') - if left.type in ('array', 'object') or right.type in ('array', 'object'): - return _JSValue(left.to_string() + right.to_string(), 'string') - return _JSValue(left.to_number() + right.to_number(), 'number') + if left.type == _JSType.STRING or right.type == _JSType.STRING: + return _JSValue(left.to_string() + right.to_string(), _JSType.STRING) + if left.type in (_JSType.ARRAY, _JSType.OBJECT) or right.type in (_JSType.ARRAY, _JSType.OBJECT): + return _JSValue(left.to_string() + right.to_string(), _JSType.STRING) + return _JSValue(left.to_number() + right.to_number(), _JSType.NUMBER) -def _int_to_base(num, base): +def _int_to_base(num: int, base: int) -> str: """Convert integer to string in given base (2-36), matching JS behavior.""" if num == 0: return '0' @@ -572,7 +589,7 @@ class _ParseError(Exception): # --------------------------------------------------------------------------- -def jsfuck_decode(code): +def jsfuck_decode(code: str) -> str | None: """Decode JSFuck-encoded JavaScript. Returns decoded string or None.""" if not code or not code.strip(): return None diff --git a/pyjsclear/transforms/logical_to_if.py b/pyjsclear/transforms/logical_to_if.py index df895f3..b471470 100644 --- a/pyjsclear/transforms/logical_to_if.py +++ b/pyjsclear/transforms/logical_to_if.py @@ -13,65 +13,65 @@ from .base import Transform -def _negate(expr): +def _negate(expression: dict) -> dict: """Wrap an expression in a logical NOT.""" return { 'type': 'UnaryExpression', 'operator': '!', 'prefix': True, - 'argument': expr, + 'argument': expression, } class LogicalToIf(Transform): """Convert logical/comma expressions in statement position to if-statements.""" - def execute(self): + def execute(self) -> bool: self._transform_bodies(self.ast) return self.has_changed() - def _transform_bodies(self, node): + def _transform_bodies(self, node: dict | list | object) -> None: """Walk all statement arrays and apply transforms.""" if not isinstance(node, dict): return for key, child in node.items(): if isinstance(child, list): if child and isinstance(child[0], dict) and 'type' in child[0]: - self._process_stmt_array(child) + self._process_statement_array(child) for item in child: self._transform_bodies(item) elif isinstance(child, dict) and 'type' in child: self._transform_bodies(child) - def _process_stmt_array(self, stmts): - i = 0 - while i < len(stmts): - stmt = stmts[i] - if not isinstance(stmt, dict): - i += 1 + def _process_statement_array(self, statements: list) -> None: + index = 0 + while index < len(statements): + statement = statements[index] + if not isinstance(statement, dict): + index += 1 continue - replacement = self._try_convert_stmt(stmt) + replacement = self._try_convert_stmt(statement) if replacement is not None: - stmts[i : i + 1] = replacement + statements[index : index + 1] = replacement self.set_changed() - i += len(replacement) + index += len(replacement) continue - i += 1 + index += 1 - def _try_convert_stmt(self, stmt): + def _try_convert_stmt(self, statement: dict) -> list | None: """Try to convert a statement. Returns replacement list or None.""" - match stmt.get('type'): + match statement.get('type'): case 'ExpressionStatement': - return self._handle_expression_stmt(stmt) + return self._handle_expression_stmt(statement) case 'ReturnStatement': - return self._handle_return_stmt(stmt) + return self._handle_return_stmt(statement) return None - def _handle_expression_stmt(self, stmt): + def _handle_expression_stmt(self, statement: dict) -> list | None: """Handle ExpressionStatement with logical or conditional.""" - expression = stmt.get('expression') + expression = statement.get('expression') if not isinstance(expression, dict): return None match expression.get('type'): @@ -81,9 +81,9 @@ def _handle_expression_stmt(self, stmt): return self._ternary_to_if(expression) return None - def _handle_return_stmt(self, stmt): + def _handle_return_stmt(self, statement: dict) -> list | None: """Handle ReturnStatement with sequence or logical expressions.""" - argument = stmt.get('argument') + argument = statement.get('argument') if not isinstance(argument, dict): return None @@ -97,49 +97,49 @@ def _handle_return_stmt(self, stmt): return None - def _split_return_sequence(self, seq): + def _split_return_sequence(self, sequence: dict) -> list | None: """Split return (a, b, c) into a; b; return c.""" - exprs = seq.get('expressions', []) - if len(exprs) <= 1: + expressions = sequence.get('expressions', []) + if len(expressions) <= 1: return None - new_stmts = [] - for expression in exprs[:-1]: + new_statements = [] + for expression in expressions[:-1]: if isinstance(expression, dict) and expression.get('type') == 'LogicalExpression': converted = self._logical_to_if(expression) if converted: - new_stmts.extend(converted) + new_statements.extend(converted) continue - new_stmts.append(make_expression_statement(expression)) - new_stmts.append({'type': 'ReturnStatement', 'argument': exprs[-1]}) - return new_stmts + new_statements.append(make_expression_statement(expression)) + new_statements.append({'type': 'ReturnStatement', 'argument': expressions[-1]}) + return new_statements - def _split_return_logical(self, logical): + def _split_return_logical(self, logical: dict) -> list | None: """Split return a || (b(), c) into if (!a) { b(); } return c.""" right = logical.get('right') if not (isinstance(right, dict) and right.get('type') == 'SequenceExpression'): return None - exprs = right.get('expressions', []) - if len(exprs) <= 1: + expressions = right.get('expressions', []) + if len(expressions) <= 1: return None test = logical.get('left') if logical.get('operator') == '||': test = _negate(test) - body_stmts = [make_expression_statement(e) for e in exprs[:-1]] - if_stmt = { + body_statements = [make_expression_statement(expression) for expression in expressions[:-1]] + if_statement = { 'type': 'IfStatement', 'test': test, - 'consequent': make_block_statement(body_stmts), + 'consequent': make_block_statement(body_statements), 'alternate': None, } - ret = {'type': 'ReturnStatement', 'argument': exprs[-1]} - return [if_stmt, ret] + return_statement = {'type': 'ReturnStatement', 'argument': expressions[-1]} + return [if_statement, return_statement] - def _logical_to_if(self, expr): + def _logical_to_if(self, expression: dict) -> list | None: """Convert a LogicalExpression to if-statement(s). Returns list of stmts or None.""" - left = expr.get('left') - match expr.get('operator'): + left = expression.get('left') + match expression.get('operator'): case '&&': test = left case '||': @@ -147,27 +147,27 @@ def _logical_to_if(self, expr): case _: return None - body_stmts = self._expr_to_stmts(expr.get('right')) - if_stmt = { + body_statements = self._expr_to_stmts(expression.get('right')) + if_statement = { 'type': 'IfStatement', 'test': test, - 'consequent': make_block_statement(body_stmts), + 'consequent': make_block_statement(body_statements), 'alternate': None, } - return [if_stmt] + return [if_statement] - def _ternary_to_if(self, expr): + def _ternary_to_if(self, expression: dict) -> list: """Convert a ConditionalExpression to if-else. Returns list of stmts or None.""" - if_stmt = { + if_statement = { 'type': 'IfStatement', - 'test': expr.get('test'), - 'consequent': make_block_statement(self._expr_to_stmts(expr.get('consequent'))), - 'alternate': make_block_statement(self._expr_to_stmts(expr.get('alternate'))), + 'test': expression.get('test'), + 'consequent': make_block_statement(self._expr_to_stmts(expression.get('consequent'))), + 'alternate': make_block_statement(self._expr_to_stmts(expression.get('alternate'))), } - return [if_stmt] + return [if_statement] - def _expr_to_stmts(self, expr): + def _expr_to_stmts(self, expression: dict | None) -> list: """Convert an expression to a list of statements.""" - if isinstance(expr, dict) and expr.get('type') == 'SequenceExpression': - return [make_expression_statement(e) for e in expr.get('expressions', [])] - return [make_expression_statement(expr)] + if isinstance(expression, dict) and expression.get('type') == 'SequenceExpression': + return [make_expression_statement(e) for e in expression.get('expressions', [])] + return [make_expression_statement(expression)] diff --git a/pyjsclear/transforms/member_chain_resolver.py b/pyjsclear/transforms/member_chain_resolver.py index 206cd9f..dd610c4 100644 --- a/pyjsclear/transforms/member_chain_resolver.py +++ b/pyjsclear/transforms/member_chain_resolver.py @@ -22,7 +22,7 @@ from .base import Transform -def _is_constant_expr(node): +def _is_constant_expr(node: dict) -> bool: """Check if a node is a constant expression safe to inline.""" if not isinstance(node, dict): return False @@ -36,32 +36,46 @@ def _is_constant_expr(node): return False +def _get_property_name(member_expr: dict, property_key: str) -> str | None: + """Extract the string name of a member expression's property.""" + prop = member_expr.get(property_key) + if not prop: + return None + if member_expr.get('computed'): + if not is_string_literal(prop): + return None + return prop['value'] + if is_identifier(prop): + return prop['name'] + return None + + class MemberChainResolver(Transform): """Resolve multi-level member chains (A.B.C) to literal values.""" - def execute(self): - # Maps: (class_name, prop_name) → AST node (constant expression) - class_constants = {} - # Maps: prop_name → class_name (from X.prop = ClassIdentifier assignments) - prop_to_class = {} + def execute(self) -> bool: + # Maps: (class_name, property_name) → AST node (constant expression) + class_constants: dict[tuple[str, str], dict] = {} + # Maps: property_name → class_name (from X.prop = ClassIdentifier assignments) + property_to_class: dict[str, str] = {} # Phase 1: Collect X.prop = constant_expr and X.prop = Identifier assignments - def collect(node, parent): + def collect(node: dict, parent: dict | None) -> None: if node.get('type') != 'AssignmentExpression': return if node.get('operator') != '=': return left = node.get('left') right = node.get('right') - obj_name, prop_name = get_member_names(left) - if not obj_name: + object_name, property_name = get_member_names(left) + if not object_name: return if _is_constant_expr(right): - class_constants[(obj_name, prop_name)] = right + class_constants[(object_name, property_name)] = right elif is_identifier(right): - # X.prop = SomeClass — record prop_name → SomeClass - prop_to_class[prop_name] = right['name'] + # X.prop = SomeClass — record property_name → SomeClass + property_to_class[property_name] = right['name'] simple_traverse(self.ast, collect) @@ -69,9 +83,9 @@ def collect(node, parent): return False # Phase 1b: Invalidate constants that are reassigned through alias chains. - # Pattern: A.B.C = expr where B → class_name via prop_to_class + # Pattern: A.B.C = expr where B → class_name via property_to_class # means (class_name, C) is NOT a true constant. - def invalidate_chain_assignments(node, parent): + def invalidate_chain_assignments(node: dict, parent: dict | None) -> None: if node.get('type') != 'AssignmentExpression': return left = node.get('left') @@ -80,34 +94,16 @@ def invalidate_chain_assignments(node, parent): inner = left.get('object') if not inner or inner.get('type') != 'MemberExpression': return - # Get C (outer property) - outer_prop = left.get('property') - if not outer_prop: - return - if left.get('computed'): - if not is_string_literal(outer_prop): - return - c_name = outer_prop['value'] - elif is_identifier(outer_prop): - c_name = outer_prop['name'] - else: - return - # Get B (middle property) - inner_prop = inner.get('property') - if not inner_prop: + outer_property_name = _get_property_name(left, 'property') + if outer_property_name is None: return - if inner.get('computed'): - if not is_string_literal(inner_prop): - return - b_name = inner_prop['value'] - elif is_identifier(inner_prop): - b_name = inner_prop['name'] - else: + middle_property_name = _get_property_name(inner, 'property') + if middle_property_name is None: return # If B resolves to a class, invalidate (class, C) - class_name = prop_to_class.get(b_name) - if class_name and (class_name, c_name) in class_constants: - del class_constants[(class_name, c_name)] + class_name = property_to_class.get(middle_property_name) + if class_name and (class_name, outer_property_name) in class_constants: + del class_constants[(class_name, outer_property_name)] simple_traverse(self.ast, invalidate_chain_assignments) @@ -116,56 +112,38 @@ def invalidate_chain_assignments(node, parent): # Phase 2: Replace A.B.C member chains where B resolves to a class # and (class, C) maps to a constant expression - def resolve(node, parent, key, index): + def resolve(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None: if node.get('type') != 'MemberExpression': - return + return None # Skip assignment targets if parent and parent.get('type') == 'AssignmentExpression' and key == 'left': - return + return None - # Get C (the outer property) - outer_prop = node.get('property') - if not outer_prop: - return - if node.get('computed'): - if not is_string_literal(outer_prop): - return - c_name = outer_prop['value'] - elif is_identifier(outer_prop): - c_name = outer_prop['name'] - else: - return + outer_property_name = _get_property_name(node, 'property') + if outer_property_name is None: + return None # Get the inner member expression (A.B) inner = node.get('object') if not inner or inner.get('type') != 'MemberExpression': - return + return None - # Get B (the middle property) - inner_prop = inner.get('property') - if not inner_prop: - return - if inner.get('computed'): - if not is_string_literal(inner_prop): - return - b_name = inner_prop['value'] - elif is_identifier(inner_prop): - b_name = inner_prop['name'] - else: - return + middle_property_name = _get_property_name(inner, 'property') + if middle_property_name is None: + return None # Resolve B → class_name - class_name = prop_to_class.get(b_name) + class_name = property_to_class.get(middle_property_name) if not class_name: - return + return None # Resolve (class_name, C) → constant expression - const_node = class_constants.get((class_name, c_name)) - if const_node is None: - return + constant_node = class_constants.get((class_name, outer_property_name)) + if constant_node is None: + return None self.set_changed() - return deep_copy(const_node) + return deep_copy(constant_node) traverse(self.ast, {'enter': resolve}) return self.has_changed() diff --git a/pyjsclear/transforms/noop_calls.py b/pyjsclear/transforms/noop_calls.py index 7177442..a54e6ef 100644 --- a/pyjsclear/transforms/noop_calls.py +++ b/pyjsclear/transforms/noop_calls.py @@ -7,6 +7,8 @@ obj.methodName('...'); // removed """ +from typing import Any + from ..traverser import REMOVE from ..traverser import simple_traverse from ..traverser import traverse @@ -17,11 +19,11 @@ class NoopCallRemover(Transform): """Remove expression-statement calls to no-op methods.""" - def execute(self): + def execute(self) -> bool: # Phase 1: Find no-op methods (empty body or just 'return;') - noop_methods = set() + noop_methods: set[str] = set() - def find_noops(node, parent): + def find_noops(node: dict, parent: dict | None) -> None: if node.get('type') != 'MethodDefinition': return if node.get('kind') not in ('method', None): @@ -29,21 +31,21 @@ def find_noops(node, parent): key = node.get('key') if not key or not is_identifier(key): return - fn = node.get('value') - if not fn or fn.get('type') != 'FunctionExpression': + function_expression = node.get('value') + if not function_expression or function_expression.get('type') != 'FunctionExpression': return - # Check if async — async no-op still returns a promise, skip - if fn.get('async'): + # Async no-op still returns a promise, skip + if function_expression.get('async'): return - body = fn.get('body') + body = function_expression.get('body') if not body or body.get('type') != 'BlockStatement': return - stmts = body.get('body', []) - if len(stmts) == 0: + statements = body.get('body', []) + if not statements: noop_methods.add(key['name']) - elif len(stmts) == 1: - stmt = stmts[0] - if stmt.get('type') == 'ReturnStatement' and stmt.get('argument') is None: + elif len(statements) == 1: + statement = statements[0] + if statement.get('type') == 'ReturnStatement' and statement.get('argument') is None: noop_methods.add(key['name']) simple_traverse(self.ast, find_noops) @@ -52,13 +54,13 @@ def find_noops(node, parent): return False # Phase 2: Remove ExpressionStatement calls to no-op methods - def remove_calls(node, parent, key, index): + def remove_calls(node: dict, parent: dict | None, key: str | None, index: int | None) -> Any: if node.get('type') != 'ExpressionStatement': return - expr = node.get('expression') - if not expr or expr.get('type') not in ('CallExpression', 'AwaitExpression'): + expression = node.get('expression') + if not expression or expression.get('type') not in ('CallExpression', 'AwaitExpression'): return - call = expr + call = expression if call.get('type') == 'AwaitExpression': call = call.get('argument') if not call or call.get('type') != 'CallExpression': @@ -66,10 +68,10 @@ def remove_calls(node, parent, key, index): callee = call.get('callee') if not callee or callee.get('type') != 'MemberExpression': return - prop = callee.get('property') - if not prop or not is_identifier(prop): + property_node = callee.get('property') + if not property_node or not is_identifier(property_node): return - if prop['name'] in noop_methods: + if property_node['name'] in noop_methods: self.set_changed() return REMOVE diff --git a/pyjsclear/transforms/nullish_coalescing.py b/pyjsclear/transforms/nullish_coalescing.py index f53d58a..91edcb9 100644 --- a/pyjsclear/transforms/nullish_coalescing.py +++ b/pyjsclear/transforms/nullish_coalescing.py @@ -17,36 +17,47 @@ class NullishCoalescing(Transform): """Convert nullish check patterns to ?? operator.""" - def execute(self): - def enter(node, parent, key, index): - if node.get('type') != 'ConditionalExpression': - return - test = node.get('test') - if not isinstance(test, dict) or test.get('type') != 'LogicalExpression' or test.get('operator') != '&&': - return - - left_cmp = test.get('left') - right_cmp = test.get('right') - if not isinstance(left_cmp, dict) or not isinstance(right_cmp, dict): - return - - # Pattern: X !== null && X !== undefined ? X : default - # Where X might be (_0x = value) or just an identifier - result = self._match_nullish_pattern(left_cmp, right_cmp, node.get('consequent'), node.get('alternate')) - if result: - self.set_changed() - return result - - # Also handle reversed order: X !== undefined && X !== null ? X : default - result = self._match_nullish_pattern(right_cmp, left_cmp, node.get('consequent'), node.get('alternate')) - if result: - self.set_changed() - return result - - traverse(self.ast, {'enter': enter}) + def execute(self) -> bool: + traverse(self.ast, {'enter': self._enter}) return self.has_changed() - def _match_nullish_pattern(self, null_check, undef_check, consequent, alternate): + def _enter(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None: + if node.get('type') != 'ConditionalExpression': + return None + test = node.get('test') + if not isinstance(test, dict) or test.get('type') != 'LogicalExpression' or test.get('operator') != '&&': + return None + + left_cmp = test.get('left') + right_cmp = test.get('right') + if not isinstance(left_cmp, dict) or not isinstance(right_cmp, dict): + return None + + consequent = node.get('consequent') + alternate = node.get('alternate') + + # Pattern: X !== null && X !== undefined ? X : default + # Where X might be (_0x = value) or just an identifier + result = self._match_nullish_pattern(left_cmp, right_cmp, consequent, alternate) + if result: + self.set_changed() + return result + + # Also handle reversed order: X !== undefined && X !== null ? X : default + result = self._match_nullish_pattern(right_cmp, left_cmp, consequent, alternate) + if result: + self.set_changed() + return result + + return None + + def _match_nullish_pattern( + self, + null_check: dict, + undef_check: dict, + consequent: dict | None, + alternate: dict | None, + ) -> dict | None: """Try to match the pattern and return a ?? node if successful.""" # null_check: X !== null (or (tmp = value) !== null) if null_check.get('type') != 'BinaryExpression' or null_check.get('operator') != '!==': diff --git a/pyjsclear/transforms/object_packer.py b/pyjsclear/transforms/object_packer.py index 170f10d..329dd99 100644 --- a/pyjsclear/transforms/object_packer.py +++ b/pyjsclear/transforms/object_packer.py @@ -12,11 +12,11 @@ class ObjectPacker(Transform): """Pack sequential property assignments into object initializers.""" - def execute(self): + def execute(self) -> bool: self._process_bodies(self.ast) return self.has_changed() - def _process_bodies(self, node): + def _process_bodies(self, node: dict) -> None: """Recursively find body arrays and try packing.""" if not isinstance(node, dict): return @@ -30,7 +30,7 @@ def _process_bodies(self, node): self._process_bodies(child) @staticmethod - def _find_empty_object_declaration(stmt): + def _find_empty_object_declaration(stmt: dict) -> tuple[str, dict, dict] | None: """Find an empty object literal in a VariableDeclaration. Returns (name, declarator, object_expression) or None. @@ -48,7 +48,7 @@ def _find_empty_object_declaration(stmt): return declaration['id']['name'], declaration, initializer return None - def _try_pack_body(self, body): + def _try_pack_body(self, body: list) -> None: """Find empty object declarations followed by property assignments and pack them.""" i = 0 while i < len(body): @@ -62,7 +62,7 @@ def _try_pack_body(self, body): i += 1 continue - obj_name, obj_decl, obj_expr = found + object_name, obj_decl, obj_expr = found # Collect consecutive property assignments assignments = [] @@ -78,18 +78,16 @@ def _try_pack_body(self, body): if not left or left.get('type') != 'MemberExpression': break object_reference = left.get('object') - if not is_identifier(object_reference) or object_reference.get('name') != obj_name: + if not is_identifier(object_reference) or object_reference.get('name') != object_name: break property_node = left.get('property') right = expr.get('right') - # Get property key if property_node is None: break - # Support both computed and non-computed property keys property_key = property_node # Don't pack self-referential assignments (o.x = o.y) - if self._references_name(right, obj_name): + if self._references_name(right, object_name): break computed = left.get('computed', False) @@ -99,7 +97,7 @@ def _try_pack_body(self, body): if assignments: # Pack into the object literal for property_key, value, computed in assignments: - property_node = { + new_property = { 'type': 'Property', 'key': property_key, 'value': value, @@ -108,14 +106,14 @@ def _try_pack_body(self, body): 'shorthand': False, 'computed': computed, } - obj_expr['properties'].append(property_node) + obj_expr['properties'].append(new_property) # Remove the assignment statements - del body[i + 1 : j] + del body[i + 1:j] self.set_changed() i += 1 - def _references_name(self, node, name): + def _references_name(self, node: dict, name: str) -> bool: """Check if a node references a given identifier name.""" if not isinstance(node, dict) or 'type' not in node: return False diff --git a/pyjsclear/transforms/object_simplifier.py b/pyjsclear/transforms/object_simplifier.py index 7c71e18..2e3febf 100644 --- a/pyjsclear/transforms/object_simplifier.py +++ b/pyjsclear/transforms/object_simplifier.py @@ -18,12 +18,12 @@ class ObjectSimplifier(Transform): rebuild_scope = True - def execute(self): + def execute(self) -> bool: scope_tree, _ = build_scope_tree(self.ast) self._process_scope(scope_tree) return self.has_changed() - def _process_scope(self, scope): + def _process_scope(self, scope) -> None: for name, binding in list(scope.bindings.items()): if not binding.is_constant: continue @@ -39,21 +39,21 @@ def _process_scope(self, scope): if not self._is_proxy_object(properties): continue - prop_map = {} + property_map = {} for property_node in properties: key = self._get_property_key(property_node) if key is None: continue value = property_node.get('value') if is_literal(value): - prop_map[key] = value + property_map[key] = value elif value and value.get('type') in ( 'FunctionExpression', 'ArrowFunctionExpression', ): - prop_map[key] = value + property_map[key] = value - if not prop_map: + if not property_map: continue if self._has_property_assignment(binding): @@ -68,10 +68,10 @@ def _process_scope(self, scope): member_expression = ref_parent property_name = self._get_member_prop_name(member_expression) - if property_name is None or property_name not in prop_map: + if property_name is None or property_name not in property_map: continue - value = prop_map[property_name] + value = property_map[property_name] if is_literal(value): if self._replace_node(member_expression, deep_copy(value)): self.set_changed() @@ -87,25 +87,25 @@ def _process_scope(self, scope): for child in scope.children: self._process_scope(child) - def _has_property_assignment(self, binding): + def _has_property_assignment(self, binding) -> bool: """Check if any reference to the binding is a property assignment target.""" for reference_node, reference_parent, ref_key, ref_index in binding.references: if not (reference_parent and reference_parent.get('type') == 'MemberExpression' and ref_key == 'object'): continue - me_parent_info = find_parent(self.ast, reference_parent) - if not me_parent_info: + member_expression_parent_info = find_parent(self.ast, reference_parent) + if not member_expression_parent_info: continue - parent, key, _ = me_parent_info + parent, key, _ = member_expression_parent_info if parent and parent.get('type') == 'AssignmentExpression' and key == 'left': return True return False - def _try_inline_function_call(self, member_expression, function_value): + def _try_inline_function_call(self, member_expression, function_value) -> None: """Try to inline a function call at a MemberExpression site.""" - me_parent_info = find_parent(self.ast, member_expression) - if not me_parent_info: + member_expression_parent_info = find_parent(self.ast, member_expression) + if not member_expression_parent_info: return - parent, key, _ = me_parent_info + parent, key, _ = member_expression_parent_info if not (parent and parent.get('type') == 'CallExpression' and key == 'callee'): return replacement = self._inline_func(function_value, parent.get('arguments', [])) @@ -114,24 +114,24 @@ def _try_inline_function_call(self, member_expression, function_value): if self._replace_node(parent, replacement): self.set_changed() - def _is_proxy_object(self, properties): + def _is_proxy_object(self, properties: list) -> bool: """Check if all properties are literals or simple functions.""" - for p in properties: - if p.get('type') != 'Property': + for property_node in properties: + if property_node.get('type') != 'Property': return False - val = p.get('value') - if not val: + value = property_node.get('value') + if not value: return False - if is_literal(val): + if is_literal(value): continue - if val.get('type') in ('FunctionExpression', 'ArrowFunctionExpression'): + if value.get('type') in ('FunctionExpression', 'ArrowFunctionExpression'): continue return False return True - def _get_property_key(self, prop): + def _get_property_key(self, property_node) -> str | None: """Get the string key of a property.""" - key = prop.get('key') + key = property_node.get('key') if not key: return None match key.get('type'): @@ -141,12 +141,12 @@ def _get_property_key(self, prop): return key['value'] return None - def _get_member_prop_name(self, member_expr): + def _get_member_prop_name(self, member_expression) -> str | None: """Get property name from a member expression.""" - prop = member_expr.get('property') + prop = member_expression.get('property') if not prop: return None - if member_expr.get('computed'): + if member_expression.get('computed'): if is_string_literal(prop): return prop['value'] return None @@ -154,7 +154,7 @@ def _get_member_prop_name(self, member_expr): return prop['name'] return None - def _replace_node(self, target, replacement): + def _replace_node(self, target, replacement) -> bool: """Replace target node in the AST. Returns True if replaced.""" result = find_parent(self.ast, target) if result: @@ -166,29 +166,29 @@ def _replace_node(self, target, replacement): return True return False - def _inline_func(self, func, args): + def _inline_func(self, function_node, arguments: list): """Inline a simple function call.""" - body = func.get('body') + body = function_node.get('body') if not body: return None - if func.get('type') == 'ArrowFunctionExpression' and body.get('type') != 'BlockStatement': + if function_node.get('type') == 'ArrowFunctionExpression' and body.get('type') != 'BlockStatement': expr = deep_copy(body) elif body.get('type') == 'BlockStatement': - stmts = body.get('body', []) - if len(stmts) != 1 or stmts[0].get('type') != 'ReturnStatement': + statements = body.get('body', []) + if len(statements) != 1 or statements[0].get('type') != 'ReturnStatement': return None - argument = stmts[0].get('argument') + argument = statements[0].get('argument') if not argument: return None expr = deep_copy(argument) else: return None - params = func.get('params', []) + params = function_node.get('params', []) param_map = {} - for i, parameter in enumerate(params): + for index, parameter in enumerate(params): if parameter.get('type') == 'Identifier': - param_map[parameter['name']] = args[i] if i < len(args) else {'type': 'Identifier', 'name': 'undefined'} + param_map[parameter['name']] = arguments[index] if index < len(arguments) else {'type': 'Identifier', 'name': 'undefined'} replace_identifiers(expr, param_map) return expr diff --git a/pyjsclear/transforms/optional_chaining.py b/pyjsclear/transforms/optional_chaining.py index 43edba5..fcead51 100644 --- a/pyjsclear/transforms/optional_chaining.py +++ b/pyjsclear/transforms/optional_chaining.py @@ -18,19 +18,19 @@ from .base import Transform -def _nodes_match(a, b): +def _nodes_match(node_a: object, node_b: object) -> bool: """Check if two AST nodes are structurally equivalent (shallow).""" - if not isinstance(a, dict) or not isinstance(b, dict): + if not isinstance(node_a, dict) or not isinstance(node_b, dict): return False - if a.get('type') != b.get('type'): + if node_a.get('type') != node_b.get('type'): return False - if a.get('type') == 'Identifier': - return a.get('name') == b.get('name') - if a.get('type') == 'MemberExpression': + if node_a.get('type') == 'Identifier': + return node_a.get('name') == node_b.get('name') + if node_a.get('type') == 'MemberExpression': return ( - _nodes_match(a.get('object'), b.get('object')) - and _nodes_match(a.get('property'), b.get('property')) - and a.get('computed') == b.get('computed') + _nodes_match(node_a.get('object'), node_b.get('object')) + and _nodes_match(node_a.get('property'), node_b.get('property')) + and node_a.get('computed') == node_b.get('computed') ) return False @@ -38,64 +38,65 @@ def _nodes_match(a, b): class OptionalChaining(Transform): """Convert nullish check patterns to ?. operator.""" - def execute(self): - def enter(node, parent, key, index): + def execute(self) -> bool: + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None: if node.get('type') != 'ConditionalExpression': - return + return None test = node.get('test') if not isinstance(test, dict) or test.get('type') != 'LogicalExpression' or test.get('operator') != '||': - return + return None alternate = node.get('alternate') consequent = node.get('consequent') # consequent must be undefined/void 0 if not is_undefined(consequent): - return + return None result = self._match_optional_pattern(test.get('left'), test.get('right'), alternate) if result: self.set_changed() return result + return None traverse(self.ast, {'enter': enter}) return self.has_changed() - def _match_optional_pattern(self, left_cmp, right_cmp, alternate): + def _match_optional_pattern(self, left_comparison: object, right_comparison: object, alternate: object) -> dict | None: """Try to match X === null || X === undefined ? undefined : X.prop.""" - if not isinstance(left_cmp, dict) or not isinstance(right_cmp, dict): + if not isinstance(left_comparison, dict) or not isinstance(right_comparison, dict): return None - if left_cmp.get('type') != 'BinaryExpression' or left_cmp.get('operator') != '===': + if left_comparison.get('type') != 'BinaryExpression' or left_comparison.get('operator') != '===': return None - if right_cmp.get('type') != 'BinaryExpression' or right_cmp.get('operator') != '===': + if right_comparison.get('type') != 'BinaryExpression' or right_comparison.get('operator') != '===': return None # Figure out which comparison has null and which has undefined - checked_var = None + checked_variable = None # Try: left has null, right has undefined - for null_cmp, undef_cmp in [(left_cmp, right_cmp), (right_cmp, left_cmp)]: - null_left, null_right = null_cmp.get('left'), null_cmp.get('right') - undef_left, undef_right = undef_cmp.get('left'), undef_cmp.get('right') + for null_comparison, undefined_comparison in [(left_comparison, right_comparison), (right_comparison, left_comparison)]: + null_comparison_left, null_comparison_right = null_comparison.get('left'), null_comparison.get('right') + undefined_comparison_left, undefined_comparison_right = undefined_comparison.get('left'), undefined_comparison.get('right') # X === null - if is_null_literal(null_right): - null_checked = null_left - elif is_null_literal(null_left): - null_checked = null_right + if is_null_literal(null_comparison_right): + null_checked = null_comparison_left + elif is_null_literal(null_comparison_left): + null_checked = null_comparison_right else: continue # X === undefined - if is_undefined(undef_right): - undef_checked = undef_left - elif is_undefined(undef_left): - undef_checked = undef_right + if is_undefined(undefined_comparison_right): + undefined_checked = undefined_comparison_left + elif is_undefined(undefined_comparison_left): + undefined_checked = undefined_comparison_right else: continue # Case 1: Simple - both check the same identifier - if _nodes_match(null_checked, undef_checked): - checked_var = null_checked + if _nodes_match(null_checked, undefined_checked): + checked_variable = null_checked break # Case 2: Temp assignment - (_tmp = expr) === null || _tmp === undefined @@ -106,31 +107,31 @@ def _match_optional_pattern(self, left_cmp, right_cmp, alternate): ): tmp_var = null_checked.get('left') value_expr = null_checked.get('right') - if identifiers_match(tmp_var, undef_checked): + if identifiers_match(tmp_var, undefined_checked): # The alternate should use tmp_var as the object - checked_var = tmp_var + checked_variable = tmp_var # We'll replace tmp_var references in alternate with value_expr - return self._build_optional_chain(value_expr, checked_var, alternate) + return self._build_optional_chain(value_expr, checked_variable, alternate) - if checked_var is None: + if checked_variable is None: return None - return self._build_optional_chain(checked_var, checked_var, alternate) + return self._build_optional_chain(checked_variable, checked_variable, alternate) - def _build_optional_chain(self, base_expr, checked_var, alternate): + def _build_optional_chain(self, base_expr: object, checked_variable: object, alternate: object) -> dict | None: """Build an optional chain node: base_expr?.something. base_expr: the actual expression to use as the object - checked_var: the variable that was null-checked (may differ from base_expr for temp assignments) - alternate: the expression that accesses checked_var.prop (or deeper) + checked_variable: the variable that was null-checked (may differ from base_expr for temp assignments) + alternate: the expression that accesses checked_variable.prop (or deeper) """ if not isinstance(alternate, dict): return None - # alternate should be a MemberExpression or CallExpression whose object matches checked_var + # alternate should be a MemberExpression or CallExpression whose object matches checked_variable if alternate.get('type') == 'MemberExpression': obj = alternate.get('object') - if _nodes_match(obj, checked_var): + if _nodes_match(obj, checked_variable): return { 'type': 'MemberExpression', 'object': base_expr, @@ -141,7 +142,7 @@ def _build_optional_chain(self, base_expr, checked_var, alternate): if alternate.get('type') == 'CallExpression': callee = alternate.get('callee') - if _nodes_match(callee, checked_var): + if _nodes_match(callee, checked_variable): return { 'type': 'CallExpression', 'callee': base_expr, diff --git a/pyjsclear/transforms/property_simplifier.py b/pyjsclear/transforms/property_simplifier.py index e5fe9db..e5df400 100644 --- a/pyjsclear/transforms/property_simplifier.py +++ b/pyjsclear/transforms/property_simplifier.py @@ -9,8 +9,8 @@ class PropertySimplifier(Transform): """Simplify obj["prop"] to obj.prop when prop is a valid identifier.""" - def execute(self): - def enter(node, parent, key, index): + def execute(self) -> bool: + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'MemberExpression': return if not node.get('computed'): @@ -29,7 +29,7 @@ def enter(node, parent, key, index): traverse(self.ast, {'enter': enter}) # Also simplify computed property keys in object literals - def enter_obj(node, parent, key, index): + def enter_obj(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'Property': return key_node = node.get('key') @@ -51,7 +51,7 @@ def enter_obj(node, parent, key, index): # Simplify method definitions with string literal keys: # static ["name"]() → static name() # Also handles cases where parser sets computed=False but key is still a Literal - def enter_method(node, parent, key, index): + def enter_method(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'MethodDefinition': return key_node = node.get('key') diff --git a/pyjsclear/transforms/proxy_functions.py b/pyjsclear/transforms/proxy_functions.py index 5c0a1c0..2732419 100644 --- a/pyjsclear/transforms/proxy_functions.py +++ b/pyjsclear/transforms/proxy_functions.py @@ -28,10 +28,10 @@ def execute(self): scope_tree, node_scope = build_scope_tree(self.ast) # Find proxy functions - proxy_fns = {} # name -> (func_node, scope, binding) - self._find_proxy_functions(scope_tree, proxy_fns) + proxy_functions = {} # name -> (func_node, scope, binding) + self._find_proxy_functions(scope_tree, proxy_functions) - if not proxy_fns: + if not proxy_functions: return False # Collect call sites with depth info @@ -46,33 +46,33 @@ def enter(node, parent, key, index): if not is_identifier(callee): return name = callee.get('name', '') - if name not in proxy_fns: + if name not in proxy_functions: return - call_sites.append((node, parent, key, index, proxy_fns[name], depth_counter[0])) + call_sites.append((node, parent, key, index, proxy_functions[name], depth_counter[0])) traverse(self.ast, {'enter': enter}) # Skip helper functions: many call sites + conditional body = not a true proxy call_counts = {} - for cs in call_sites: - fn_id = id(cs[4][0]) # func_node - call_counts[fn_id] = call_counts.get(fn_id, 0) + 1 + for call_site in call_sites: + function_node_id = id(call_site[4][0]) # func_node + call_counts[function_node_id] = call_counts.get(function_node_id, 0) + 1 def _has_conditional(node): found = [False] - def cb(n, parent): + def check_node(n, parent): if n.get('type') == 'ConditionalExpression': found[0] = True - simple_traverse(node, cb) + simple_traverse(node, check_node) return found[0] - helper_fn_ids = set() - for name, (func_node, _, _) in proxy_fns.items(): + helper_function_ids = set() + for name, (func_node, _, _) in proxy_functions.items(): if call_counts.get(id(func_node), 0) > 3 and _has_conditional(func_node): - helper_fn_ids.add(id(func_node)) - call_sites = [cs for cs in call_sites if id(cs[4][0]) not in helper_fn_ids] + helper_function_ids.add(id(func_node)) + call_sites = [call_site for call_site in call_sites if id(call_site[4][0]) not in helper_function_ids] # Process innermost calls first call_sites.sort(key=lambda x: x[5], reverse=True) @@ -147,18 +147,18 @@ def _is_proxy_function(self, func_node): # Block with single return if body.get('type') == 'BlockStatement': - stmts = body.get('body', []) - if len(stmts) != 1: + statements = body.get('body', []) + if len(statements) != 1: return False - stmt = stmts[0] + stmt = statements[0] if stmt.get('type') != 'ReturnStatement': return False - arg = stmt.get('argument') - if arg is None: + argument = stmt.get('argument') + if argument is None: return True # returns undefined - if not self._is_proxy_value(arg): + if not self._is_proxy_value(argument): return False - return self._count_nodes(arg) <= _MAX_PROXY_BODY_NODES + return self._count_nodes(argument) <= _MAX_PROXY_BODY_NODES return False @@ -167,10 +167,10 @@ def _count_nodes(node): """Count AST nodes in a subtree.""" count = [0] - def cb(n, parent): + def increment_count(n, parent): count[0] += 1 - simple_traverse(node, cb) + simple_traverse(node, increment_count) return count[0] _DISALLOWED_PROXY_TYPES = frozenset( @@ -211,25 +211,25 @@ def _get_replacement(self, func_node, args): if func_node.get('type') == 'ArrowFunctionExpression' and body.get('type') != 'BlockStatement': expr = deep_copy(body) elif body.get('type') == 'BlockStatement': - stmts = body.get('body', []) - if not stmts or stmts[0].get('type') != 'ReturnStatement': + statements = body.get('body', []) + if not statements or statements[0].get('type') != 'ReturnStatement': return None - arg = stmts[0].get('argument') - if arg is None: + argument = statements[0].get('argument') + if argument is None: return {'type': 'Identifier', 'name': 'undefined'} - expr = deep_copy(arg) + expr = deep_copy(argument) else: return None # Build parameter map params = func_node.get('params', []) - param_map = {} - for i, parameter in enumerate(params): + parameter_map = {} + for index, parameter in enumerate(params): if parameter.get('type') == 'Identifier': - if i < len(args): - param_map[parameter['name']] = args[i] + if index < len(args): + parameter_map[parameter['name']] = args[index] else: - param_map[parameter['name']] = {'type': 'Identifier', 'name': 'undefined'} + parameter_map[parameter['name']] = {'type': 'Identifier', 'name': 'undefined'} - replace_identifiers(expr, param_map) + replace_identifiers(expr, parameter_map) return expr diff --git a/pyjsclear/transforms/reassignment.py b/pyjsclear/transforms/reassignment.py index 555a8f5..5352dc8 100644 --- a/pyjsclear/transforms/reassignment.py +++ b/pyjsclear/transforms/reassignment.py @@ -4,12 +4,19 @@ And replaces all references to x with y, then removes x. """ +from __future__ import annotations + +from typing import TYPE_CHECKING + from ..scope import build_scope_tree from ..traverser import REMOVE from ..traverser import traverse from ..utils.ast_helpers import is_identifier from .base import Transform +if TYPE_CHECKING: + from ..scope import Scope + class ReassignmentRemover(Transform): """Remove redundant reassignments like x = y where y is used identically.""" @@ -51,13 +58,13 @@ class ReassignmentRemover(Transform): rebuild_scope = True - def execute(self): + def execute(self) -> bool: scope_tree, _ = build_scope_tree(self.ast) self._process_scope(scope_tree) self._inline_assignment_aliases(scope_tree) return self.has_changed() - def _process_scope(self, scope): + def _process_scope(self, scope: Scope) -> None: for name, binding in list(scope.bindings.items()): if not binding.is_constant: continue @@ -69,8 +76,8 @@ def _process_scope(self, scope): continue # Skip destructuring patterns — id must be a simple Identifier - decl_id = node.get('id') - if not decl_id or decl_id.get('type') != 'Identifier': + declaration_id = node.get('id') + if not declaration_id or declaration_id.get('type') != 'Identifier': continue initializer = node.get('init') @@ -80,12 +87,12 @@ def _process_scope(self, scope): target_name = initializer.get('name', '') if target_name == name: continue + target_binding = scope.get_binding(target_name) # Allow inlining if target is a well-known global or a constant binding - if target_binding: - if not target_binding.is_constant: - continue - elif target_name not in self._WELL_KNOWN_GLOBALS: + if target_binding and not target_binding.is_constant: + continue + if not target_binding and target_name not in self._WELL_KNOWN_GLOBALS: continue # Replace all references to `name` with `target_name` @@ -108,7 +115,7 @@ def _process_scope(self, scope): for child in scope.children: self._process_scope(child) - def _inline_assignment_aliases(self, scope_tree): + def _inline_assignment_aliases(self, scope_tree: Scope) -> None: """Inline aliases created by `var x; ... x = y;` patterns. Handles the obfuscator pattern where a variable is declared without @@ -116,17 +123,17 @@ def _inline_assignment_aliases(self, scope_tree): """ self._process_assignment_aliases(scope_tree) - def _remove_assignment_statement(self, assignment_node): + def _remove_assignment_statement(self, assignment_node: dict) -> None: """Remove the ExpressionStatement containing the given assignment expression.""" - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object: if node.get('type') == 'ExpressionStatement' and node.get('expression') is assignment_node: self.set_changed() return REMOVE traverse(self.ast, {'enter': enter}) - def _process_assignment_aliases(self, scope): + def _process_assignment_aliases(self, scope: Scope) -> None: for name, binding in list(scope.bindings.items()): if binding.is_constant or binding.kind == 'param': continue @@ -167,10 +174,9 @@ def _process_assignment_aliases(self, scope): # The target must be constant or a well-known global target_binding = scope.get_binding(target_name) - if target_binding: - if not target_binding.is_constant: - continue - elif target_name not in self._WELL_KNOWN_GLOBALS: + if target_binding and not target_binding.is_constant: + continue + if not target_binding and target_name not in self._WELL_KNOWN_GLOBALS: continue # Replace all reads of `name` with `target_name` diff --git a/pyjsclear/transforms/require_inliner.py b/pyjsclear/transforms/require_inliner.py index b1c7431..7611c8e 100644 --- a/pyjsclear/transforms/require_inliner.py +++ b/pyjsclear/transforms/require_inliner.py @@ -17,26 +17,24 @@ class RequireInliner(Transform): """Replace require polyfill calls with direct require() calls.""" - def execute(self): - polyfill_names = set() + def execute(self) -> bool: + polyfill_names: set[str] = set() # Phase 1: Detect require polyfill pattern. # We look for: var X = (...)(function(Y) { ... require ... }) # Heuristic: a VariableDeclarator whose init is a CallExpression, # and somewhere in its body there's `typeof require !== "undefined" ? require : ...` - def find_polyfills(node, parent): + def find_polyfills(node: dict, parent: dict | None) -> None: if node.get('type') != 'VariableDeclarator': return - decl_id = node.get('id') + declaration_id = node.get('id') init = node.get('init') - if not is_identifier(decl_id): + if not is_identifier(declaration_id): return if not init or init.get('type') != 'CallExpression': return - - # Check if the init tree contains `typeof require` if self._contains_typeof_require(init): - polyfill_names.add(decl_id['name']) + polyfill_names.add(declaration_id['name']) simple_traverse(self.ast, find_polyfills) @@ -44,7 +42,7 @@ def find_polyfills(node, parent): return False # Phase 2: Replace _0x544bfe(X) with require(X) - def replace_calls(node, parent, key, index): + def replace_calls(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'CallExpression': return callee = node.get('callee') @@ -55,26 +53,24 @@ def replace_calls(node, parent, key, index): args = node.get('arguments', []) if len(args) != 1: return - - # Replace the callee with require node['callee'] = make_identifier('require') self.set_changed() traverse(self.ast, {'enter': replace_calls}) return self.has_changed() - def _contains_typeof_require(self, node): + def _contains_typeof_require(self, node: dict) -> bool: """Check if a subtree contains `typeof require`.""" found = [False] - def scan(n, parent): + def scan(current_node: dict, parent: dict | None) -> None: if found[0]: return - if not isinstance(n, dict): + if not isinstance(current_node, dict): return - if n.get('type') == 'UnaryExpression' and n.get('operator') == 'typeof': - arg = n.get('argument') - if is_identifier(arg) and arg.get('name') == 'require': + if current_node.get('type') == 'UnaryExpression' and current_node.get('operator') == 'typeof': + argument = current_node.get('argument') + if is_identifier(argument) and argument.get('name') == 'require': found[0] = True simple_traverse(node, scan) diff --git a/pyjsclear/transforms/sequence_splitter.py b/pyjsclear/transforms/sequence_splitter.py index 8878210..3478c60 100644 --- a/pyjsclear/transforms/sequence_splitter.py +++ b/pyjsclear/transforms/sequence_splitter.py @@ -14,15 +14,15 @@ class SequenceSplitter(Transform): """Split sequence expressions and normalize control flow bodies.""" - def execute(self): + def execute(self) -> bool: self._normalize_bodies(self.ast) self._split_in_body_arrays(self.ast) return self.has_changed() - def _normalize_bodies(self, ast): + def _normalize_bodies(self, ast: dict) -> None: """Ensure if/while/for bodies are BlockStatements.""" - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: node_type = node.get('type', '') if node_type not in ( 'IfStatement', @@ -42,7 +42,7 @@ def enter(node, parent, key, index): traverse(ast, {'enter': enter}) - def _normalize_if_branches(self, node): + def _normalize_if_branches(self, node: dict) -> None: """Wrap non-block consequent/alternate of IfStatement in BlockStatements.""" consequent = node.get('consequent') if consequent and consequent.get('type') != 'BlockStatement': @@ -53,7 +53,7 @@ def _normalize_if_branches(self, node): node['alternate'] = make_block_statement([alternate]) self.set_changed() - def _split_in_body_arrays(self, node): + def _split_in_body_arrays(self, node: dict) -> None: """Find all arrays that contain statements and split sequences + var decls in them.""" if not isinstance(node, dict): return @@ -62,13 +62,12 @@ def _split_in_body_arrays(self, node): # Check if this looks like a statement array if child and isinstance(child[0], dict) and 'type' in child[0]: self._process_stmt_array(child) - # Recurse into items for item in child: self._split_in_body_arrays(item) elif isinstance(child, dict) and 'type' in child: self._split_in_body_arrays(child) - def _extract_indirect_call_prefixes(self, stmt): + def _extract_indirect_call_prefixes(self, statement: dict) -> list: """Extract dead prefix expressions from (0, fn)(args) patterns. Only extracts from: @@ -79,7 +78,7 @@ def _extract_indirect_call_prefixes(self, stmt): """ prefixes = [] - def extract_from_call(node): + def extract_from_call(node: dict | None) -> None: """If node is a CallExpression with SequenceExpression callee, extract prefixes.""" if not isinstance(node, dict): return @@ -91,45 +90,44 @@ def extract_from_call(node): callee = target.get('callee') if not isinstance(callee, dict) or callee.get('type') != 'SequenceExpression': return - exprs = callee.get('expressions', []) - if len(exprs) <= 1: + expressions = callee.get('expressions', []) + if len(expressions) <= 1: return - prefixes.extend(exprs[:-1]) - target['callee'] = exprs[-1] - - stype = stmt.get('type', '') - if stype == 'ExpressionStatement': - extract_from_call(stmt.get('expression')) - elif stype == 'VariableDeclaration': - for d in stmt.get('declarations', []): - extract_from_call(d.get('init')) - elif stype == 'ReturnStatement': - extract_from_call(stmt.get('argument')) - # Also check assignment expressions in ExpressionStatements: - # x = (0, fn)(args) - if stype == 'ExpressionStatement': - expr = stmt.get('expression') - if isinstance(expr, dict) and expr.get('type') == 'AssignmentExpression': - extract_from_call(expr.get('right')) + prefixes.extend(expressions[:-1]) + target['callee'] = expressions[-1] + + statement_type = statement.get('type', '') + match statement_type: + case 'ExpressionStatement': + extract_from_call(statement.get('expression')) + # Also check assignment: x = (0, fn)(args) + expression = statement.get('expression') + if isinstance(expression, dict) and expression.get('type') == 'AssignmentExpression': + extract_from_call(expression.get('right')) + case 'VariableDeclaration': + for declarator in statement.get('declarations', []): + extract_from_call(declarator.get('init')) + case 'ReturnStatement': + extract_from_call(statement.get('argument')) return prefixes - def _process_stmt_array(self, statements): + def _process_stmt_array(self, statements: list) -> None: """Split sequence expressions and multi-var declarations in a statement array.""" - i = 0 - while i < len(statements): - statement = statements[i] + index = 0 + while index < len(statements): + statement = statements[index] if not isinstance(statement, dict): - i += 1 + index += 1 continue # Extract dead prefix from indirect call patterns: (0, fn)(args) → 0; fn(args); prefixes = self._extract_indirect_call_prefixes(statement) if prefixes: - new_stmts = [make_expression_statement(expression) for expression in prefixes] - new_stmts.append(statement) - statements[i : i + 1] = new_stmts - i += len(new_stmts) + new_statements = [make_expression_statement(expression) for expression in prefixes] + new_statements.append(statement) + statements[index : index + 1] = new_statements + index += len(new_statements) self.set_changed() continue @@ -141,44 +139,44 @@ def _process_stmt_array(self, statements): ): expressions = statement['expression'].get('expressions', []) if len(expressions) > 1: - new_stmts = [make_expression_statement(expression) for expression in expressions] - statements[i : i + 1] = new_stmts - i += len(new_stmts) + new_statements = [make_expression_statement(expression) for expression in expressions] + statements[index : index + 1] = new_statements + index += len(new_statements) self.set_changed() continue # Split multi-declarator VariableDeclaration # (but not inside for-loop init — those aren't in body arrays) if statement.get('type') == 'VariableDeclaration': - decls = statement.get('declarations', []) - if len(decls) > 1: + declarations = statement.get('declarations', []) + if len(declarations) > 1: kind = statement.get('kind', 'var') - new_stmts = [ + new_statements = [ { 'type': 'VariableDeclaration', 'kind': kind, 'declarations': [declaration], } - for declaration in decls + for declaration in declarations ] - statements[i : i + 1] = new_stmts - i += len(new_stmts) + statements[index : index + 1] = new_statements + index += len(new_statements) self.set_changed() continue # Split SequenceExpression in single declarator init - if len(decls) == 1: - split_result = self._try_split_single_declarator_init(statement, decls[0]) + if len(declarations) == 1: + split_result = self._try_split_single_declarator_init(statement, declarations[0]) if split_result: - statements[i : i + 1] = split_result - i += len(split_result) + statements[index : index + 1] = split_result + index += len(split_result) self.set_changed() continue - i += 1 + index += 1 @staticmethod - def _try_split_single_declarator_init(stmt, declarator): + def _try_split_single_declarator_init(statement: dict, declarator: dict) -> list | None: """Split SequenceExpression from a single VariableDeclarator init. Handles both direct sequences and sequences inside AwaitExpression. @@ -190,12 +188,12 @@ def _try_split_single_declarator_init(stmt, declarator): # Direct: const x = (a, b, expr()) → a; b; const x = expr(); if init.get('type') == 'SequenceExpression': - exprs = init.get('expressions', []) - if len(exprs) <= 1: + expressions = init.get('expressions', []) + if len(expressions) <= 1: return None - prefix = [make_expression_statement(e) for e in exprs[:-1]] - declarator['init'] = exprs[-1] - prefix.append(stmt) + prefix = [make_expression_statement(expression) for expression in expressions[:-1]] + declarator['init'] = expressions[-1] + prefix.append(statement) return prefix # Await-wrapped: var x = await (a, b, expr()) → a; b; var x = await expr(); @@ -204,12 +202,12 @@ def _try_split_single_declarator_init(stmt, declarator): and isinstance(init.get('argument'), dict) and init['argument'].get('type') == 'SequenceExpression' ): - exprs = init['argument'].get('expressions', []) - if len(exprs) <= 1: + expressions = init['argument'].get('expressions', []) + if len(expressions) <= 1: return None - prefix = [make_expression_statement(e) for e in exprs[:-1]] - init['argument'] = exprs[-1] - prefix.append(stmt) + prefix = [make_expression_statement(expression) for expression in expressions[:-1]] + init['argument'] = expressions[-1] + prefix.append(statement) return prefix return None diff --git a/pyjsclear/transforms/single_use_vars.py b/pyjsclear/transforms/single_use_vars.py index d8078d4..7d28b17 100644 --- a/pyjsclear/transforms/single_use_vars.py +++ b/pyjsclear/transforms/single_use_vars.py @@ -3,11 +3,11 @@ Targets patterns like: const _0x337161 = require("process"); return _0x337161.env.LOCALAPPDATA; -→ return require("process").env.LOCALAPPDATA; + return require("process").env.LOCALAPPDATA; const _0x27439f = Buffer.from(_0x162d6f); return _0x27439f.toString(); -→ return Buffer.from(_0x162d6f).toString(); + return Buffer.from(_0x162d6f).toString(); Only inlines when: - The variable is constant (no reassignments) @@ -15,6 +15,10 @@ - The init expression is not too large (≤ 15 AST nodes) """ +from __future__ import annotations + +from typing import TYPE_CHECKING + from ..scope import build_scope_tree from ..traverser import REMOVE from ..traverser import find_parent @@ -24,15 +28,18 @@ from ..utils.ast_helpers import is_identifier from .base import Transform +if TYPE_CHECKING: + from ..scope import Scope + -def _count_nodes(node): +def _count_nodes(node: dict) -> int: """Count AST nodes in a subtree.""" count = [0] - def cb(n, parent): + def increment_count(_node: dict, parent: dict | None) -> None: count[0] += 1 - simple_traverse(node, cb) + simple_traverse(node, increment_count) return count[0] @@ -45,7 +52,7 @@ class SingleUseVarInliner(Transform): # Keeps inlined expressions readable; avoids ballooning line length. _MAX_INIT_NODES = 15 - def execute(self): + def execute(self) -> bool: scope_tree, _ = build_scope_tree(self.ast) inlined = self._process_scope(scope_tree) if not inlined: @@ -53,7 +60,7 @@ def execute(self): self._remove_declarators(inlined) return self.has_changed() - def _process_scope(self, scope): + def _process_scope(self, scope: Scope) -> list[dict]: """Find and inline single-use constant bindings.""" inlined_declarators = [] @@ -115,7 +122,7 @@ def _process_scope(self, scope): return inlined_declarators - def _is_mutated_member_object(self, ref_parent, ref_key): + def _is_mutated_member_object(self, ref_parent: dict | None, ref_key: str | None) -> bool: """Check if ref is the object of a member expression that is an assignment target. Catches: obj[x] = val, obj.x = val, obj[x]++, etc. @@ -129,27 +136,27 @@ def _is_mutated_member_object(self, ref_parent, ref_key): parent_info = find_parent(self.ast, ref_parent) if not parent_info: return False - grandparent, gp_key, _ = parent_info - if grandparent.get('type') == 'AssignmentExpression' and gp_key == 'left': + grandparent, grandparent_key, _ = parent_info + if grandparent.get('type') == 'AssignmentExpression' and grandparent_key == 'left': return True if grandparent.get('type') == 'UpdateExpression': return True return False - def _remove_declarators(self, declarator_nodes): + def _remove_declarators(self, declarator_nodes: list[dict]) -> None: """Remove inlined VariableDeclarators from their parent declarations.""" - declarator_ids = {id(d) for d in declarator_nodes} + declarator_ids = {id(declarator) for declarator in declarator_nodes} - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object: if node.get('type') != 'VariableDeclaration': return - decls = node.get('declarations', []) - original_len = len(decls) - decls[:] = [d for d in decls if id(d) not in declarator_ids] - if len(decls) == original_len: + declarations = node.get('declarations', []) + original_length = len(declarations) + declarations[:] = [declarator for declarator in declarations if id(declarator) not in declarator_ids] + if len(declarations) == original_length: return # No match — continue traversing children self.set_changed() - if not decls: + if not declarations: return REMOVE traverse(self.ast, {'enter': enter}) diff --git a/pyjsclear/transforms/string_revealer.py b/pyjsclear/transforms/string_revealer.py index 3070f8f..7d9c99b 100644 --- a/pyjsclear/transforms/string_revealer.py +++ b/pyjsclear/transforms/string_revealer.py @@ -2,6 +2,7 @@ import math import re +from typing import Any from ..generator import generate from ..scope import build_scope_tree @@ -24,14 +25,14 @@ _RC4_REGEX = re.compile(r"""fromCharCode.{0,30}\^""") -def _eval_numeric(node): +def _eval_numeric(node: Any) -> int | float | None: """Evaluate an AST node to a numeric value if it's a constant expression.""" if not isinstance(node, dict): return None match node.get('type', ''): case 'Literal': - val = node.get('value') - return val if isinstance(val, (int, float)) else None + value = node.get('value') + return value if isinstance(value, (int, float)) else None case 'UnaryExpression': arg = _eval_numeric(node.get('argument')) if arg is None: @@ -51,18 +52,18 @@ def _eval_numeric(node): return None -def _js_parse_int(s): +def _js_parse_int(string: str) -> float: """Mimic JavaScript's parseInt: extract leading integer from string.""" - if not isinstance(s, str): + if not isinstance(string, str): return float('nan') - s = s.strip() - m = re.match(r'^[+-]?\d+', s) - if m: - return int(m.group()) + string = string.strip() + match = re.match(r'^[+-]?\d+', string) + if match: + return int(match.group()) return float('nan') -def _apply_arith(operator, left, right): +def _apply_arith(operator: str, left: int | float, right: int | float) -> int | float | None: """Apply a binary arithmetic operator.""" match operator: case '+': @@ -79,7 +80,7 @@ def _apply_arith(operator, left, right): return None -def _collect_object_literals(ast): +def _collect_object_literals(ast: dict) -> dict[tuple[str, str], int | float | str]: """Collect simple object literal assignments: var o = {a: 0x1b1, b: 'str'}. Returns a dict mapping (object_name, property_name) -> value (int or str). @@ -117,7 +118,7 @@ def visitor(node, parent): return result -def _resolve_arg_value(arg, object_literals): +def _resolve_arg_value(arg: dict, object_literals: dict) -> int | float | None: """Try to resolve a call argument to a numeric value. Handles numeric literals, string hex literals, and member expressions @@ -150,7 +151,7 @@ def _resolve_arg_value(arg, object_literals): return None -def _resolve_string_arg(arg, object_literals): +def _resolve_string_arg(arg: dict, object_literals: dict) -> str | None: """Try to resolve a call argument to a string value. Handles string literals and member expressions referencing known object properties. @@ -170,14 +171,21 @@ def _resolve_string_arg(arg, object_literals): class WrapperInfo: """Info about a wrapper function that calls the decoder.""" - def __init__(self, name, param_index, wrapper_offset, func_node, key_param_index=None): + def __init__( + self, + name: str, + param_index: int, + wrapper_offset: int, + func_node: dict, + key_param_index: int | None = None, + ) -> None: self.name = name self.param_index = param_index self.wrapper_offset = wrapper_offset self.func_node = func_node self.key_param_index = key_param_index - def get_effective_index(self, call_args): + def get_effective_index(self, call_args: list) -> int | None: """Given call argument values, compute the effective decoder index.""" if self.param_index >= len(call_args): return None @@ -186,7 +194,7 @@ def get_effective_index(self, call_args): return None return int(value) + self.wrapper_offset - def get_key(self, call_args): + def get_key(self, call_args: list) -> str | None: """Get the RC4 key argument if applicable.""" if self.key_param_index is not None and self.key_param_index < len(call_args): return call_args[self.key_param_index] @@ -199,7 +207,7 @@ class StringRevealer(Transform): rebuild_scope = True _rotation_locals = {} - def execute(self): + def execute(self) -> bool: scope_tree, node_scope = build_scope_tree(self.ast) # Strategy 1: Direct string array declarations (var arr = ["a","b","c"]) @@ -220,7 +228,7 @@ def execute(self): # Strategy 2: Obfuscator.io pattern # ================================================================ - def _process_obfuscatorio_pattern(self): + def _process_obfuscatorio_pattern(self) -> None: """Handle obfuscator.io: array func -> decoder func(s) -> wrapper funcs -> rotation.""" body = self.ast.get('body', []) @@ -312,7 +320,7 @@ def _process_obfuscatorio_pattern(self): indices_to_remove.add(array_func_idx) self._remove_body_indices(body, *indices_to_remove) - def _find_string_array_function(self, body): + def _find_string_array_function(self, body: list) -> tuple[str | None, list | None, int | None]: """Find the string array function declaration. Pattern: function X() { var a = ['s1','s2',...]; X = function(){return a;}; return X(); } @@ -334,7 +342,7 @@ def _find_string_array_function(self, body): return None, None, None @staticmethod - def _string_array_from_expression(node): + def _string_array_from_expression(node: dict | None) -> list[str] | None: """Return list of string values if node is an ArrayExpression of all string literals.""" if not node or node.get('type') != 'ArrayExpression': return None @@ -343,7 +351,7 @@ def _string_array_from_expression(node): return None return [e['value'] for e in elements] - def _extract_array_from_statement(self, stmt): + def _extract_array_from_statement(self, stmt: dict) -> list[str] | None: """Extract string array from a variable declaration or assignment.""" if stmt.get('type') == 'VariableDeclaration': for declaration in stmt.get('declarations', []): @@ -356,7 +364,7 @@ def _extract_array_from_statement(self, stmt): return self._string_array_from_expression(expr.get('right')) return None - def _find_all_decoder_functions(self, body, array_func_name): + def _find_all_decoder_functions(self, body: list, array_func_name: str) -> list[tuple[str, int, int, DecoderType]]: """Find all decoder functions that call the array function. Returns list of (func_name, offset, body_index, decoder_type) tuples. @@ -384,7 +392,7 @@ def _find_all_decoder_functions(self, body, array_func_name): return results - def _function_calls(self, func_node, callee_name): + def _function_calls(self, func_node: dict, callee_name: str) -> bool: """Check if a function body contains a call to callee_name.""" found = [False] @@ -401,7 +409,7 @@ def visitor(node, parent): simple_traverse(func_node, visitor) return found[0] - def _extract_decoder_offset(self, func_node): + def _extract_decoder_offset(self, func_node: dict) -> int: """Extract offset from decoder's inner param = param OP EXPR pattern.""" found_offset = [None] @@ -430,7 +438,7 @@ def find_offset(node, parent): simple_traverse(func_node, find_offset) return found_offset[0] if found_offset[0] is not None else 0 - def _create_base_decoder(self, string_array, offset, dtype): + def _create_base_decoder(self, string_array: list[str], offset: int, dtype: DecoderType) -> BasicStringDecoder | Base64StringDecoder | Rc4StringDecoder: """Create the appropriate decoder instance.""" match dtype: case DecoderType.RC4: @@ -440,7 +448,7 @@ def _create_base_decoder(self, string_array, offset, dtype): case _: return BasicStringDecoder(string_array, offset) - def _find_all_wrappers(self, decoder_name): + def _find_all_wrappers(self, decoder_name: str) -> dict[str, 'WrapperInfo']: """Find all wrapper functions throughout the AST that call the decoder. Pattern: function W(p0,..,pN) { return DECODER(p_i OP OFFSET, p_j); } @@ -467,14 +475,14 @@ def visitor(node, parent): simple_traverse(self.ast, visitor) return wrappers - def _analyze_wrapper(self, func_node, decoder_name): + def _analyze_wrapper(self, func_node: dict, decoder_name: str) -> 'WrapperInfo | None': """Check if a FunctionDeclaration is a wrapper. Returns WrapperInfo or None.""" func_name = func_node.get('id', {}).get('name') if not func_name: return None return self._analyze_wrapper_expr(func_name, func_node, decoder_name) - def _analyze_wrapper_expr(self, func_name, func_node, decoder_name): + def _analyze_wrapper_expr(self, func_name: str, func_node: dict, decoder_name: str) -> 'WrapperInfo | None': """Analyze a function node (declaration or expression) as a potential wrapper.""" func_body = func_node.get('body', {}) if func_body.get('type') == 'BlockStatement': @@ -516,7 +524,7 @@ def _analyze_wrapper_expr(self, func_name, func_node, decoder_name): return WrapperInfo(func_name, param_index, wrapper_offset, func_node, key_param_index) - def _extract_wrapper_offset(self, expr, param_names): + def _extract_wrapper_offset(self, expr: dict, param_names: list[str]) -> tuple[int | None, int | None]: """Extract (param_index, offset) from wrapper's first argument to decoder. Handles: p_N, p_N + LIT, p_N - LIT, p_N - -LIT, p_N + -LIT @@ -542,7 +550,7 @@ def _extract_wrapper_offset(self, expr, param_names): offset = int(-right_value) if operator == '-' else int(right_value) return param_idx, offset - def _remove_decoder_aliases(self, decoder_name, aliases): + def _remove_decoder_aliases(self, decoder_name: str, aliases: set[str]) -> None: """Remove variable declarations that are aliases for the decoder. Removes: const _0xABC = _0x22e6; and transitive: const _0xDEF = _0xABC; @@ -577,7 +585,7 @@ def enter(node, parent, key, index): traverse(self.ast, {'enter': enter}) - def _find_decoder_aliases(self, decoder_name): + def _find_decoder_aliases(self, decoder_name: str) -> set[str]: """Find all variable declarations that are aliases for the decoder. Handles transitive aliases: const a = decoder; const b = a; const c = b; @@ -614,15 +622,15 @@ def visitor(node, parent): def _find_and_execute_rotation( self, - body, - array_func_name, - string_array, - decoder, - wrappers, - decoder_aliases=None, - alias_decoder_map=None, - all_decoders=None, - ): + body: list, + array_func_name: str, + string_array: list[str], + decoder: Any, + wrappers: dict, + decoder_aliases: set[str] | None = None, + alias_decoder_map: dict | None = None, + all_decoders: dict | None = None, + ) -> tuple[int, Any] | None: """Find rotation IIFE and execute it. Returns (body_index, rotation_call_expr_or_none) on success, or None. @@ -669,15 +677,15 @@ def _find_and_execute_rotation( def _try_execute_rotation_call( self, - call_expr, - array_func_name, - string_array, - decoder, - wrappers, - decoder_aliases, - alias_decoder_map=None, - all_decoders=None, - ): + call_expr: dict, + array_func_name: str, + string_array: list[str], + decoder: Any, + wrappers: dict, + decoder_aliases: set[str] | None, + alias_decoder_map: dict | None = None, + all_decoders: dict | None = None, + ) -> bool: """Try to parse and execute a single rotation call expression. Returns True on success.""" callee = call_expr.get('callee') args = call_expr.get('arguments', []) @@ -716,7 +724,7 @@ def _try_execute_rotation_call( return True @staticmethod - def _collect_rotation_locals(iife_func): + def _collect_rotation_locals(iife_func: dict) -> dict[str, dict]: """Collect local object literal assignments from the rotation IIFE. Returns dict: var_name -> {prop_name: value}. @@ -739,21 +747,21 @@ def _collect_rotation_locals(iife_func): if not key or not value: continue if is_identifier(key): - k = key['name'] + prop_name = key['name'] elif is_string_literal(key): - k = key['value'] + prop_name = key['value'] else: continue num = _eval_numeric(value) if num is not None: - obj[k] = int(num) + obj[prop_name] = int(num) elif is_string_literal(value): - obj[k] = value['value'] + obj[prop_name] = value['value'] if obj: result[name_node['name']] = obj return result - def _extract_rotation_expression(self, iife_func): + def _extract_rotation_expression(self, iife_func: dict) -> dict | None: """Extract the arithmetic expression from the try block in the rotation loop.""" func_body = iife_func.get('body', {}).get('body', []) if not func_body: @@ -782,7 +790,7 @@ def _extract_rotation_expression(self, iife_func): return None @staticmethod - def _expression_from_try_block(first_statement): + def _expression_from_try_block(first_statement: dict) -> dict | None: """Extract the init/rhs expression from the first statement in a try block.""" if first_statement.get('type') == 'VariableDeclaration': decls = first_statement.get('declarations', []) @@ -793,7 +801,7 @@ def _expression_from_try_block(first_statement): return expr.get('right') return None - def _parse_rotation_op(self, expr, wrappers, decoder_aliases=None): + def _parse_rotation_op(self, expr: dict, wrappers: dict, decoder_aliases: set[str] | None = None) -> dict | None: """Parse a rotation expression into an operation tree.""" if not isinstance(expr, dict): return None @@ -830,7 +838,7 @@ def _parse_rotation_op(self, expr, wrappers, decoder_aliases=None): return None - def _parse_parseInt_call(self, expr, wrappers, aliases): + def _parse_parseInt_call(self, expr: dict, wrappers: dict, aliases: set[str]) -> dict | None: """Parse parseInt(wrapperOrDecoder(...)) into an operation node.""" callee = expr.get('callee') args = expr.get('arguments', []) @@ -856,20 +864,20 @@ def _parse_parseInt_call(self, expr, wrappers, aliases): return {'op': 'direct_decoder_call', 'alias_name': cname, 'args': arg_values} return None - def _resolve_rotation_arg(self, arg): + def _resolve_rotation_arg(self, arg: dict) -> int | str | None: """Resolve a rotation call argument to a numeric or string value. Handles literals, string hex, and MemberExpression referencing local objects. """ - val = _eval_numeric(arg) - if val is not None: - return int(val) + numeric_value = _eval_numeric(arg) + if numeric_value is not None: + return int(numeric_value) if is_string_literal(arg): - s = arg['value'] + string_value = arg['value'] try: - return int(s, 16) if s.startswith('0x') else int(s) + return int(string_value, 16) if string_value.startswith('0x') else int(string_value) except (ValueError, TypeError): - return s + return string_value # MemberExpression: J.A or J['A'] if arg.get('type') == 'MemberExpression': obj = arg.get('object') @@ -882,7 +890,7 @@ def _resolve_rotation_arg(self, arg): return local_obj.get(prop['value']) return None - def _decode_and_parse_int(self, decoder, idx, key=None): + def _decode_and_parse_int(self, decoder: Any, idx: int | float, key: str | None = None) -> float: """Decode a string and parse it as an integer. Raises on failure.""" decoded = decoder.get_string(int(idx), key) if key is not None else decoder.get_string(int(idx)) if decoded is None: @@ -892,7 +900,7 @@ def _decode_and_parse_int(self, decoder, idx, key=None): raise ValueError('NaN from parseInt') return result - def _apply_rotation_op(self, operation, wrappers, decoder, alias_decoder_map=None): + def _apply_rotation_op(self, operation: dict, wrappers: dict, decoder: Any, alias_decoder_map: dict | None = None) -> int | float: """Evaluate a parsed rotation operation tree.""" match operation['op']: case 'literal': @@ -924,7 +932,7 @@ def _apply_rotation_op(self, operation, wrappers, decoder, alias_decoder_map=Non case _: raise ValueError(f'Unknown op: {operation["op"]}') - def _execute_rotation(self, string_array, operation, wrappers, decoder, stop_value, alias_decoder_map=None): + def _execute_rotation(self, string_array: list[str], operation: dict, wrappers: dict, decoder: Any, stop_value: int, alias_decoder_map: dict | None = None) -> bool: """Rotate array until the expression evaluates to stop_value.""" # Collect all decoders that need cache clearing on each rotation all_decoders = set() @@ -941,14 +949,14 @@ def _execute_rotation(self, string_array, operation, wrappers, decoder, stop_val except Exception: string_array.append(string_array.pop(0)) # Clear decoder caches after rotation since array contents shifted - for d in all_decoders: - if hasattr(d, '_cache'): - d._cache.clear() + for decoder_instance in all_decoders: + if hasattr(decoder_instance, '_cache'): + decoder_instance._cache.clear() return False # ---- Replacement ---- - def _replace_all_wrapper_calls(self, wrappers, decoder, obj_literals=None): + def _replace_all_wrapper_calls(self, wrappers: dict, decoder: Any, obj_literals: dict | None = None) -> bool: """Replace all calls to wrapper functions with decoded string literals.""" if not wrappers: return True @@ -999,7 +1007,7 @@ def enter(node, parent, key, index): traverse(self.ast, {'enter': enter}) return all_replaced[0] - def _replace_direct_decoder_calls(self, decoder_name, decoder, decoder_aliases=None, obj_literals=None): + def _replace_direct_decoder_calls(self, decoder_name: str, decoder: Any, decoder_aliases: set[str] | None = None, obj_literals: dict | None = None) -> None: """Replace direct calls to the decoder function (and its aliases) with literals.""" names = {decoder_name} if decoder_aliases: @@ -1037,7 +1045,7 @@ def enter(node, parent, key, index): traverse(self.ast, {'enter': enter}) @staticmethod - def _find_array_expression_in_statement(stmt): + def _find_array_expression_in_statement(stmt: dict) -> dict | None: """Find the first ArrayExpression node in a variable declaration or assignment.""" if stmt.get('type') == 'VariableDeclaration': for declaration in stmt.get('declarations', []): @@ -1052,7 +1060,7 @@ def _find_array_expression_in_statement(stmt): return right return None - def _update_ast_array(self, func_node, rotated_array): + def _update_ast_array(self, func_node: dict, rotated_array: list[str]) -> None: """Update the AST's array function to contain the rotated string array.""" func_body = func_node.get('body', {}).get('body', []) if not func_body: @@ -1061,7 +1069,7 @@ def _update_ast_array(self, func_node, rotated_array): if arr_expr is not None: arr_expr['elements'] = [make_literal(s) for s in rotated_array] - def _remove_body_indices(self, body, *indices): + def _remove_body_indices(self, body: list, *indices: int | None) -> None: """Remove statements at given indices from body.""" for idx in sorted(set(i for i in indices if i is not None), reverse=True): if 0 <= idx < len(body): @@ -1076,7 +1084,7 @@ def _remove_body_indices(self, body, *indices): # var _0xDEC = function(a, b) { a = a - OFFSET; var x = _0xARR[a]; return x; }; # ================================================================ - def _process_var_array_pattern(self): + def _process_var_array_pattern(self) -> None: """Handle var-based string array with simple rotation and decoder.""" body = self.ast.get('body', []) if len(body) < 3: @@ -1122,7 +1130,7 @@ def _process_var_array_pattern(self): indices_to_remove.add(rotation_idx) self._remove_body_indices(body, *indices_to_remove) - def _find_var_string_array(self, body): + def _find_var_string_array(self, body: list) -> tuple[str | None, list[str] | None, int | None]: """Find var _0x... = ['s1', 's2', ...] at top of body.""" for i, stmt in enumerate(body[:3]): if stmt.get('type') != 'VariableDeclaration': @@ -1142,7 +1150,7 @@ def _find_var_string_array(self, body): return name_node['name'], [e['value'] for e in elements], i return None, None, None - def _find_simple_rotation(self, body, array_name): + def _find_simple_rotation(self, body: list, array_name: str) -> tuple[int | None, int | None]: """Find (function(arr, count) { ...push/shift... })(array, N) rotation IIFE.""" for i, stmt in enumerate(body): if stmt.get('type') != 'ExpressionStatement': @@ -1177,7 +1185,7 @@ def _find_simple_rotation(self, body, array_name): return None, None - def _find_var_decoder(self, body, array_name): + def _find_var_decoder(self, body: list, array_name: str) -> tuple[str | None, int | None, int | None]: """Find var _0xDEC = function(a) { a = a - OFFSET; var x = ARR[a]; return x; }.""" for i, stmt in enumerate(body): if stmt.get('type') != 'VariableDeclaration': @@ -1200,7 +1208,7 @@ def _find_var_decoder(self, body, array_name): # Strategy 1: Direct string array declarations # ================================================================ - def _try_replace_array_access(self, ref_parent, ref_key, string_array): + def _try_replace_array_access(self, ref_parent: dict | None, ref_key: str, string_array: list[str]) -> None: """Replace arr[N] member expression with the string literal if valid.""" if not ref_parent or ref_parent.get('type') != 'MemberExpression': return @@ -1215,7 +1223,7 @@ def _try_replace_array_access(self, ref_parent, ref_key, string_array): self._replace_node_in_ast(ref_parent, make_literal(string_array[idx])) self.set_changed() - def _process_direct_arrays(self, scope_tree): + def _process_direct_arrays(self, scope_tree: Any) -> None: """Find direct array declarations and replace indexed accesses.""" for name, binding in list(scope_tree.bindings.items()): node = binding.node @@ -1234,7 +1242,7 @@ def _process_direct_arrays(self, scope_tree): for child in scope_tree.children: self._process_direct_arrays_in_scope(child, name, string_array) - def _process_direct_arrays_in_scope(self, scope, name, string_array): + def _process_direct_arrays_in_scope(self, scope: Any, name: str, string_array: list[str]) -> None: """Process direct array accesses in child scopes.""" binding = scope.get_binding(name) if not binding: @@ -1242,7 +1250,7 @@ def _process_direct_arrays_in_scope(self, scope, name, string_array): for reference_node, reference_parent, reference_key, ref_index in binding.references[:]: self._try_replace_array_access(reference_parent, reference_key, string_array) - def _replace_node_in_ast(self, target, replacement): + def _replace_node_in_ast(self, target: dict, replacement: dict) -> None: """Replace a node in the AST with a replacement.""" result = find_parent(self.ast, target) if result: @@ -1256,6 +1264,6 @@ def _replace_node_in_ast(self, target, replacement): # Strategy 3: Simple static array unpacking # ================================================================ - def _process_static_arrays(self): + def _process_static_arrays(self) -> None: """No-op: static array unpacking is handled by _process_direct_arrays.""" pass diff --git a/pyjsclear/transforms/unreachable_code.py b/pyjsclear/transforms/unreachable_code.py index 9cd32d4..cb5e4a9 100644 --- a/pyjsclear/transforms/unreachable_code.py +++ b/pyjsclear/transforms/unreachable_code.py @@ -11,14 +11,14 @@ class UnreachableCodeRemover(Transform): """Remove statements that follow a terminator (return/throw/break/continue) in a block.""" - def execute(self): - def enter(node, parent, key, index): - t = node.get('type') - if t in ('BlockStatement', 'Program'): + def execute(self) -> bool: + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: + node_type = node.get('type') + if node_type in ('BlockStatement', 'Program'): body = node.get('body') if body and isinstance(body, list): self._truncate_after_terminator(body, node, 'body') - elif t == 'SwitchCase': + elif node_type == 'SwitchCase': consequent = node.get('consequent') if consequent and isinstance(consequent, list): self._truncate_after_terminator(consequent, node, 'consequent') @@ -26,12 +26,13 @@ def enter(node, parent, key, index): traverse(self.ast, {'enter': enter}) return self.has_changed() - def _truncate_after_terminator(self, stmts, node, key): - for i, stmt in enumerate(stmts): - if not isinstance(stmt, dict): + def _truncate_after_terminator(self, statements: list, node: dict, key: str) -> None: + for statement_index, statement in enumerate(statements): + if not isinstance(statement, dict): continue - if stmt.get('type') in _TERMINATORS: - if i + 1 < len(stmts): - self.set_changed() - node[key] = stmts[: i + 1] - return + if statement.get('type') not in _TERMINATORS: + continue + if statement_index + 1 < len(statements): + self.set_changed() + node[key] = statements[: statement_index + 1] + return diff --git a/pyjsclear/transforms/unused_vars.py b/pyjsclear/transforms/unused_vars.py index 64ebbba..830ca89 100644 --- a/pyjsclear/transforms/unused_vars.py +++ b/pyjsclear/transforms/unused_vars.py @@ -33,17 +33,17 @@ class UnusedVariableRemover(Transform): rebuild_scope = True - def execute(self): + def execute(self) -> bool: scope_tree, _ = build_scope_tree(self.ast) - declarators_to_remove = set() - functions_to_remove = set() + declarators_to_remove: set[int] = set() + functions_to_remove: set[int] = set() self._collect_unused(scope_tree, declarators_to_remove, functions_to_remove) if not declarators_to_remove and not functions_to_remove: return False self._batch_remove(declarators_to_remove, functions_to_remove) return self.has_changed() - def _collect_unused(self, scope, declarators, functions): + def _collect_unused(self, scope: object, declarators: set[int], functions: set[int]) -> None: skip_global = scope.parent is None for name, binding in scope.bindings.items(): @@ -66,30 +66,31 @@ def _collect_unused(self, scope, declarators, functions): for child in scope.children: self._collect_unused(child, declarators, functions) - def _batch_remove(self, declarators_to_remove, functions_to_remove): + def _batch_remove(self, declarators_to_remove: set[int], functions_to_remove: set[int]) -> None: """Remove all collected unused declarations in a single traversal.""" - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object: node_type = node.get('type') if node_type == 'FunctionDeclaration' and id(node) in functions_to_remove: self.set_changed() return REMOVE if node_type != 'VariableDeclaration': - return - decls = node.get('declarations') - if not decls: - return - new_decls = [d for d in decls if id(d) not in declarators_to_remove] - if len(new_decls) == len(decls): - return + return None + declarations = node.get('declarations') + if not declarations: + return None + filtered_declarations = [declarator for declarator in declarations if id(declarator) not in declarators_to_remove] + if len(filtered_declarations) == len(declarations): + return None self.set_changed() - if not new_decls: + if not filtered_declarations: return REMOVE - node['declarations'] = new_decls + node['declarations'] = filtered_declarations + return None traverse(self.ast, {'enter': enter}) - def _has_side_effects(self, node): + def _has_side_effects(self, node: dict) -> bool: """Conservative check for side effects in an expression.""" if not isinstance(node, dict): return False @@ -98,8 +99,7 @@ def _has_side_effects(self, node): return True if node_type in ('Literal', 'Identifier', 'ThisExpression', 'FunctionExpression', 'ArrowFunctionExpression'): return False - # For all other types (including ArrayExpression, ObjectExpression, - # BinaryExpression, etc.), recurse into children to check for side effects + # Recurse into children (handles ArrayExpression, ObjectExpression, BinaryExpression, etc.) for key in get_child_keys(node): child = node.get(key) if child is None: diff --git a/pyjsclear/transforms/variable_renamer.py b/pyjsclear/transforms/variable_renamer.py index a3ad71d..27584ff 100644 --- a/pyjsclear/transforms/variable_renamer.py +++ b/pyjsclear/transforms/variable_renamer.py @@ -135,38 +135,59 @@ 'FormData': 'form', } +# fs-like methods +_FS_METHODS = { + 'readFileSync', + 'writeFileSync', + 'existsSync', + 'mkdirSync', + 'statSync', + 'readdirSync', + 'unlinkSync', + 'createWriteStream', + 'createReadStream', + 'readFile', + 'writeFile', + 'appendFileSync', +} + +# path-like methods +_PATH_METHODS = {'join', 'resolve', 'basename', 'dirname', 'extname', 'normalize'} + +_ALPHABET = 'abcdefghijklmnopqrstuvwxyz' + -def _name_generator(reserved): +def _name_generator(reserved: set) -> object: """Yield short identifier names, skipping reserved and taken names.""" - for c in 'abcdefghijklmnopqrstuvwxyz': - if c not in reserved: - yield c - for c1 in 'abcdefghijklmnopqrstuvwxyz': - for c2 in 'abcdefghijklmnopqrstuvwxyz': - name = c1 + c2 + for char in _ALPHABET: + if char not in reserved: + yield char + for first_char in _ALPHABET: + for second_char in _ALPHABET: + name = first_char + second_char if name not in reserved: yield name - for c1 in 'abcdefghijklmnopqrstuvwxyz': - for c2 in 'abcdefghijklmnopqrstuvwxyz': - for c3 in 'abcdefghijklmnopqrstuvwxyz': - name = c1 + c2 + c3 + for first_char in _ALPHABET: + for second_char in _ALPHABET: + for third_char in _ALPHABET: + name = first_char + second_char + third_char if name not in reserved: yield name -def _dedupe_name(base, reserved): +def _dedupe_name(base: str, reserved: set) -> str: """Return base or base2, base3, ... until a non-reserved name is found.""" if base not in reserved: return base - n = 2 + counter = 2 while True: - candidate = f'{base}{n}' + candidate = f'{base}{counter}' if candidate not in reserved: return candidate - n += 1 + counter += 1 -def _infer_from_init(init): +def _infer_from_init(init: dict | None) -> str | None: """Infer a variable name from its initializer expression.""" if not isinstance(init, dict) or 'type' not in init: return None @@ -180,11 +201,11 @@ def _infer_from_init(init): if is_identifier(callee) and callee.get('name') == 'require' and len(args) == 1: arg = args[0] if arg.get('type') == 'Literal' and isinstance(arg.get('value'), str): - mod = arg['value'] - if mod in _REQUIRE_NAMES: - return _REQUIRE_NAMES[mod] + module_name = arg['value'] + if module_name in _REQUIRE_NAMES: + return _REQUIRE_NAMES[module_name] # Derive name from module path, sanitized to valid identifier - base = mod.split('/')[-1].split('\\')[-1] + base = module_name.split('/')[-1].split('\\')[-1] base = base.split('.')[0] # strip file extension base = re.sub(r'[^a-zA-Z0-9_]', '_', base) base = re.sub(r'^[0-9]+', '', base) # can't start with digit @@ -200,20 +221,19 @@ def _infer_from_init(init): if is_identifier(obj) and is_identifier(prop): obj_name = obj.get('name') prop_name = prop.get('name') - if obj_name == 'Buffer' and prop_name == 'from': - return 'buf' - if obj_name == 'JSON' and prop_name == 'parse': - return 'data' - if obj_name == 'JSON' and prop_name == 'stringify': - return 'json' - if obj_name == 'Object' and prop_name == 'keys': - return 'keys' - if obj_name == 'Object' and prop_name == 'values': - return 'values' - if obj_name == 'Object' and prop_name == 'entries': - return 'entries' - if obj_name == 'Object' and prop_name == 'getOwnPropertyNames': - return 'keys' + match (obj_name, prop_name): + case ('Buffer', 'from'): + return 'buf' + case ('JSON', 'parse'): + return 'data' + case ('JSON', 'stringify'): + return 'json' + case ('Object', 'keys') | ('Object', 'getOwnPropertyNames'): + return 'keys' + case ('Object', 'values'): + return 'values' + case ('Object', 'entries'): + return 'entries' # new Date() → "date" if init_type == 'NewExpression': @@ -226,34 +246,28 @@ def _infer_from_init(init): if is_identifier(prop): return _CONSTRUCTOR_NAMES.get(prop.get('name')) - # [] → "arr" - if init_type == 'ArrayExpression': - return 'arr' - - # {} → "obj" - if init_type == 'ObjectExpression': - return 'obj' - - # "string" → "str" - if init_type == 'Literal': - val = init.get('value') - if isinstance(val, str): - return 'str' - if isinstance(val, bool): - return 'flag' - - # await expr → infer from the inner expression - if init_type == 'AwaitExpression': - return _infer_from_init(init.get('argument')) + match init_type: + case 'ArrayExpression': + return 'arr' + case 'ObjectExpression': + return 'obj' + case 'Literal': + value = init.get('value') + if isinstance(value, str): + return 'str' + if isinstance(value, bool): + return 'flag' + case 'AwaitExpression': + return _infer_from_init(init.get('argument')) return None -def _infer_from_usage(binding): +def _infer_from_usage(binding: object) -> str | None: """Infer a variable name from how it's used at reference sites.""" # Check what methods are called on this variable methods = set() - for ref_node, ref_parent, ref_key, ref_index in binding.references: + for ref_node, ref_parent, ref_key, _ref_index in binding.references: if not ref_parent: continue # x.method() — ref is object of MemberExpression @@ -262,26 +276,9 @@ def _infer_from_usage(binding): if is_identifier(prop) and not ref_parent.get('computed'): methods.add(prop.get('name')) - # fs-like methods - _FS_METHODS = { - 'readFileSync', - 'writeFileSync', - 'existsSync', - 'mkdirSync', - 'statSync', - 'readdirSync', - 'unlinkSync', - 'createWriteStream', - 'createReadStream', - 'readFile', - 'writeFile', - 'appendFileSync', - } if methods & _FS_METHODS: return 'fs' - # path-like methods - _PATH_METHODS = {'join', 'resolve', 'basename', 'dirname', 'extname', 'normalize'} if methods & _PATH_METHODS and not (methods - _PATH_METHODS - {'sep'}): return 'path' @@ -308,7 +305,7 @@ def _infer_from_usage(binding): return None -def _infer_loop_var(binding): +def _infer_loop_var(binding: object) -> bool | None: """Check if this binding is a for-loop counter.""" node = binding.node if not isinstance(node, dict): @@ -319,40 +316,40 @@ def _infer_loop_var(binding): init = node.get('init') if not init or init.get('type') != 'Literal': return None - val = init.get('value') - if not isinstance(val, (int, float)): + value = init.get('value') + if not isinstance(value, (int, float)): return None # Check if any assignment is an UpdateExpression (i++, i--) if binding.assignments: - for assign in binding.assignments: - if isinstance(assign, dict) and assign.get('type') == 'UpdateExpression': + for assignment in binding.assignments: + if isinstance(assignment, dict) and assignment.get('type') == 'UpdateExpression': return True # Also check references for UpdateExpression parents - for ref_node, ref_parent, ref_key, ref_index in binding.references: + for _ref_node, ref_parent, _ref_key, _ref_index in binding.references: if ref_parent and ref_parent.get('type') == 'UpdateExpression': return True return None -def _collect_pattern_idents(pattern, result): +def _collect_pattern_idents(pattern: dict | None, result: list) -> None: """Collect all Identifier nodes from a destructuring pattern.""" if not isinstance(pattern, dict): return - pat_type = pattern.get('type') - if pat_type == 'Identifier': + pattern_type = pattern.get('type') + if pattern_type == 'Identifier': result.append(pattern) - elif pat_type == 'ArrayPattern': - for elem in pattern.get('elements', []): - if elem: - _collect_pattern_idents(elem, result) - elif pat_type == 'ObjectPattern': + elif pattern_type == 'ArrayPattern': + for element in pattern.get('elements', []): + if element: + _collect_pattern_idents(element, result) + elif pattern_type == 'ObjectPattern': for prop in pattern.get('properties', []): - val = prop.get('value', prop.get('argument')) - if val: - _collect_pattern_idents(val, result) - elif pat_type == 'RestElement': + value = prop.get('value', prop.get('argument')) + if value: + _collect_pattern_idents(value, result) + elif pattern_type == 'RestElement': _collect_pattern_idents(pattern.get('argument'), result) - elif pat_type == 'AssignmentPattern': + elif pattern_type == 'AssignmentPattern': _collect_pattern_idents(pattern.get('left'), result) @@ -361,7 +358,7 @@ class VariableRenamer(Transform): rebuild_scope = True - def execute(self): + def execute(self) -> bool: scope_tree, _ = build_scope_tree(self.ast) # Collect all non-obfuscated names across the entire tree to avoid conflicts @@ -369,18 +366,18 @@ def execute(self): self._collect_reserved(scope_tree, reserved) # Rename bindings scope by scope - gen = _name_generator(reserved) + generator = _name_generator(reserved) # Track loop var counter for i, j, k assignment self._loop_letters = list('ijklmn') self._loop_idx = 0 - self._rename_scope(scope_tree, gen, reserved) + self._rename_scope(scope_tree, generator, reserved) # Fix duplicate names in destructuring patterns (can come from broken obfuscated input) self._fix_destructuring_dupes(reserved) return self.has_changed() - def _collect_reserved(self, scope, reserved): + def _collect_reserved(self, scope: object, reserved: set) -> None: """Collect all non-_0x binding names so we never generate a conflict.""" for name in scope.bindings: if not _OBF_RE.match(name): @@ -388,21 +385,21 @@ def _collect_reserved(self, scope, reserved): for child in scope.children: self._collect_reserved(child, reserved) - def _rename_scope(self, scope, gen, reserved): + def _rename_scope(self, scope: object, generator: object, reserved: set) -> None: """Rename all _0x bindings in this scope and its children.""" for name, binding in list(scope.bindings.items()): if not _OBF_RE.match(name): continue - new_name = self._pick_name(binding, gen, reserved) + new_name = self._pick_name(binding, generator, reserved) reserved.add(new_name) self._apply_rename(binding, new_name) self.set_changed() for child in scope.children: - self._rename_scope(child, gen, reserved) + self._rename_scope(child, generator, reserved) - def _pick_name(self, binding, gen, reserved): + def _pick_name(self, binding: object, generator: object, reserved: set) -> str: """Pick the best name for a binding using heuristics, with fallback.""" # 1. Check if it's a loop counter → i, j, k if _infer_loop_var(binding): @@ -437,9 +434,9 @@ def _pick_name(self, binding, gen, reserved): pass # Fall through to sequential # 5. Fallback: sequential name from generator - return next(gen) + return next(generator) - def _apply_rename(self, binding, new_name): + def _apply_rename(self, binding: object, new_name: str) -> None: """Rename a binding at its declaration site and all reference sites.""" old_name = binding.name @@ -469,14 +466,14 @@ def _apply_rename(self, binding, new_name): arg['name'] = new_name # 2. Rename at all reference sites - for ref_node, ref_parent, ref_key, ref_index in binding.references: + for ref_node, _ref_parent, _ref_key, _ref_index in binding.references: if ref_node.get('type') == 'Identifier' and ref_node.get('name') == old_name: ref_node['name'] = new_name # 3. Update binding.name binding.name = new_name - def _fix_destructuring_dupes(self, reserved): + def _fix_destructuring_dupes(self, reserved: set) -> None: """Fix duplicate identifier names in destructuring patterns. Obfuscators sometimes produce invalid code like `const [a, a, a] = x;`. @@ -487,15 +484,15 @@ def _fix_destructuring_dupes(self, reserved): the last step of the renamer post-pass. """ - def enter(node, parent, key, index): + def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None: if node.get('type') != 'VariableDeclarator': return - pat = node.get('id') - if not pat or pat.get('type') not in ('ArrayPattern', 'ObjectPattern'): + pattern = node.get('id') + if not pattern or pattern.get('type') not in ('ArrayPattern', 'ObjectPattern'): return # Collect all identifier nodes in the pattern idents = [] - _collect_pattern_idents(pat, idents) + _collect_pattern_idents(pattern, idents) seen = {} for ident_node in idents: name = ident_node.get('name') diff --git a/pyjsclear/transforms/xor_string_decode.py b/pyjsclear/transforms/xor_string_decode.py index 0dab5e8..2ddd2ba 100644 --- a/pyjsclear/transforms/xor_string_decode.py +++ b/pyjsclear/transforms/xor_string_decode.py @@ -26,23 +26,23 @@ from .base import Transform -def _extract_numeric_array(node): +def _extract_numeric_array(node: dict | None) -> list[int] | None: """Extract a list of integers from an ArrayExpression node.""" if not node or node.get('type') != 'ArrayExpression': return None elements = node.get('elements', []) result = [] - for el in elements: - if not is_numeric_literal(el): + for element in elements: + if not is_numeric_literal(element): return None - val = el['value'] - if not isinstance(val, (int, float)) or val != int(val) or val < 0 or val > 255: + value = element['value'] + if not isinstance(value, (int, float)) or value != int(value) or value < 0 or value > 255: return None - result.append(int(val)) + result.append(int(value)) return result -def _xor_decode(byte_array, prefix_len=4): +def _xor_decode(byte_array: list[int], prefix_len: int = 4) -> str | None: """Decode XOR-obfuscated byte array: prefix XOR'd against remaining data.""" if len(byte_array) <= prefix_len: return None @@ -56,7 +56,7 @@ def _xor_decode(byte_array, prefix_len=4): return None -def _is_xor_decoder_function(node): +def _is_xor_decoder_function(node: dict | None) -> bool: """Heuristic: check if a function body contains XOR (^=) on array elements and a slice/Buffer.from pattern typical of XOR string decoders.""" if not node: @@ -69,17 +69,17 @@ def _is_xor_decoder_function(node): has_slice = [False] has_tostring = [False] - def scan(n, parent): - if not isinstance(n, dict): + def scan(ast_node: dict, parent: dict) -> None: + if not isinstance(ast_node, dict): return # Look for ^= operator - if n.get('type') == 'AssignmentExpression' and n.get('operator') == '^=': + if ast_node.get('type') == 'AssignmentExpression' and ast_node.get('operator') == '^=': has_xor[0] = True # Look for .slice or .from - if n.get('type') == 'MemberExpression': - prop = n.get('property') - if prop: - name = prop.get('name') or (prop.get('value') if prop.get('type') == 'Literal' else None) + if ast_node.get('type') == 'MemberExpression': + property_node = ast_node.get('property') + if property_node: + name = property_node.get('name') or (property_node.get('value') if property_node.get('type') == 'Literal' else None) if name in ('slice', 'from'): has_slice[0] = True if name in ('toString', 'decode'): @@ -92,11 +92,11 @@ def scan(n, parent): class XorStringDecoder(Transform): """Decode XOR-obfuscated string constants and inline them.""" - def execute(self): + def execute(self) -> bool: # Phase 1: Find XOR decoder functions - decoder_funcs = set() + decoder_funcs: set[str] = set() - def find_decoders(node, parent): + def find_decoders(node: dict, parent: dict) -> None: if node.get('type') not in ('FunctionDeclaration', 'FunctionExpression'): return params = node.get('params', []) @@ -111,9 +111,9 @@ def find_decoders(node, parent): if func_id and is_identifier(func_id): decoder_funcs.add(func_id['name']) elif parent and parent.get('type') == 'VariableDeclarator': - decl_id = parent.get('id') - if decl_id and is_identifier(decl_id): - decoder_funcs.add(decl_id['name']) + declaration_id = parent.get('id') + if declaration_id and is_identifier(declaration_id): + decoder_funcs.add(declaration_id['name']) simple_traverse(self.ast, find_decoders) @@ -121,14 +121,14 @@ def find_decoders(node, parent): return False # Phase 2: Find calls like `var X = decoder([...bytes...])` and decode - decoded_vars = {} # var_name → decoded_string + decoded_vars: dict[str, str] = {} # var_name → decoded_string - def find_calls(node, parent): + def find_calls(node: dict, parent: dict) -> None: if node.get('type') != 'VariableDeclarator': return - decl_id = node.get('id') + declaration_id = node.get('id') init = node.get('init') - if not is_identifier(decl_id) or not init: + if not is_identifier(declaration_id) or not init: return if init.get('type') != 'CallExpression': return @@ -143,7 +143,7 @@ def find_calls(node, parent): return decoded = _xor_decode(byte_array) if decoded is not None: - decoded_vars[decl_id['name']] = decoded + decoded_vars[declaration_id['name']] = decoded simple_traverse(self.ast, find_calls) @@ -152,12 +152,12 @@ def find_calls(node, parent): # Phase 3: Replace computed member accesses obj[_0xVAR] → obj.decoded # and standalone identifier refs with string literals - def replace_refs(node, parent, key, index): + def replace_refs(node: dict, parent: dict, key: str, index: int | None) -> dict | None: # Handle computed member: obj[_0xVAR] → obj.decoded or obj["decoded"] if node.get('type') == 'MemberExpression' and node.get('computed'): - prop = node.get('property') - if is_identifier(prop) and prop['name'] in decoded_vars: - decoded = decoded_vars[prop['name']] + property_node = node.get('property') + if is_identifier(property_node) and property_node['name'] in decoded_vars: + decoded = decoded_vars[property_node['name']] if is_valid_identifier(decoded): node['property'] = make_identifier(decoded) node['computed'] = False @@ -193,32 +193,32 @@ def replace_refs(node, parent, key, index): return self.has_changed() - def _remove_dead_declarations(self, decoded_vars): + def _remove_dead_declarations(self, decoded_vars: dict[str, str]) -> None: """Remove var X = decoder([...]) declarations that are now inlined.""" remaining_refs = {name: 0 for name in decoded_vars} - def count(node, parent): + def count_refs(node: dict, parent: dict) -> None: if is_identifier(node) and node['name'] in remaining_refs: if parent and parent.get('type') == 'VariableDeclarator' and node is parent.get('id'): return remaining_refs[node['name']] = remaining_refs.get(node['name'], 0) + 1 - simple_traverse(self.ast, count) + simple_traverse(self.ast, count_refs) - dead_vars = {name for name, count in remaining_refs.items() if count == 0} + dead_vars = {name for name, ref_count in remaining_refs.items() if ref_count == 0} if not dead_vars: return - def remove_decls(node, parent, key, index): + def remove_decls(node: dict, parent: dict, key: str, index: int | None) -> dict | None: if node.get('type') != 'VariableDeclaration': return decls = node.get('declarations', []) remaining = [] - for d in decls: - did = d.get('id') - if is_identifier(did) and did['name'] in dead_vars: + for declaration in decls: + declaration_id = declaration.get('id') + if is_identifier(declaration_id) and declaration_id['name'] in dead_vars: continue - remaining.append(d) + remaining.append(declaration) if len(remaining) == len(decls): return if not remaining: diff --git a/pyjsclear/traverser.py b/pyjsclear/traverser.py index a3337d6..dbd272d 100644 --- a/pyjsclear/traverser.py +++ b/pyjsclear/traverser.py @@ -1,5 +1,8 @@ """ESTree AST traversal with visitor pattern.""" +from collections.abc import Callable +from typing import Any + from .utils.ast_helpers import _CHILD_KEYS from .utils.ast_helpers import get_child_keys @@ -15,7 +18,7 @@ _isinstance = isinstance -def traverse(node, visitor): +def traverse(node: dict, visitor: dict | object) -> None: """Traverse an ESTree AST calling visitor callbacks. visitor should be a dict or object with optional 'enter' and 'exit' callables. @@ -36,7 +39,7 @@ def traverse(node, visitor): _REMOVE = REMOVE _SKIP = SKIP - def _visit(current_node, parent, key, index): + def _visit(current_node: dict, parent: dict | None, key: str | None, index: int | None) -> Any: node_type = current_node.get('type') if node_type is None: return current_node @@ -72,17 +75,17 @@ def _visit(current_node, parent, key, index): if child is None: continue if _isinstance(child, _list): - i = 0 - while i < len(child): - item = child[i] + child_index = 0 + while child_index < len(child): + item = child[child_index] if _isinstance(item, _dict) and 'type' in item: - result = _visit(item, current_node, child_key, i) + result = _visit(item, current_node, child_key, child_index) if result is _REMOVE: - child.pop(i) + child.pop(child_index) continue elif result is not item: - child[i] = result - i += 1 + child[child_index] = result + child_index += 1 elif _isinstance(child, _dict) and 'type' in child: result = _visit(child, current_node, child_key, None) if result is _REMOVE: @@ -103,13 +106,13 @@ def _visit(current_node, parent, key, index): _visit(node, None, None, None) -def simple_traverse(node, callback): +def simple_traverse(node: dict, callback: Callable) -> None: """Simple traversal that calls callback(node, parent) for every node. No replacement support - just visiting. """ child_keys_map = _CHILD_KEYS - def _visit(current_node, parent): + def _visit(current_node: dict, parent: dict | None) -> None: node_type = current_node.get('type') if node_type is None: return @@ -117,8 +120,8 @@ def _visit(current_node, parent): child_keys = child_keys_map.get(node_type) if child_keys is None: child_keys = get_child_keys(current_node) - for key in child_keys: - child = current_node.get(key) + for child_key in child_keys: + child = current_node.get(child_key) if child is None: continue if _isinstance(child, _list): @@ -131,16 +134,16 @@ def _visit(current_node, parent): _visit(node, None) -def collect_nodes(ast, node_type): +def collect_nodes(ast: dict, node_type: str) -> list[dict]: """Collect all nodes of a given type.""" - result = [] + collected = [] - def cb(node, parent): + def collect_callback(node: dict, parent: dict | None) -> None: if node.get('type') == node_type: - result.append(node) + collected.append(node) - simple_traverse(ast, cb) - return result + simple_traverse(ast, collect_callback) + return collected class _FoundParent(Exception): @@ -148,14 +151,14 @@ class _FoundParent(Exception): __slots__ = ('value',) - def __init__(self, value): + def __init__(self, value: tuple) -> None: self.value = value -def find_parent(ast, target_node): +def find_parent(ast: dict, target_node: dict) -> tuple | None: """Find the parent of a node in the AST. Returns (parent, key, index) or None.""" - def _visit(node): + def _visit(node: dict) -> None: if not isinstance(node, dict) or 'type' not in node: return for child_key in get_child_keys(node): @@ -163,9 +166,9 @@ def _visit(node): if child is None: continue if isinstance(child, list): - for i, item in enumerate(child): + for child_index, item in enumerate(child): if item is target_node: - raise _FoundParent((node, child_key, i)) + raise _FoundParent((node, child_key, child_index)) _visit(item) elif isinstance(child, dict): if child is target_node: @@ -174,12 +177,12 @@ def _visit(node): try: _visit(ast) - except _FoundParent as found: - return found.value + except _FoundParent as found_parent: + return found_parent.value return None -def replace_in_parent(parent, key, index, new_node): +def replace_in_parent(parent: dict, key: str, index: int | None, new_node: dict) -> None: """Replace a node within its parent.""" if index is not None: parent[key][index] = new_node @@ -187,7 +190,7 @@ def replace_in_parent(parent, key, index, new_node): parent[key] = new_node -def remove_from_parent(parent, key, index): +def remove_from_parent(parent: dict, key: str, index: int | None) -> None: """Remove a node from its parent.""" if index is not None: parent[key].pop(index) diff --git a/pyjsclear/utils/ast_helpers.py b/pyjsclear/utils/ast_helpers.py index cb4e2b7..4c1d94b 100644 --- a/pyjsclear/utils/ast_helpers.py +++ b/pyjsclear/utils/ast_helpers.py @@ -4,42 +4,42 @@ import re -def deep_copy(node): +def deep_copy(node: dict) -> dict: """Deep copy an AST node.""" return copy.deepcopy(node) -def is_literal(node): +def is_literal(node: object) -> bool: """Check if node is a Literal.""" return isinstance(node, dict) and node.get('type') == 'Literal' -def is_identifier(node): +def is_identifier(node: object) -> bool: """Check if node is an Identifier.""" return isinstance(node, dict) and node.get('type') == 'Identifier' -def is_string_literal(node): +def is_string_literal(node: object) -> bool: """Check if node is a string Literal.""" return is_literal(node) and isinstance(node.get('value'), str) -def is_numeric_literal(node): +def is_numeric_literal(node: object) -> bool: """Check if node is a numeric Literal.""" return is_literal(node) and isinstance(node.get('value'), (int, float)) -def is_boolean_literal(node): +def is_boolean_literal(node: object) -> bool: """Check if node is a boolean-ish literal (true/false or !0/!1).""" return is_literal(node) and isinstance(node.get('value'), bool) -def is_null_literal(node): +def is_null_literal(node: object) -> bool: """Check if node is null literal.""" return is_literal(node) and node.get('value') is None and node.get('raw') == 'null' -def is_undefined(node): +def is_undefined(node: object) -> bool: """Check if node represents undefined (identifier or ``void 0``).""" if is_identifier(node) and node.get('name') == 'undefined': return True @@ -55,14 +55,14 @@ def is_undefined(node): return False -def get_literal_value(node): +def get_literal_value(node: object) -> tuple: """Extract the value from a literal node. Returns (value, True) or (None, False).""" if not is_literal(node): return None, False return node.get('value'), True -def make_literal(value, raw=None): +def make_literal(value: object, raw: str | None = None) -> dict: """Create a Literal AST node.""" if raw is not None: return {'type': 'Literal', 'value': value, 'raw': raw} @@ -90,22 +90,22 @@ def make_literal(value, raw=None): return {'type': 'Literal', 'value': value, 'raw': raw} -def make_identifier(name): +def make_identifier(name: str) -> dict: """Create an Identifier AST node.""" return {'type': 'Identifier', 'name': name} -def make_expression_statement(expr): +def make_expression_statement(expr: dict) -> dict: """Wrap an expression in an ExpressionStatement.""" return {'type': 'ExpressionStatement', 'expression': expr} -def make_block_statement(body): +def make_block_statement(body: list) -> dict: """Create a BlockStatement.""" return {'type': 'BlockStatement', 'body': body} -def make_var_declaration(name, init=None, kind='var'): +def make_var_declaration(name: str, init: dict | None = None, kind: str = 'var') -> dict: """Create a VariableDeclaration with a single declarator.""" return { 'type': 'VariableDeclaration', @@ -117,7 +117,7 @@ def make_var_declaration(name, init=None, kind='var'): _IDENT_RE = re.compile(r'^[a-zA-Z_$][a-zA-Z0-9_$]*$') -def is_valid_identifier(name): +def is_valid_identifier(name: object) -> bool: """Check if a string is a valid JS identifier (for obj.prop access).""" if not isinstance(name, str) or not name: return False @@ -207,7 +207,7 @@ def is_valid_identifier(name): ) -def get_child_keys(node): +def get_child_keys(node: object) -> tuple | list: """Get keys of a node that may contain child nodes/arrays.""" if not isinstance(node, dict) or 'type' not in node: return () @@ -217,15 +217,15 @@ def get_child_keys(node): return keys # Fallback: return all keys that look like they might contain nodes return [ - k - for k, v in node.items() - if k not in _SKIP_KEYS - and not (k == 'expression' and node_type != 'ExpressionStatement') - and isinstance(v, (dict, list)) + key + for key, value in node.items() + if key not in _SKIP_KEYS + and not (key == 'expression' and node_type != 'ExpressionStatement') + and isinstance(value, (dict, list)) ] -def replace_identifiers(node, param_map): +def replace_identifiers(node: dict, param_map: dict) -> None: """Replace Identifier nodes whose names are in param_map with deep copies. Skips non-computed property names in MemberExpressions. @@ -238,10 +238,10 @@ def replace_identifiers(node, param_map): continue is_noncomputed_prop = key == 'property' and node.get('type') == 'MemberExpression' and not node.get('computed') if isinstance(child, list): - for i, item in enumerate(child): + for index, item in enumerate(child): if isinstance(item, dict) and item.get('type') == 'Identifier': if not is_noncomputed_prop and item.get('name', '') in param_map: - child[i] = copy.deepcopy(param_map[item['name']]) + child[index] = copy.deepcopy(param_map[item['name']]) elif isinstance(item, dict) and 'type' in item: replace_identifiers(item, param_map) elif isinstance(child, dict): @@ -252,12 +252,12 @@ def replace_identifiers(node, param_map): replace_identifiers(child, param_map) -def identifiers_match(a, b): +def identifiers_match(node_a: object, node_b: object) -> bool: """Check if two nodes are the same identifier.""" - return is_identifier(a) and is_identifier(b) and a.get('name') == b.get('name') + return is_identifier(node_a) and is_identifier(node_b) and node_a.get('name') == node_b.get('name') -def is_side_effect_free(node): +def is_side_effect_free(node: object) -> bool: """Check if an expression node is side-effect-free (safe to discard).""" if not isinstance(node, dict): return False @@ -288,7 +288,7 @@ def is_side_effect_free(node): return False -def get_member_names(node): +def get_member_names(node: object) -> tuple[str, str] | tuple[None, None]: """Extract (object_name, property_name) from a MemberExpression. Handles both computed (obj["prop"]) and non-computed (obj.prop) forms. @@ -311,17 +311,17 @@ def get_member_names(node): return None, None -def nodes_equal(a, b): +def nodes_equal(node_a: object, node_b: object) -> bool: """Check if two AST nodes are structurally equal (ignoring position info).""" - if type(a) != type(b): + if type(node_a) != type(node_b): return False - match a: + match node_a: case dict(): - keys_a = {k for k in a if k not in ('start', 'end', 'loc', 'range')} - keys_b = {k for k in b if k not in ('start', 'end', 'loc', 'range')} + keys_a = {key for key in node_a if key not in ('start', 'end', 'loc', 'range')} + keys_b = {key for key in node_b if key not in ('start', 'end', 'loc', 'range')} if keys_a != keys_b: return False - return all(nodes_equal(a[k], b[k]) for k in keys_a) + return all(nodes_equal(node_a[key], node_b[key]) for key in keys_a) case list(): - return len(a) == len(b) and all(nodes_equal(x, y) for x, y in zip(a, b)) - return a == b + return len(node_a) == len(node_b) and all(nodes_equal(x, y) for x, y in zip(node_a, node_b)) + return node_a == node_b diff --git a/pyjsclear/utils/string_decoders.py b/pyjsclear/utils/string_decoders.py index cb2b7fa..bbd428c 100644 --- a/pyjsclear/utils/string_decoders.py +++ b/pyjsclear/utils/string_decoders.py @@ -1,10 +1,9 @@ """String decoder implementations for obfuscator.io patterns.""" -from enum import Enum -from urllib.parse import unquote +from enum import StrEnum -class DecoderType(Enum): +class DecoderType(StrEnum): BASIC = 'basic' BASE_64 = 'base64' RC4 = 'rc4' @@ -13,7 +12,7 @@ class DecoderType(Enum): _BASE_64_ALPHABET = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+/=' -def base64_transform(encoded_string): +def base64_transform(encoded_string: str) -> str: """Decode obfuscator.io's custom base64 encoding.""" # Decode 4 base64 chars into 3 bytes using 6-bit groups. # bit_buffer accumulates bits; every non-first char in a group yields a byte @@ -39,19 +38,19 @@ def base64_transform(encoded_string): class StringDecoder: """Base string decoder.""" - def __init__(self, string_array, index_offset): + def __init__(self, string_array: list[str], index_offset: int) -> None: self.string_array = string_array self.index_offset = index_offset self.is_first_call = True @property - def type(self): + def type(self) -> DecoderType: return DecoderType.BASIC - def get_string(self, index, *args): + def get_string(self, index: int, *args) -> str | None: raise NotImplementedError - def get_string_for_rotation(self, index, *args, **kwargs): + def get_string_for_rotation(self, index: int, *args, **kwargs) -> str | None: if self.is_first_call: self.is_first_call = False raise RuntimeError('First call') @@ -62,10 +61,10 @@ class BasicStringDecoder(StringDecoder): """Simple array index + offset decoder.""" @property - def type(self): + def type(self) -> DecoderType: return DecoderType.BASIC - def get_string(self, index, *args): + def get_string(self, index: int, *args) -> str | None: array_index = index + self.index_offset if 0 <= array_index < len(self.string_array): return self.string_array[array_index] @@ -75,15 +74,15 @@ def get_string(self, index, *args): class Base64StringDecoder(StringDecoder): """Base64 string decoder.""" - def __init__(self, string_array, index_offset): + def __init__(self, string_array: list[str], index_offset: int) -> None: super().__init__(string_array, index_offset) - self._cache = {} + self._cache: dict[int, str] = {} @property - def type(self): + def type(self) -> DecoderType: return DecoderType.BASE_64 - def get_string(self, index, *args): + def get_string(self, index: int, *args) -> str | None: if index in self._cache: return self._cache[index] array_index = index + self.index_offset @@ -97,15 +96,15 @@ def get_string(self, index, *args): class Rc4StringDecoder(StringDecoder): """RC4 string decoder.""" - def __init__(self, string_array, index_offset): + def __init__(self, string_array: list[str], index_offset: int) -> None: super().__init__(string_array, index_offset) - self._cache = {} + self._cache: dict[tuple[int, str], str] = {} @property - def type(self): + def type(self) -> DecoderType: return DecoderType.RC4 - def get_string(self, index, key=None): + def get_string(self, index: int, key: str | None = None) -> str | None: if not key: return None # Include key in cache to avoid collisions with different RC4 keys @@ -120,7 +119,7 @@ def get_string(self, index, key=None): self._cache[cache_key] = decoded return decoded - def _rc4_decode(self, encoded_string, key): + def _rc4_decode(self, encoded_string: str, key: str) -> str: """RC4 decryption with base64 pre-processing.""" encoded_string = base64_transform(encoded_string) # KSA diff --git a/tests/conftest.py b/tests/conftest.py index fed2159..7b08d41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ """Root test configuration.""" +import pytest -def pytest_addoption(parser): + +def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption( '--update-snapshots', action='store_true', diff --git a/tests/fuzz/conftest_fuzz.py b/tests/fuzz/conftest_fuzz.py index 99553d7..2ce846f 100644 --- a/tests/fuzz/conftest_fuzz.py +++ b/tests/fuzz/conftest_fuzz.py @@ -1,12 +1,14 @@ """Shared helpers for fuzz targets.""" +import argparse import os import random -import struct import sys +import time +from typing import Any, Callable -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) MAX_INPUT_SIZE = 102_400 # 100KB cap @@ -14,72 +16,71 @@ # ESTree node types for synthetic AST generation _STATEMENT_TYPES = [ - "ExpressionStatement", - "VariableDeclaration", - "ReturnStatement", - "IfStatement", - "WhileStatement", - "ForStatement", - "BlockStatement", - "EmptyStatement", - "BreakStatement", - "ContinueStatement", - "ThrowStatement", + 'ExpressionStatement', + 'VariableDeclaration', + 'ReturnStatement', + 'IfStatement', + 'WhileStatement', + 'ForStatement', + 'BlockStatement', + 'EmptyStatement', + 'BreakStatement', + 'ContinueStatement', + 'ThrowStatement', ] _EXPRESSION_TYPES = [ - "Literal", - "Identifier", - "BinaryExpression", - "UnaryExpression", - "CallExpression", - "MemberExpression", - "AssignmentExpression", - "ConditionalExpression", - "ArrayExpression", - "ObjectExpression", - "FunctionExpression", - "ThisExpression", + 'Literal', + 'Identifier', + 'BinaryExpression', + 'UnaryExpression', + 'CallExpression', + 'MemberExpression', + 'AssignmentExpression', + 'ConditionalExpression', + 'ArrayExpression', + 'ObjectExpression', + 'FunctionExpression', + 'ThisExpression', ] _BINARY_OPS = [ - "+", - "-", - "*", - "/", - "%", - "==", - "!=", - "===", - "!==", - "<", - ">", - "<=", - ">=", - "&&", - "||", - "&", - "|", - "^", - "<<", - ">>", - ">>>", + '+', + '-', + '*', + '/', + '%', + '==', + '!=', + '===', + '!==', + '<', + '>', + '<=', + '>=', + '&&', + '||', + '&', + '|', + '^', + '<<', + '>>', + '>>>', ] -_UNARY_OPS = ["-", "+", "!", "~", "typeof", "void"] +_UNARY_OPS = ['-', '+', '!', '~', 'typeof', 'void'] -def bytes_to_js(data): +def bytes_to_js(data: bytes) -> str: """Decode bytes to a JS string with size limit.""" - text = data[:MAX_INPUT_SIZE].decode("utf-8", errors="replace") - return text + return data[:MAX_INPUT_SIZE].decode('utf-8', errors='replace') -def bytes_to_ast_dict(data, max_depth=5, max_children=4): +def bytes_to_ast_dict(data: bytes, max_depth: int = 5, max_children: int = 4) -> dict: """Build a synthetic ESTree AST dict from bytes for testing generator/traverser.""" - rng = random.Random(int.from_bytes(data[:8].ljust(8, b"\x00"), "little")) + rng = random.Random(int.from_bytes(data[:8].ljust(8, b'\x00'), 'little')) pos = 8 - def consume_byte(): + def consume_byte() -> int: nonlocal pos if pos < len(data): val = data[pos] @@ -87,214 +88,218 @@ def consume_byte(): return val return rng.randint(0, 255) - def make_literal(): - kind = consume_byte() % 4 - if kind == 0: - return {"type": "Literal", "value": consume_byte(), "raw": str(consume_byte())} - elif kind == 1: - return {"type": "Literal", "value": True, "raw": "true"} - elif kind == 2: - return {"type": "Literal", "value": None, "raw": "null"} - else: - return {"type": "Literal", "value": "fuzz", "raw": '"fuzz"'} - - def make_identifier(): - names = ["a", "b", "c", "x", "y", "foo", "bar", "_", "$"] - return {"type": "Identifier", "name": names[consume_byte() % len(names)]} - - def make_node(depth=0): + def make_literal() -> dict: + match consume_byte() % 4: + case 0: + return {'type': 'Literal', 'value': consume_byte(), 'raw': str(consume_byte())} + case 1: + return {'type': 'Literal', 'value': True, 'raw': 'true'} + case 2: + return {'type': 'Literal', 'value': None, 'raw': 'null'} + case _: + return {'type': 'Literal', 'value': 'fuzz', 'raw': '"fuzz"'} + + def make_identifier() -> dict: + names = ['a', 'b', 'c', 'x', 'y', 'foo', 'bar', '_', '$'] + return {'type': 'Identifier', 'name': names[consume_byte() % len(names)]} + + def make_node(depth: int = 0) -> dict: if depth >= max_depth: return make_literal() if consume_byte() % 2 == 0 else make_identifier() type_idx = consume_byte() - if type_idx % 3 == 0: - # Expression - expr_type = _EXPRESSION_TYPES[consume_byte() % len(_EXPRESSION_TYPES)] - if expr_type == "Literal": + if type_idx % 3 != 0: + return make_statement(depth) + + expr_type = _EXPRESSION_TYPES[consume_byte() % len(_EXPRESSION_TYPES)] + match expr_type: + case 'Literal': return make_literal() - elif expr_type == "Identifier": + case 'Identifier': return make_identifier() - elif expr_type == "BinaryExpression": + case 'BinaryExpression': return { - "type": "BinaryExpression", - "operator": _BINARY_OPS[consume_byte() % len(_BINARY_OPS)], - "left": make_node(depth + 1), - "right": make_node(depth + 1), + 'type': 'BinaryExpression', + 'operator': _BINARY_OPS[consume_byte() % len(_BINARY_OPS)], + 'left': make_node(depth + 1), + 'right': make_node(depth + 1), } - elif expr_type == "UnaryExpression": + case 'UnaryExpression': return { - "type": "UnaryExpression", - "operator": _UNARY_OPS[consume_byte() % len(_UNARY_OPS)], - "argument": make_node(depth + 1), - "prefix": True, + 'type': 'UnaryExpression', + 'operator': _UNARY_OPS[consume_byte() % len(_UNARY_OPS)], + 'argument': make_node(depth + 1), + 'prefix': True, } - elif expr_type == "CallExpression": + case 'CallExpression': num_args = consume_byte() % max_children return { - "type": "CallExpression", - "callee": make_node(depth + 1), - "arguments": [make_node(depth + 1) for _ in range(num_args)], + 'type': 'CallExpression', + 'callee': make_node(depth + 1), + 'arguments': [make_node(depth + 1) for _ in range(num_args)], } - elif expr_type == "MemberExpression": + case 'MemberExpression': computed = consume_byte() % 2 == 0 return { - "type": "MemberExpression", - "object": make_node(depth + 1), - "property": make_node(depth + 1), - "computed": computed, + 'type': 'MemberExpression', + 'object': make_node(depth + 1), + 'property': make_node(depth + 1), + 'computed': computed, } - elif expr_type == "AssignmentExpression": + case 'AssignmentExpression': return { - "type": "AssignmentExpression", - "operator": "=", - "left": make_identifier(), - "right": make_node(depth + 1), + 'type': 'AssignmentExpression', + 'operator': '=', + 'left': make_identifier(), + 'right': make_node(depth + 1), } - elif expr_type == "ConditionalExpression": + case 'ConditionalExpression': return { - "type": "ConditionalExpression", - "test": make_node(depth + 1), - "consequent": make_node(depth + 1), - "alternate": make_node(depth + 1), + 'type': 'ConditionalExpression', + 'test': make_node(depth + 1), + 'consequent': make_node(depth + 1), + 'alternate': make_node(depth + 1), } - elif expr_type == "ArrayExpression": + case 'ArrayExpression': num = consume_byte() % max_children return { - "type": "ArrayExpression", - "elements": [make_node(depth + 1) for _ in range(num)], + 'type': 'ArrayExpression', + 'elements': [make_node(depth + 1) for _ in range(num)], } - elif expr_type == "ObjectExpression": + case 'ObjectExpression': num = consume_byte() % max_children return { - "type": "ObjectExpression", - "properties": [ + 'type': 'ObjectExpression', + 'properties': [ { - "type": "Property", - "key": make_identifier(), - "value": make_node(depth + 1), - "kind": "init", - "computed": False, - "method": False, - "shorthand": False, + 'type': 'Property', + 'key': make_identifier(), + 'value': make_node(depth + 1), + 'kind': 'init', + 'computed': False, + 'method': False, + 'shorthand': False, } for _ in range(num) ], } - elif expr_type == "FunctionExpression": + case 'FunctionExpression': return { - "type": "FunctionExpression", - "id": None, - "params": [], - "body": { - "type": "BlockStatement", - "body": [make_statement(depth + 1) for _ in range(consume_byte() % 3)], + 'type': 'FunctionExpression', + 'id': None, + 'params': [], + 'body': { + 'type': 'BlockStatement', + 'body': [make_statement(depth + 1) for _ in range(consume_byte() % 3)], }, - "generator": False, - "async": False, + 'generator': False, + 'async': False, } - else: - return {"type": "ThisExpression"} - else: - return make_statement(depth) + case _: + return {'type': 'ThisExpression'} - def make_statement(depth=0): + def make_statement(depth: int = 0) -> dict: if depth >= max_depth: - return { - "type": "ExpressionStatement", - "expression": make_literal(), - } + return {'type': 'ExpressionStatement', 'expression': make_literal()} stmt_type = _STATEMENT_TYPES[consume_byte() % len(_STATEMENT_TYPES)] - if stmt_type == "ExpressionStatement": - return {"type": "ExpressionStatement", "expression": make_node(depth + 1)} - elif stmt_type == "VariableDeclaration": - return { - "type": "VariableDeclaration", - "declarations": [ - { - "type": "VariableDeclarator", - "id": make_identifier(), - "init": make_node(depth + 1) if consume_byte() % 2 == 0 else None, - } - ], - "kind": ["var", "let", "const"][consume_byte() % 3], - } - elif stmt_type == "ReturnStatement": - return {"type": "ReturnStatement", "argument": make_node(depth + 1) if consume_byte() % 2 == 0 else None} - elif stmt_type == "IfStatement": - return { - "type": "IfStatement", - "test": make_node(depth + 1), - "consequent": {"type": "BlockStatement", "body": [make_statement(depth + 1)]}, - "alternate": ( - {"type": "BlockStatement", "body": [make_statement(depth + 1)]} if consume_byte() % 2 == 0 else None - ), - } - elif stmt_type == "WhileStatement": - return { - "type": "WhileStatement", - "test": make_node(depth + 1), - "body": {"type": "BlockStatement", "body": [make_statement(depth + 1)]}, - } - elif stmt_type == "ForStatement": - return { - "type": "ForStatement", - "init": None, - "test": make_node(depth + 1), - "update": None, - "body": {"type": "BlockStatement", "body": [make_statement(depth + 1)]}, - } - elif stmt_type == "BlockStatement": - num = consume_byte() % max_children - return {"type": "BlockStatement", "body": [make_statement(depth + 1) for _ in range(num)]} - elif stmt_type == "EmptyStatement": - return {"type": "EmptyStatement"} - elif stmt_type == "BreakStatement": - return {"type": "BreakStatement", "label": None} - elif stmt_type == "ContinueStatement": - return {"type": "ContinueStatement", "label": None} - elif stmt_type == "ThrowStatement": - return {"type": "ThrowStatement", "argument": make_node(depth + 1)} - return {"type": "EmptyStatement"} + match stmt_type: + case 'ExpressionStatement': + return {'type': 'ExpressionStatement', 'expression': make_node(depth + 1)} + case 'VariableDeclaration': + return { + 'type': 'VariableDeclaration', + 'declarations': [ + { + 'type': 'VariableDeclarator', + 'id': make_identifier(), + 'init': make_node(depth + 1) if consume_byte() % 2 == 0 else None, + } + ], + 'kind': ['var', 'let', 'const'][consume_byte() % 3], + } + case 'ReturnStatement': + return { + 'type': 'ReturnStatement', + 'argument': make_node(depth + 1) if consume_byte() % 2 == 0 else None, + } + case 'IfStatement': + return { + 'type': 'IfStatement', + 'test': make_node(depth + 1), + 'consequent': {'type': 'BlockStatement', 'body': [make_statement(depth + 1)]}, + 'alternate': ( + {'type': 'BlockStatement', 'body': [make_statement(depth + 1)]} + if consume_byte() % 2 == 0 + else None + ), + } + case 'WhileStatement': + return { + 'type': 'WhileStatement', + 'test': make_node(depth + 1), + 'body': {'type': 'BlockStatement', 'body': [make_statement(depth + 1)]}, + } + case 'ForStatement': + return { + 'type': 'ForStatement', + 'init': None, + 'test': make_node(depth + 1), + 'update': None, + 'body': {'type': 'BlockStatement', 'body': [make_statement(depth + 1)]}, + } + case 'BlockStatement': + num = consume_byte() % max_children + return {'type': 'BlockStatement', 'body': [make_statement(depth + 1) for _ in range(num)]} + case 'EmptyStatement': + return {'type': 'EmptyStatement'} + case 'BreakStatement': + return {'type': 'BreakStatement', 'label': None} + case 'ContinueStatement': + return {'type': 'ContinueStatement', 'label': None} + case 'ThrowStatement': + return {'type': 'ThrowStatement', 'argument': make_node(depth + 1)} + case _: + return {'type': 'EmptyStatement'} num_stmts = max(1, consume_byte() % 6) return { - "type": "Program", - "body": [make_statement(0) for _ in range(num_stmts)], - "sourceType": "script", + 'type': 'Program', + 'body': [make_statement(0) for _ in range(num_stmts)], + 'sourceType': 'script', } class SimpleFuzzedDataProvider: """Minimal FuzzedDataProvider for when atheris is not available.""" - def __init__(self, data): + def __init__(self, data: bytes) -> None: self._data = data self._pos = 0 - def ConsumeUnicode(self, max_length): + def ConsumeUnicode(self, max_length: int) -> str: end = min(self._pos + max_length, len(self._data)) - chunk = self._data[self._pos : end] + chunk = self._data[self._pos:end] self._pos = end - return chunk.decode("utf-8", errors="replace") + return chunk.decode('utf-8', errors='replace') - def ConsumeBytes(self, max_length): + def ConsumeBytes(self, max_length: int) -> bytes: end = min(self._pos + max_length, len(self._data)) - chunk = self._data[self._pos : end] + chunk = self._data[self._pos:end] self._pos = end return chunk - def ConsumeIntInRange(self, min_val, max_val): + def ConsumeIntInRange(self, min_val: int, max_val: int) -> int: if self._pos < len(self._data): val = self._data[self._pos] self._pos += 1 return min_val + (val % (max_val - min_val + 1)) return min_val - def ConsumeBool(self): + def ConsumeBool(self) -> bool: return self.ConsumeIntInRange(0, 1) == 1 - def remaining_bytes(self): + def remaining_bytes(self) -> int: return len(self._data) - self._pos @@ -307,7 +312,11 @@ def remaining_bytes(self): FuzzedDataProvider = SimpleFuzzedDataProvider -def run_fuzzer(target_fn, argv=None, custom_setup=None): +def run_fuzzer( + target_fn: Callable[[bytes], None], + argv: list[str] | None = None, + custom_setup: Callable | None = None, +) -> None: """Run a fuzz target with atheris if available, otherwise with random inputs.""" if atheris is not None: if custom_setup: @@ -315,50 +324,46 @@ def run_fuzzer(target_fn, argv=None, custom_setup=None): custom_setup() atheris.Setup(argv or sys.argv, target_fn) atheris.Fuzz() - else: - # Standalone random-based fuzzing - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("corpus_dirs", nargs="*", default=[]) - parser.add_argument("-max_total_time", type=int, default=10) - parser.add_argument("-max_len", type=int, default=MAX_INPUT_SIZE) - parser.add_argument("-timeout", type=int, default=30) - parser.add_argument("-rss_limit_mb", type=int, default=2048) - parser.add_argument("-runs", type=int, default=0) - args = parser.parse_args(argv[1:] if argv else sys.argv[1:]) - - # First, run seed corpus files - seeds_run = 0 - for corpus_dir in args.corpus_dirs: - if os.path.isdir(corpus_dir): - for fname in sorted(os.listdir(corpus_dir)): - fpath = os.path.join(corpus_dir, fname) - if os.path.isfile(fpath): - with open(fpath, "rb") as f: - data = f.read() - try: - target_fn(data) - except Exception as e: - if not isinstance(e, SAFE_EXCEPTIONS): - print(f"FINDING in seed {fname}: {type(e).__name__}: {e}") - seeds_run += 1 - - # Then random inputs - import time - - rng = random.Random(42) - start = time.time() - runs = 0 - max_runs = args.runs if args.runs > 0 else float("inf") - while time.time() - start < args.max_total_time and runs < max_runs: - length = rng.randint(0, min(args.max_len, 4096)) - data = bytes(rng.randint(0, 255) for _ in range(length)) + return + + parser = argparse.ArgumentParser() + parser.add_argument('corpus_dirs', nargs='*', default=[]) + parser.add_argument('-max_total_time', type=int, default=10) + parser.add_argument('-max_len', type=int, default=MAX_INPUT_SIZE) + parser.add_argument('-timeout', type=int, default=30) + parser.add_argument('-rss_limit_mb', type=int, default=2048) + parser.add_argument('-runs', type=int, default=0) + args = parser.parse_args(argv[1:] if argv else sys.argv[1:]) + + seeds_run = 0 + for corpus_dir in args.corpus_dirs: + if not os.path.isdir(corpus_dir): + continue + for fname in sorted(os.listdir(corpus_dir)): + fpath = os.path.join(corpus_dir, fname) + if not os.path.isfile(fpath): + continue + with open(fpath, 'rb') as file_handle: + seed_data = file_handle.read() try: - target_fn(data) - except Exception as e: - if not isinstance(e, SAFE_EXCEPTIONS): - print(f"FINDING at run {runs}: {type(e).__name__}: {e}") - runs += 1 - - print(f"Fuzzing complete: {seeds_run} seeds + {runs} random inputs in {time.time() - start:.1f}s") + target_fn(seed_data) + except Exception as exc: + if not isinstance(exc, SAFE_EXCEPTIONS): + print(f'FINDING in seed {fname}: {type(exc).__name__}: {exc}') + seeds_run += 1 + + rng = random.Random(42) + start = time.time() + runs = 0 + max_runs = args.runs if args.runs > 0 else float('inf') + while time.time() - start < args.max_total_time and runs < max_runs: + length = rng.randint(0, min(args.max_len, 4096)) + random_data = bytes(rng.randint(0, 255) for _ in range(length)) + try: + target_fn(random_data) + except Exception as exc: + if not isinstance(exc, SAFE_EXCEPTIONS): + print(f'FINDING at run {runs}: {type(exc).__name__}: {exc}') + runs += 1 + + print(f'Fuzzing complete: {seeds_run} seeds + {runs} random inputs in {time.time() - start:.1f}s') diff --git a/tests/fuzz/fuzz_deobfuscate.py b/tests/fuzz/fuzz_deobfuscate.py index 5a950f3..5d2e243 100755 --- a/tests/fuzz/fuzz_deobfuscate.py +++ b/tests/fuzz/fuzz_deobfuscate.py @@ -9,7 +9,7 @@ import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from conftest_fuzz import SAFE_EXCEPTIONS from conftest_fuzz import bytes_to_js @@ -18,7 +18,7 @@ from pyjsclear import deobfuscate -def TestOneInput(data): +def TestOneInput(data: bytes) -> None: if len(data) < 2: return @@ -29,11 +29,9 @@ def TestOneInput(data): except SAFE_EXCEPTIONS: return - # Core safety guarantee: result must never be None - assert result is not None, "deobfuscate() returned None" - # Result must be a string - assert isinstance(result, str), f"deobfuscate() returned {type(result)}, expected str" + assert result is not None, 'deobfuscate() returned None' + assert isinstance(result, str), f'deobfuscate() returned {type(result)}, expected str' -if __name__ == "__main__": +if __name__ == '__main__': run_fuzzer(TestOneInput) diff --git a/tests/fuzz/fuzz_expression_simplifier.py b/tests/fuzz/fuzz_expression_simplifier.py index 001ca9f..f3b8438 100755 --- a/tests/fuzz/fuzz_expression_simplifier.py +++ b/tests/fuzz/fuzz_expression_simplifier.py @@ -9,7 +9,7 @@ import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from conftest_fuzz import SAFE_EXCEPTIONS from conftest_fuzz import bytes_to_js @@ -20,7 +20,7 @@ from pyjsclear.transforms.expression_simplifier import ExpressionSimplifier -def TestOneInput(data): +def TestOneInput(data: bytes) -> None: if len(data) < 2: return @@ -42,8 +42,8 @@ def TestOneInput(data): except SAFE_EXCEPTIONS: return - assert isinstance(result, str), f"generate() returned {type(result)} after ExpressionSimplifier" + assert isinstance(result, str), f'generate() returned {type(result)} after ExpressionSimplifier' -if __name__ == "__main__": +if __name__ == '__main__': run_fuzzer(TestOneInput) diff --git a/tests/fuzz/fuzz_generator.py b/tests/fuzz/fuzz_generator.py index be2a47a..1f3d228 100755 --- a/tests/fuzz/fuzz_generator.py +++ b/tests/fuzz/fuzz_generator.py @@ -10,7 +10,7 @@ import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from conftest_fuzz import SAFE_EXCEPTIONS from conftest_fuzz import FuzzedDataProvider @@ -22,7 +22,7 @@ from pyjsclear.parser import parse -def TestOneInput(data): +def TestOneInput(data: bytes) -> None: if len(data) < 4: return @@ -30,7 +30,7 @@ def TestOneInput(data): mode = fdp.ConsumeIntInRange(0, 1) if mode == 0: - # Roundtrip mode: parse valid JS then generate + # Roundtrip: parse then generate code = bytes_to_js(fdp.ConsumeBytes(fdp.remaining_bytes())) try: ast = parse(code) @@ -42,22 +42,21 @@ def TestOneInput(data): except SAFE_EXCEPTIONS: return - assert isinstance(result, str), f"generate() returned {type(result)}, expected str" + assert isinstance(result, str), f'generate() returned {type(result)}, expected str' else: - # Synthetic AST mode: test with malformed input + # Synthetic AST: test with malformed input remaining = fdp.ConsumeBytes(fdp.remaining_bytes()) ast = bytes_to_ast_dict(remaining) try: result = generate(ast) except (KeyError, TypeError, AttributeError, ValueError): - # Expected for malformed ASTs return except SAFE_EXCEPTIONS: return - assert isinstance(result, str), f"generate() returned {type(result)}, expected str" + assert isinstance(result, str), f'generate() returned {type(result)}, expected str' -if __name__ == "__main__": +if __name__ == '__main__': run_fuzzer(TestOneInput) diff --git a/tests/fuzz/fuzz_parser.py b/tests/fuzz/fuzz_parser.py index 40aa615..1ff4275 100755 --- a/tests/fuzz/fuzz_parser.py +++ b/tests/fuzz/fuzz_parser.py @@ -5,7 +5,7 @@ import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from conftest_fuzz import SAFE_EXCEPTIONS from conftest_fuzz import bytes_to_js @@ -14,7 +14,7 @@ from pyjsclear.parser import parse -def TestOneInput(data): +def TestOneInput(data: bytes) -> None: if len(data) < 1: return @@ -23,15 +23,13 @@ def TestOneInput(data): try: result = parse(code) except SyntaxError: - # Expected for invalid JS return except SAFE_EXCEPTIONS: return - # Successful parse must return a Program or Module dict - assert isinstance(result, dict), f"parse() returned {type(result)}, expected dict" - assert result.get("type") in ("Program", "Module"), f"Unexpected root type: {result.get('type')}" + assert isinstance(result, dict), f'parse() returned {type(result)}, expected dict' + assert result.get('type') in ('Program', 'Module'), f"Unexpected root type: {result.get('type')}" -if __name__ == "__main__": +if __name__ == '__main__': run_fuzzer(TestOneInput) diff --git a/tests/fuzz/fuzz_scope.py b/tests/fuzz/fuzz_scope.py index 21cb853..8abb477 100755 --- a/tests/fuzz/fuzz_scope.py +++ b/tests/fuzz/fuzz_scope.py @@ -5,7 +5,7 @@ import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from conftest_fuzz import SAFE_EXCEPTIONS from conftest_fuzz import bytes_to_js @@ -15,7 +15,7 @@ from pyjsclear.scope import build_scope_tree -def TestOneInput(data): +def TestOneInput(data: bytes) -> None: if len(data) < 2: return @@ -31,9 +31,9 @@ def TestOneInput(data): except SAFE_EXCEPTIONS: return - assert root_scope is not None, "build_scope_tree returned None root_scope" - assert isinstance(node_scope_map, dict), f"node_scope_map is {type(node_scope_map)}, expected dict" + assert root_scope is not None, 'build_scope_tree returned None root_scope' + assert isinstance(node_scope_map, dict), f'node_scope_map is {type(node_scope_map)}, expected dict' -if __name__ == "__main__": +if __name__ == '__main__': run_fuzzer(TestOneInput) diff --git a/tests/fuzz/fuzz_string_decoders.py b/tests/fuzz/fuzz_string_decoders.py index 25647a3..4b45754 100755 --- a/tests/fuzz/fuzz_string_decoders.py +++ b/tests/fuzz/fuzz_string_decoders.py @@ -9,7 +9,7 @@ import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from conftest_fuzz import SAFE_EXCEPTIONS from conftest_fuzz import FuzzedDataProvider @@ -21,71 +21,65 @@ from pyjsclear.utils.string_decoders import base64_transform -def TestOneInput(data): +def TestOneInput(data: bytes) -> None: if len(data) < 8: return fdp = FuzzedDataProvider(data) decoder_choice = fdp.ConsumeIntInRange(0, 3) - if decoder_choice == 0: - # Test base64_transform with random strings - encoded = fdp.ConsumeUnicode(1024) - try: - result = base64_transform(encoded) - except SAFE_EXCEPTIONS: - return - assert isinstance(result, str), f"base64_transform returned {type(result)}" - - elif decoder_choice == 1: - # Test BasicStringDecoder - num_strings = fdp.ConsumeIntInRange(0, 20) - string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)] - offset = fdp.ConsumeIntInRange(-5, 5) - decoder = BasicStringDecoder(string_array, offset) - if num_strings > 0: - idx = fdp.ConsumeIntInRange(-2, num_strings * 2) + match decoder_choice: + case 0: + encoded = fdp.ConsumeUnicode(1024) try: - result = decoder.get_string(idx) + result = base64_transform(encoded) except SAFE_EXCEPTIONS: return - # None is valid for out-of-range indices - if result is not None: - assert isinstance(result, str), f"BasicStringDecoder returned {type(result)}" + assert isinstance(result, str), f'base64_transform returned {type(result)}' - elif decoder_choice == 2: - # Test Base64StringDecoder - num_strings = fdp.ConsumeIntInRange(0, 20) - string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)] - offset = fdp.ConsumeIntInRange(-5, 5) - decoder = Base64StringDecoder(string_array, offset) - if num_strings > 0: - idx = fdp.ConsumeIntInRange(-2, num_strings * 2) - try: - result = decoder.get_string(idx) - except SAFE_EXCEPTIONS: - return - # None is valid for out-of-range indices - if result is not None: - assert isinstance(result, str), f"Base64StringDecoder returned {type(result)}" + case 1: + num_strings = fdp.ConsumeIntInRange(0, 20) + string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)] + offset = fdp.ConsumeIntInRange(-5, 5) + decoder = BasicStringDecoder(string_array, offset) + if num_strings > 0: + idx = fdp.ConsumeIntInRange(-2, num_strings * 2) + try: + result = decoder.get_string(idx) + except SAFE_EXCEPTIONS: + return + if result is not None: + assert isinstance(result, str), f'BasicStringDecoder returned {type(result)}' - elif decoder_choice == 3: - # Test Rc4StringDecoder - potential ZeroDivisionError with empty key - num_strings = fdp.ConsumeIntInRange(0, 20) - string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)] - offset = fdp.ConsumeIntInRange(-5, 5) - decoder = Rc4StringDecoder(string_array, offset) - if num_strings > 0: - idx = fdp.ConsumeIntInRange(-2, num_strings * 2) - key = fdp.ConsumeUnicode(32) # May be empty - tests empty key guard - try: - result = decoder.get_string(idx, key=key) - except SAFE_EXCEPTIONS: - return - # None is valid for out-of-range or None key - if result is not None: - assert isinstance(result, str), f"Rc4StringDecoder returned {type(result)}" + case 2: + num_strings = fdp.ConsumeIntInRange(0, 20) + string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)] + offset = fdp.ConsumeIntInRange(-5, 5) + decoder = Base64StringDecoder(string_array, offset) + if num_strings > 0: + idx = fdp.ConsumeIntInRange(-2, num_strings * 2) + try: + result = decoder.get_string(idx) + except SAFE_EXCEPTIONS: + return + if result is not None: + assert isinstance(result, str), f'Base64StringDecoder returned {type(result)}' + + case 3: + num_strings = fdp.ConsumeIntInRange(0, 20) + string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)] + offset = fdp.ConsumeIntInRange(-5, 5) + decoder = Rc4StringDecoder(string_array, offset) + if num_strings > 0: + idx = fdp.ConsumeIntInRange(-2, num_strings * 2) + key = fdp.ConsumeUnicode(32) # may be empty - tests empty key guard + try: + result = decoder.get_string(idx, key=key) + except SAFE_EXCEPTIONS: + return + if result is not None: + assert isinstance(result, str), f'Rc4StringDecoder returned {type(result)}' -if __name__ == "__main__": +if __name__ == '__main__': run_fuzzer(TestOneInput) diff --git a/tests/fuzz/fuzz_transforms.py b/tests/fuzz/fuzz_transforms.py index f5b0eeb..ac9378e 100755 --- a/tests/fuzz/fuzz_transforms.py +++ b/tests/fuzz/fuzz_transforms.py @@ -9,7 +9,7 @@ import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from conftest_fuzz import SAFE_EXCEPTIONS from conftest_fuzz import FuzzedDataProvider @@ -54,7 +54,7 @@ ] -def TestOneInput(data): +def TestOneInput(data: bytes) -> None: if len(data) < 4: return @@ -84,5 +84,5 @@ def TestOneInput(data): return -if __name__ == "__main__": +if __name__ == '__main__': run_fuzzer(TestOneInput) diff --git a/tests/fuzz/fuzz_traverser.py b/tests/fuzz/fuzz_traverser.py index 64ed65e..b076fd1 100755 --- a/tests/fuzz/fuzz_traverser.py +++ b/tests/fuzz/fuzz_traverser.py @@ -9,7 +9,7 @@ import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) from conftest_fuzz import SAFE_EXCEPTIONS from conftest_fuzz import FuzzedDataProvider @@ -25,7 +25,7 @@ MAX_VISITED = 10_000 -def TestOneInput(data): +def TestOneInput(data: bytes) -> None: if len(data) < 8: return @@ -36,55 +36,53 @@ def TestOneInput(data): visited = 0 - if mode == 0: - # traverse with enter that sometimes returns SKIP - action_byte = remaining[0] if remaining else 0 - - def enter(node, parent, key, index): - nonlocal visited - visited += 1 - if visited > MAX_VISITED: - return SKIP - if isinstance(node, dict) and node.get("type") == "Literal" and action_byte % 3 == 0: - return SKIP - return None - - try: - traverse(ast, {"enter": enter}) - except SAFE_EXCEPTIONS: - return - - elif mode == 1: - # traverse with enter that returns REMOVE for some nodes - action_byte = remaining[1] if len(remaining) > 1 else 0 - - def enter(node, parent, key, index): - nonlocal visited - visited += 1 - if visited > MAX_VISITED: - return SKIP - if isinstance(node, dict) and node.get("type") == "EmptyStatement" and action_byte % 2 == 0: - return REMOVE - return None - - try: - traverse(ast, {"enter": enter}) - except SAFE_EXCEPTIONS: - return - - elif mode == 2: - # simple_traverse - just visit all nodes - def callback(node, parent): - nonlocal visited - visited += 1 - if visited > MAX_VISITED: - raise StopIteration("too many nodes") - - try: - simple_traverse(ast, callback) - except (StopIteration, SAFE_EXCEPTIONS[0]): - return - - -if __name__ == "__main__": + match mode: + case 0: + action_byte = remaining[0] if remaining else 0 + + def enter(node, parent, key, index): + nonlocal visited + visited += 1 + if visited > MAX_VISITED: + return SKIP + if isinstance(node, dict) and node.get('type') == 'Literal' and action_byte % 3 == 0: + return SKIP + return None + + try: + traverse(ast, {'enter': enter}) + except SAFE_EXCEPTIONS: + return + + case 1: + action_byte = remaining[1] if len(remaining) > 1 else 0 + + def enter(node, parent, key, index): + nonlocal visited + visited += 1 + if visited > MAX_VISITED: + return SKIP + if isinstance(node, dict) and node.get('type') == 'EmptyStatement' and action_byte % 2 == 0: + return REMOVE + return None + + try: + traverse(ast, {'enter': enter}) + except SAFE_EXCEPTIONS: + return + + case 2: + def callback(node, parent): + nonlocal visited + visited += 1 + if visited > MAX_VISITED: + raise StopIteration('too many nodes') + + try: + simple_traverse(ast, callback) + except (StopIteration, SAFE_EXCEPTIONS[0]): + return + + +if __name__ == '__main__': run_fuzzer(TestOneInput) diff --git a/tests/test_regression.py b/tests/test_regression.py index 304abed..6fb8703 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -13,7 +13,6 @@ - Cross-cutting quality invariants """ -import os import re from pathlib import Path @@ -31,17 +30,17 @@ RE_HEX_NUMERIC = re.compile(r'\b0x[0-9a-fA-F]+\b') -def _deobfuscate(filename): +def _deobfuscate(filename: str) -> tuple[str, str]: code = (SAMPLES_DIR / filename).read_text() result = pyjsclear.deobfuscate(code) return code, result -def _count_0x(text): +def _count_0x(text: str) -> int: return len(RE_0X.findall(text)) -def _count_hex(text): +def _count_hex(text: str) -> int: return len(RE_HEX.findall(text)) @@ -176,7 +175,7 @@ def test_animatedflatlist_property_simplifier(self): assert result != code, 'PropertySimplifier should transform the code' assert len(result) < len(code), 'Output should be smaller (bracket -> dot)' # String decode should NOT fire — the array literal should still be present - assert "createAnimatedComponent" in result + assert 'createAnimatedComponent' in result def test_animatedimage_2element_array_no_decode(self): """AnimatedImage: 2-element array — below Strategy 2b threshold. @@ -187,8 +186,8 @@ def test_animatedimage_2element_array_no_decode(self): code, result = _deobfuscate('AnimatedImage-obfuscated.js') assert result != code, 'PropertySimplifier should still fire' # String decode should NOT fire — the array literal should still be present - assert "createAnimatedComponent" in result - assert "exports" in result + assert 'createAnimatedComponent' in result + assert 'exports' in result # ================================================================ @@ -525,7 +524,7 @@ def test_multiple_decoders_clean_output(self): # ================================================================ -def _deobfuscate_resource(filename): +def _deobfuscate_resource(filename: str) -> tuple[str, str]: """Load and deobfuscate a file from tests/resources/.""" code = (RESOURCES_DIR / filename).read_text() result = pyjsclear.deobfuscate(code) @@ -579,7 +578,7 @@ def test_no_extremely_long_lines(self, sample_result): def test_few_long_lines(self, sample_result): _, result = sample_result - long_lines = sum(1 for l in result.splitlines() if len(l) > 500) + long_lines = sum(1 for line in result.splitlines() if len(line) > 500) assert long_lines <= 5, f'{long_lines} lines > 500 chars' @@ -698,7 +697,7 @@ class TestSampleRegressionGuards: def test_no_proxy_inliner_blowup(self, sample_result): """OptionalChaining + ProxyFunctionInliner interaction guard.""" _, result = sample_result - max_line_len = max(len(l) for l in result.splitlines()) + max_line_len = max(len(line) for line in result.splitlines()) assert max_line_len < 2000, f'Max line length {max_line_len} suggests proxy inliner blowup' def test_helper_functions_preserved(self, sample_result): @@ -756,19 +755,19 @@ class TestQualityInvariants: def test_no_empty_output(self): """No sample should produce empty output.""" - for f in SAMPLES_DIR.glob('*.js'): - code = f.read_text() + for sample_file in SAMPLES_DIR.glob('*.js'): + code = sample_file.read_text() result = pyjsclear.deobfuscate(code) - assert len(result.strip()) > 0, f'{f.name} produced empty output' + assert len(result.strip()) > 0, f'{sample_file.name} produced empty output' def test_no_hex_increase(self): """Deobfuscation should never introduce new hex escapes.""" - for f in SAMPLES_DIR.glob('*.js'): - code = f.read_text() + for sample_file in SAMPLES_DIR.glob('*.js'): + code = sample_file.read_text() result = pyjsclear.deobfuscate(code) assert _count_hex(result) <= _count_hex( code - ), f'{f.name}: hex escapes increased from {_count_hex(code)} to {_count_hex(result)}' + ), f'{sample_file.name}: hex escapes increased from {_count_hex(code)} to {_count_hex(result)}' def test_output_not_larger_than_input(self): """Deobfuscated output should never be larger than the input. @@ -776,11 +775,11 @@ def test_output_not_larger_than_input(self): Deobfuscation removes string arrays, dead code, and infrastructure. If output grows, something is wrong (e.g., proxy inlining blowup). """ - for f in SAMPLES_DIR.glob('*.js'): - code = f.read_text() + for sample_file in SAMPLES_DIR.glob('*.js'): + code = sample_file.read_text() result = pyjsclear.deobfuscate(code) assert len(result) <= len(code) * 1.1, ( - f'{f.name}: output ({len(result)}) > 110% of input ({len(code)}). ' + f'{sample_file.name}: output ({len(result)}) > 110% of input ({len(code)}). ' f'Ratio: {len(result)/len(code):.2f}' ) @@ -790,10 +789,8 @@ def test_output_parseable(self): If the output doesn't parse, a transform likely corrupted the AST. We skip files whose input doesn't parse (ES modules with import). """ - from pyjsclear.parser import parse - - for f in SAMPLES_DIR.glob('*.js'): - code = f.read_text() + for sample_file in SAMPLES_DIR.glob('*.js'): + code = sample_file.read_text() # Skip files that don't parse as input (ES modules) try: parse(code) @@ -802,8 +799,8 @@ def test_output_parseable(self): result = pyjsclear.deobfuscate(code) try: parse(result) - except SyntaxError as e: - pytest.fail(f'{f.name}: output does not parse: {e}') + except SyntaxError as error: + pytest.fail(f'{sample_file.name}: output does not parse: {error}') def test_no_extremely_long_lines(self): """No output line should exceed 5000 chars. @@ -812,10 +809,10 @@ def test_no_extremely_long_lines(self): or other expansion bugs. The limit is generous (5000) to accommodate files with legitimately long array literals. """ - for f in SAMPLES_DIR.glob('*.js'): - code = f.read_text() + for sample_file in SAMPLES_DIR.glob('*.js'): + code = sample_file.read_text() result = pyjsclear.deobfuscate(code) for i, line in enumerate(result.splitlines(), 1): assert len(line) <= 5000, ( - f'{f.name} line {i}: {len(line)} chars (max 5000). ' f'Preview: {line[:80]}...' + f'{sample_file.name} line {i}: {len(line)} chars (max 5000). ' f'Preview: {line[:80]}...' ) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2b16919..d188766 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,10 +1,12 @@ """Shared test helpers for pyjsclear unit tests.""" +from typing import Any + from pyjsclear.generator import generate from pyjsclear.parser import parse -def roundtrip(js_code, transform_class): +def roundtrip(js_code: str, transform_class: type) -> tuple[str, bool]: """Parse JS, apply a transform, return (generated_code, changed).""" ast = parse(js_code) t = transform_class(ast) @@ -12,12 +14,12 @@ def roundtrip(js_code, transform_class): return generate(ast), changed -def parse_expr(js_expr): +def parse_expr(js_expr: str) -> dict[str, Any]: """Parse a JS expression and return the expression AST node.""" ast = parse(js_expr + ';') return ast['body'][0]['expression'] -def normalize(js_code): +def normalize(js_code: str) -> str: """Collapse all whitespace to single spaces for comparison.""" return ' '.join(js_code.split()) diff --git a/tests/unit/deobfuscator_test.py b/tests/unit/deobfuscator_test.py index c5e4bfe..3c4e9d5 100644 --- a/tests/unit/deobfuscator_test.py +++ b/tests/unit/deobfuscator_test.py @@ -10,6 +10,9 @@ from pyjsclear.deobfuscator import TRANSFORM_CLASSES from pyjsclear.deobfuscator import Deobfuscator from pyjsclear.deobfuscator import _count_nodes +from pyjsclear.transforms.control_flow import ControlFlowRecoverer +from pyjsclear.transforms.proxy_functions import ProxyFunctionInliner +from pyjsclear.transforms.string_revealer import StringRevealer class TestTransformClasses: @@ -20,14 +23,10 @@ def test_transform_classes_length(self): assert len(TRANSFORM_CLASSES) == 37 def test_string_revealer_appears_twice(self): - from pyjsclear.transforms.string_revealer import StringRevealer - occurrences = [cls for cls in TRANSFORM_CLASSES if cls is StringRevealer] assert len(occurrences) == 2 def test_string_revealer_is_first_and_last(self): - from pyjsclear.transforms.string_revealer import StringRevealer - assert TRANSFORM_CLASSES[0] is StringRevealer assert TRANSFORM_CLASSES[-1] is StringRevealer @@ -261,8 +260,6 @@ def test_lite_mode_strips_expensive_transforms(self, mock_transforms, mock_gener mock_ast = MagicMock() mock_parse.return_value = mock_ast - from pyjsclear.transforms.control_flow import ControlFlowRecoverer - # One cheap transform that changes, one expensive that should be skipped cheap_instance = MagicMock() cheap_instance.execute.return_value = False @@ -292,8 +289,6 @@ def test_large_node_count_strips_expensive_transforms(self, mock_transforms, moc cheap_instance.execute.return_value = False cheap_transform = MagicMock(return_value=cheap_instance) - from pyjsclear.transforms.proxy_functions import ProxyFunctionInliner - mock_transforms.__iter__ = lambda self: iter([cheap_transform, ProxyFunctionInliner]) # Code > _LARGE_FILE_SIZE to trigger node counting, but < _MAX_CODE_SIZE (not lite mode) diff --git a/tests/unit/generator_test.py b/tests/unit/generator_test.py index 79a6e75..17f35a4 100644 --- a/tests/unit/generator_test.py +++ b/tests/unit/generator_test.py @@ -2,6 +2,9 @@ import pytest +from pyjsclear.generator import _expr_precedence +from pyjsclear.generator import _gen_property +from pyjsclear.generator import _gen_stmt from pyjsclear.generator import generate from pyjsclear.parser import parse @@ -1497,8 +1500,6 @@ class TestGenStmtNoneNode: """Line 113: _gen_stmt with None node returns ''.""" def test_gen_stmt_none(self): - from pyjsclear.generator import _gen_stmt - assert _gen_stmt(None, 0) == '' @@ -1509,8 +1510,6 @@ def test_statement_ending_with_semicolon(self): # EmptyStatement generates ';' and is in _NO_SEMI_TYPES, # but we can construct a node whose generate() output ends with ';' # that is NOT in _NO_SEMI_TYPES. Use a manual approach. - from pyjsclear.generator import _gen_stmt - # Create a fake node type that generates code ending with ';' # A VariableDeclaration ending with ';' (by appending manually) # Actually, let's just test the path: _gen_stmt appends ';' only if code doesn't already end with it @@ -1625,8 +1624,6 @@ class TestRestElementInProperty: def test_gen_property_rest_element(self): # _gen_property generates key: value for a Property node - from pyjsclear.generator import _gen_property - node = { 'type': 'Property', 'key': {'type': 'Identifier', 'name': 'a'}, @@ -1822,8 +1819,6 @@ def test_assignment_expression_precedence(self): def test_yield_expression_precedence_in_binary(self): """YieldExpression has precedence 2, needs parens in binary context.""" - from pyjsclear.generator import _expr_precedence - yield_node = {'type': 'YieldExpression', 'argument': _id('x'), 'delegate': False} assert _expr_precedence(yield_node) == 2 @@ -1849,19 +1844,13 @@ def test_nested_precedence_conditional_in_assignment(self): assert 'x = a ? b : c' in result def test_arrow_function_precedence(self): - from pyjsclear.generator import _expr_precedence - arrow_node = {'type': 'ArrowFunctionExpression'} assert _expr_precedence(arrow_node) == 3 def test_unknown_type_precedence(self): - from pyjsclear.generator import _expr_precedence - node = {'type': 'SomeUnknownExpression'} assert _expr_precedence(node) == 0 def test_non_dict_precedence(self): - from pyjsclear.generator import _expr_precedence - assert _expr_precedence(42) == 20 assert _expr_precedence('str') == 20 diff --git a/tests/unit/transforms/aa_decode_test.py b/tests/unit/transforms/aa_decode_test.py index 8404858..24c0f52 100644 --- a/tests/unit/transforms/aa_decode_test.py +++ b/tests/unit/transforms/aa_decode_test.py @@ -2,8 +2,8 @@ import pytest -from pyjsclear.transforms.aa_decode import is_aa_encoded from pyjsclear.transforms.aa_decode import aa_decode +from pyjsclear.transforms.aa_decode import is_aa_encoded class TestIsAAEncoded: @@ -42,12 +42,12 @@ def test_synthetic_simple(self): """Synthetic AAEncode for 'Hi' (H=110 octal, i=151 octal). This builds a minimal AAEncoded payload that the decoder can parse. - Note: real AAEncode uses U+FF70 (\uff70 halfwidth), NOT U+30FC (fullwidth). + Note: real AAEncode uses U+FF70 (\\uff70 halfwidth), NOT U+30FC (fullwidth). """ # H = 0x48 = 110 octal, i = 0x69 = 151 octal - # Digit 1 = (\uff9f\uff70\uff9f), Digit 0 = (c^_^o), - # Digit 5 = ((\uff9f\uff70\uff9f) + (\uff9f\uff70\uff9f) + (\uff9f\u0398\uff9f)) - sep = '(\uff9f\u0414\uff9f)[\uff9f\u03b5\uff9f]+' + # Digit 1 = (\\uff9f\\uff70\\uff9f), Digit 0 = (c^_^o), + # Digit 5 = ((\\uff9f\\uff70\\uff9f) + (\\uff9f\\uff70\\uff9f) + (\\uff9f\\u0398\\uff9f)) + separator = '(\uff9f\u0414\uff9f)[\uff9f\u03b5\uff9f]+' h_digits = '(\uff9f\uff70\uff9f)+(\uff9f\uff70\uff9f)+(c^_^o)' # 1 1 0 i_digits = ( '(\uff9f\uff70\uff9f)+' @@ -55,7 +55,7 @@ def test_synthetic_simple(self): '(\uff9f\uff70\uff9f)' ) # 1 5 1 - data = sep + h_digits + sep + i_digits + data = separator + h_digits + separator + i_digits # Add execution wrapper with the signature code = data + "(\uff9f\u0414\uff9f)['_'](\uff9f\u0398\uff9f)" diff --git a/tests/unit/transforms/anti_tamper_test.py b/tests/unit/transforms/anti_tamper_test.py index ed51f40..06c19b6 100644 --- a/tests/unit/transforms/anti_tamper_test.py +++ b/tests/unit/transforms/anti_tamper_test.py @@ -2,6 +2,7 @@ import pytest +from pyjsclear.parser import parse from pyjsclear.transforms.anti_tamper import AntiTamperRemover from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip @@ -101,10 +102,10 @@ def test_multiple_statements_only_anti_tamper_removed(self): code = 'var a = 1;(function() { console["log"] = function(){}; })();var b = 2;' result, changed = roundtrip(code, AntiTamperRemover) assert changed is True - norm = normalize(result) - assert 'var a = 1' in norm - assert 'var b = 2' in norm - assert 'console' not in norm + normalized = normalize(result) + assert 'var a = 1' in normalized + assert 'var b = 2' in normalized + assert 'console' not in normalized class TestAntiTamperRemoverEdgeCases: @@ -139,8 +140,7 @@ def test_debug_protection_setInterval(self): def test_exception_during_generate(self): """Lines 87-88: Exception during generate() should be caught gracefully.""" - # This is hard to trigger directly via roundtrip since we'd need a malformed AST. - # We test indirectly that normal IIFE processing doesn't crash. + # Hard to trigger directly; test that normal IIFE processing doesn't crash code = '(function() { var x = 1; var y = 2; })();' result, changed = roundtrip(code, AntiTamperRemover) assert changed is False @@ -154,32 +154,26 @@ def test_debugger_with_setInterval_removed(self): def test_expression_statement_no_expression(self): """Line 70: ExpressionStatement with expression set to None.""" - from pyjsclear.parser import parse - ast = parse('a();') # Manually set expression to None to trigger early return ast['body'][0]['expression'] = None - t = AntiTamperRemover(ast) - changed = t.execute() + transform = AntiTamperRemover(ast) + changed = transform.execute() assert changed is False - def test_call_without_callee(self): + def test_call_without_callee_ast(self): """Line 78: Call node without callee.""" - from pyjsclear.parser import parse - ast = parse('(function() { x(); })();') # Find the outer CallExpression and remove its callee call = ast['body'][0]['expression'] if call.get('type') == 'CallExpression': call['callee'] = None - t = AntiTamperRemover(ast) - changed = t.execute() + transform = AntiTamperRemover(ast) + changed = transform.execute() assert changed is False def test_exception_during_generate_malformed_callee(self): """Lines 87-88: Exception during generate() with malformed AST.""" - from pyjsclear.parser import parse - ast = parse('(function() { x(); })();') # Find the IIFE callee (FunctionExpression) and corrupt its body call = ast['body'][0]['expression'] @@ -188,7 +182,7 @@ def test_exception_during_generate_malformed_callee(self): if callee and callee.get('type') == 'FunctionExpression': # Corrupt the body to make generate() raise callee['body'] = 'not_a_valid_body' - t = AntiTamperRemover(ast) - changed = t.execute() + transform = AntiTamperRemover(ast) + changed = transform.execute() # Should not crash, just skip the node assert changed is False diff --git a/tests/unit/transforms/base_test.py b/tests/unit/transforms/base_test.py index 4668bfe..5752f3b 100644 --- a/tests/unit/transforms/base_test.py +++ b/tests/unit/transforms/base_test.py @@ -6,50 +6,50 @@ class TestTransformInit: def test_stores_ast(self): ast = {'type': 'Program'} - t = Transform(ast) - assert t.ast is ast + transform = Transform(ast) + assert transform.ast is ast def test_stores_scope_tree(self): scope_tree = {'root': True} - t = Transform('ast', scope_tree=scope_tree) - assert t.scope_tree is scope_tree + transform = Transform('ast', scope_tree=scope_tree) + assert transform.scope_tree is scope_tree def test_stores_node_scope(self): node_scope = {'node': 'scope'} - t = Transform('ast', node_scope=node_scope) - assert t.node_scope is node_scope + transform = Transform('ast', node_scope=node_scope) + assert transform.node_scope is node_scope def test_scope_tree_defaults_to_none(self): - t = Transform('ast') - assert t.scope_tree is None + transform = Transform('ast') + assert transform.scope_tree is None def test_node_scope_defaults_to_none(self): - t = Transform('ast') - assert t.node_scope is None + transform = Transform('ast') + assert transform.node_scope is None class TestTransformExecute: def test_raises_not_implemented(self): - t = Transform('ast') + transform = Transform('ast') with pytest.raises(NotImplementedError): - t.execute() + transform.execute() class TestTransformChangedTracking: def test_has_changed_initially_false(self): - t = Transform('ast') - assert t.has_changed() is False + transform = Transform('ast') + assert transform.has_changed() is False def test_set_changed_makes_has_changed_true(self): - t = Transform('ast') - t.set_changed() - assert t.has_changed() is True + transform = Transform('ast') + transform.set_changed() + assert transform.has_changed() is True def test_set_changed_is_idempotent(self): - t = Transform('ast') - t.set_changed() - t.set_changed() - assert t.has_changed() is True + transform = Transform('ast') + transform.set_changed() + transform.set_changed() + assert transform.has_changed() is True class TestTransformRebuildScope: @@ -57,8 +57,8 @@ def test_class_default_is_false(self): assert Transform.rebuild_scope is False def test_instance_inherits_default(self): - t = Transform('ast') - assert t.rebuild_scope is False + transform = Transform('ast') + assert transform.rebuild_scope is False def test_subclass_can_override(self): class MyTransform(Transform): @@ -68,5 +68,5 @@ def execute(self): pass assert MyTransform.rebuild_scope is True - t = MyTransform('ast') - assert t.rebuild_scope is True + transform = MyTransform('ast') + assert transform.rebuild_scope is True diff --git a/tests/unit/transforms/class_string_decoder_test.py b/tests/unit/transforms/class_string_decoder_test.py index e2bc2e4..4926283 100644 --- a/tests/unit/transforms/class_string_decoder_test.py +++ b/tests/unit/transforms/class_string_decoder_test.py @@ -1,6 +1,8 @@ """Tests for the ClassStringDecoder transform.""" from pyjsclear.transforms.class_string_decoder import ClassStringDecoder +from pyjsclear.transforms.dead_class_props import DeadClassPropRemover +from pyjsclear.utils.ast_helpers import get_member_names from tests.unit.conftest import roundtrip @@ -26,7 +28,7 @@ class TestClassPropCollection: def test_string_prop_collected(self): """String assignments on class vars should be tracked internally.""" - # This is an indirect test — if props aren't collected, resolution won't work + # Indirect test — if props aren't collected, resolution won't work code = ''' var Cls = class {}; Cls.p1 = "foo"; @@ -41,15 +43,11 @@ class TestDeadClassPropRemover: """Tests for the DeadClassPropRemover transform.""" def test_no_classes_returns_false(self): - from pyjsclear.transforms.dead_class_props import DeadClassPropRemover - result, changed = roundtrip('var x = 1;', DeadClassPropRemover) assert changed is False def test_dead_prop_removed(self): """Property written but never read should be removed.""" - from pyjsclear.transforms.dead_class_props import DeadClassPropRemover - code = ''' var Cls = class {}; Cls.deadProp = "never_used"; @@ -60,8 +58,6 @@ def test_dead_prop_removed(self): def test_read_prop_preserved(self): """Property that is read should NOT be removed.""" - from pyjsclear.transforms.dead_class_props import DeadClassPropRemover - code = ''' var Cls = class {}; Cls.liveProp = "used"; @@ -73,8 +69,6 @@ def test_read_prop_preserved(self): def test_fully_dead_class(self): """Class that is only used via property assignments — all props dead.""" - from pyjsclear.transforms.dead_class_props import DeadClassPropRemover - code = ''' var Cls = class {}; Cls.a = "x"; @@ -87,8 +81,6 @@ def test_fully_dead_class(self): def test_assignment_class_detected(self): """Class assigned via `X = class {}` (not var) should be detected.""" - from pyjsclear.transforms.dead_class_props import DeadClassPropRemover - code = ''' var Cls; Cls = class {}; @@ -100,8 +92,6 @@ def test_assignment_class_detected(self): def test_sequence_expression_dead_props(self): """Dead props inside SequenceExpression should be stripped.""" - from pyjsclear.transforms.dead_class_props import DeadClassPropRemover - code = ''' var Cls = class {}; Cls.dead1 = "a", Cls.dead2 = "b"; @@ -113,8 +103,6 @@ def test_sequence_expression_dead_props(self): def test_sequence_expression_partial_removal(self): """Only dead props removed from a sequence; live ones kept.""" - from pyjsclear.transforms.dead_class_props import DeadClassPropRemover - code = ''' var Cls = class {}; Cls.dead = "a", Cls.live = "b"; @@ -130,51 +118,41 @@ class TestClassStringDecoderHelpers: """Tests for ClassStringDecoder helper functions.""" def test_get_member_names_computed_string(self): - from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names - node = { 'type': 'MemberExpression', 'object': {'type': 'Identifier', 'name': 'obj'}, 'property': {'type': 'Literal', 'value': 'prop'}, 'computed': True, } - assert _get_member_names(node) == ('obj', 'prop') + assert get_member_names(node) == ('obj', 'prop') def test_get_member_names_non_identifier_prop(self): - from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names - node = { 'type': 'MemberExpression', 'object': {'type': 'Identifier', 'name': 'obj'}, 'property': {'type': 'Literal', 'value': 42}, 'computed': True, } - assert _get_member_names(node) == (None, None) + assert get_member_names(node) == (None, None) def test_get_member_names_dot_notation(self): - from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names - node = { 'type': 'MemberExpression', 'object': {'type': 'Identifier', 'name': 'obj'}, 'property': {'type': 'Identifier', 'name': 'prop'}, 'computed': False, } - assert _get_member_names(node) == ('obj', 'prop') + assert get_member_names(node) == ('obj', 'prop') def test_get_member_names_no_prop(self): - from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names - - assert _get_member_names(None) == (None, None) - assert _get_member_names({'type': 'Literal'}) == (None, None) + assert get_member_names(None) == (None, None) + assert get_member_names({'type': 'Literal'}) == (None, None) def test_get_member_names_non_identifier_object(self): - from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names - node = { 'type': 'MemberExpression', 'object': {'type': 'Literal', 'value': 1}, 'property': {'type': 'Identifier', 'name': 'prop'}, 'computed': False, } - assert _get_member_names(node) == (None, None) + assert get_member_names(node) == (None, None) diff --git a/tests/unit/transforms/constant_prop_test.py b/tests/unit/transforms/constant_prop_test.py index 6be27d8..a132f68 100644 --- a/tests/unit/transforms/constant_prop_test.py +++ b/tests/unit/transforms/constant_prop_test.py @@ -3,6 +3,7 @@ import pytest from pyjsclear.transforms.constant_prop import ConstantProp +from pyjsclear.transforms.constant_prop import _should_skip_reference from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip @@ -60,8 +61,7 @@ def test_skip_update_target(self): """Line 20: Reference used as update target should be skipped.""" code = 'var x = 5; x++;' result, changed = roundtrip(code, ConstantProp) - # x++ should not become 5++ - # x has writes (x++) so it's not constant, so no propagation + # x++ should not become 5++; x has writes so it's not constant assert changed is False def test_skip_declarator_id(self): @@ -112,42 +112,30 @@ class TestConstantPropSkipReferenceEdgeCases: def test_ref_parent_is_none(self): """Line 15: ref_parent is None → return True (skip).""" - from pyjsclear.transforms.constant_prop import _should_skip_reference - assert _should_skip_reference(None, 'left') is True def test_update_expression_parent(self): """Line 20: UpdateExpression parent → return True.""" - from pyjsclear.transforms.constant_prop import _should_skip_reference - parent = {'type': 'UpdateExpression', 'operator': '++', 'argument': {}} assert _should_skip_reference(parent, 'argument') is True def test_variable_declarator_id(self): """Line 22: VariableDeclarator id parent → return True.""" - from pyjsclear.transforms.constant_prop import _should_skip_reference - parent = {'type': 'VariableDeclarator', 'id': {}, 'init': None} assert _should_skip_reference(parent, 'id') is True def test_variable_declarator_init_not_skipped(self): """VariableDeclarator with key='init' should NOT be skipped.""" - from pyjsclear.transforms.constant_prop import _should_skip_reference - parent = {'type': 'VariableDeclarator', 'id': {}, 'init': {}} assert _should_skip_reference(parent, 'init') is False def test_assignment_expression_right_not_skipped(self): """AssignmentExpression with key='right' should NOT be skipped.""" - from pyjsclear.transforms.constant_prop import _should_skip_reference - parent = {'type': 'AssignmentExpression', 'operator': '=', 'left': {}, 'right': {}} assert _should_skip_reference(parent, 'right') is False def test_normal_parent_not_skipped(self): """Normal parent (e.g., CallExpression) should NOT be skipped.""" - from pyjsclear.transforms.constant_prop import _should_skip_reference - parent = {'type': 'CallExpression', 'callee': {}, 'arguments': []} assert _should_skip_reference(parent, 'arguments') is False @@ -163,10 +151,7 @@ def test_binding_with_assignments_not_removed(self): assert changed is False def test_non_dict_decl_node_skip(self): - """Line 82: decl_node not a dict — should be skipped.""" - from pyjsclear.transforms.constant_prop import _should_skip_reference - - # This is a defensive check; test it doesn't crash in normal flow + """Line 82: decl_node not a dict — defensive check doesn't crash.""" code = 'const a = 1; y(a);' result, changed = roundtrip(code, ConstantProp) assert changed is True diff --git a/tests/unit/transforms/control_flow_test.py b/tests/unit/transforms/control_flow_test.py index e3f3fac..247c1c2 100644 --- a/tests/unit/transforms/control_flow_test.py +++ b/tests/unit/transforms/control_flow_test.py @@ -7,8 +7,8 @@ from tests.unit.conftest import roundtrip -def rt(js_code): - """Shorthand roundtrip for ControlFlowRecoverer.""" +def roundtrip_cff(js_code: str) -> tuple[str, bool]: + """Run a roundtrip through ControlFlowRecoverer and normalize output.""" code, changed = roundtrip(js_code, ControlFlowRecoverer) return normalize(code), changed @@ -161,50 +161,50 @@ class TestIsSplitCall: """Test the _is_split_call detection method.""" def setup_method(self): - self.t = ControlFlowRecoverer(_program([])) + self.transform = ControlFlowRecoverer(_program([])) def test_valid_split_call(self): node = _split_call('1|0|3|2') - assert self.t._is_split_call(node) is True + assert self.transform._is_split_call(node) is True def test_non_dict_returns_false(self): - assert self.t._is_split_call(None) is False - assert self.t._is_split_call('string') is False + assert self.transform._is_split_call(None) is False + assert self.transform._is_split_call('string') is False def test_non_call_expression(self): - assert self.t._is_split_call({'type': 'Identifier', 'name': 'x'}) is False + assert self.transform._is_split_call({'type': 'Identifier', 'name': 'x'}) is False def test_callee_not_member_expression(self): node = _call_expr(_identifier('split'), [_literal('|')]) - assert self.t._is_split_call(node) is False + assert self.transform._is_split_call(node) is False def test_object_not_string_literal(self): node = _call_expr( callee=_member_expr(_identifier('arr'), _identifier('split')), arguments=[_literal('|')], ) - assert self.t._is_split_call(node) is False + assert self.transform._is_split_call(node) is False def test_property_not_split(self): node = _call_expr( callee=_member_expr(_literal('1|2'), _identifier('join')), arguments=[_literal('|')], ) - assert self.t._is_split_call(node) is False + assert self.transform._is_split_call(node) is False def test_no_arguments(self): node = _call_expr( callee=_member_expr(_literal('1|2'), _identifier('split')), arguments=[], ) - assert self.t._is_split_call(node) is False + assert self.transform._is_split_call(node) is False def test_argument_not_string(self): node = _call_expr( callee=_member_expr(_literal('1|2'), _identifier('split')), arguments=[_literal(1)], ) - assert self.t._is_split_call(node) is False + assert self.transform._is_split_call(node) is False # --------------------------------------------------------------------------- @@ -216,23 +216,23 @@ class TestExtractSplitStates: """Test the _extract_split_states method.""" def setup_method(self): - self.t = ControlFlowRecoverer(_program([])) + self.transform = ControlFlowRecoverer(_program([])) def test_basic_extraction(self): node = _split_call('1|0|3|2') - assert self.t._extract_split_states(node) == ['1', '0', '3', '2'] + assert self.transform._extract_split_states(node) == ['1', '0', '3', '2'] def test_single_state(self): node = _split_call('0') - assert self.t._extract_split_states(node) == ['0'] + assert self.transform._extract_split_states(node) == ['0'] def test_five_states(self): node = _split_call('4|2|0|1|3') - assert self.t._extract_split_states(node) == ['4', '2', '0', '1', '3'] + assert self.transform._extract_split_states(node) == ['4', '2', '0', '1', '3'] def test_custom_separator(self): node = _split_call('a-b-c', separator='-') - assert self.t._extract_split_states(node) == ['a', 'b', 'c'] + assert self.transform._extract_split_states(node) == ['a', 'b', 'c'] # --------------------------------------------------------------------------- @@ -250,8 +250,8 @@ def test_two_states_reordered(self): '1': [_expr_stmt(_call_expr(_identifier('a'), []))], } ast = _make_cff_ast_var_pattern('1|0', '_a', '_i', cases) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is True body = ast['body'] @@ -270,8 +270,8 @@ def test_three_states(self): '2': [_expr_stmt(_call_expr(_identifier('c'), []))], } ast = _make_cff_ast_var_pattern('2|0|1', '_a', '_i', cases) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is True body = ast['body'] @@ -288,8 +288,8 @@ def test_sequential_order(self): '2': [_expr_stmt(_call_expr(_identifier('c'), []))], } ast = _make_cff_ast_var_pattern('0|1|2', '_a', '_i', cases) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is True body = ast['body'] @@ -304,8 +304,8 @@ def test_continue_statements_filtered(self): '0': [_expr_stmt(_call_expr(_identifier('a'), []))], } ast = _make_cff_ast_var_pattern('0', '_a', '_i', cases) - t = ControlFlowRecoverer(ast) - t.execute() + transform = ControlFlowRecoverer(ast) + transform.execute() body = ast['body'] for stmt in body: @@ -318,8 +318,8 @@ def test_return_statement_preserved(self): '1': [_return_stmt(_literal(42))], } ast = _make_cff_ast_var_pattern('0|1', '_a', '_i', cases) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is True body = ast['body'] @@ -344,8 +344,8 @@ def test_expression_pattern_two_states(self): '1': [_expr_stmt(_call_expr(_identifier('a'), []))], } ast = _make_cff_ast_expr_pattern('1|0', '_a', '_i', cases) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is True body = ast['body'] @@ -360,8 +360,8 @@ def test_expression_pattern_three_states(self): '2': [_expr_stmt(_call_expr(_identifier('z'), []))], } ast = _make_cff_ast_expr_pattern('2|1|0', '_arr', '_idx', cases) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is True body = ast['body'] @@ -386,8 +386,8 @@ def test_plain_statements_unchanged(self): _expr_stmt(_call_expr(_identifier('bar'), [])), ] ) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is False assert len(ast['body']) == 2 @@ -395,16 +395,16 @@ def test_plain_statements_unchanged(self): def test_while_without_switch_unchanged(self): loop = _while_true([_expr_stmt(_call_expr(_identifier('doStuff'), []))]) ast = _program([loop]) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is False def test_var_decl_without_split_unchanged(self): decl = _var_declaration([_var_declarator('x', _literal(10))]) ast = _program([decl]) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is False @@ -424,7 +424,7 @@ def test_basic_cff_roundtrip(self): ' switch (_a[_i++]) { case "0": b(); continue; case "1": a(); continue; }' ' break; }' ) - code, changed = rt(js) + code, changed = roundtrip_cff(js) assert changed is True assert 'a();' in code assert 'b();' in code @@ -438,7 +438,7 @@ def test_three_state_roundtrip(self): ' switch (_s[_c++]) { case "0": first(); continue; case "1": second(); continue; case "2": third(); continue; }' ' break; }' ) - code, changed = rt(js) + code, changed = roundtrip_cff(js) assert changed is True # Order: "2|0|1" => third, first, second assert code.index('third()') < code.index('first()') @@ -495,35 +495,35 @@ class TestIsTruthy: """Test the _is_truthy helper method.""" def setup_method(self): - self.t = ControlFlowRecoverer(_program([])) + self.transform = ControlFlowRecoverer(_program([])) def test_literal_true(self): - assert self.t._is_truthy(_literal(True)) is True + assert self.transform._is_truthy(_literal(True)) is True def test_literal_1(self): - assert self.t._is_truthy(_literal(1)) is True + assert self.transform._is_truthy(_literal(1)) is True def test_literal_false(self): - assert self.t._is_truthy(_literal(False)) is False + assert self.transform._is_truthy(_literal(False)) is False def test_literal_0(self): - assert self.t._is_truthy(_literal(0)) is False + assert self.transform._is_truthy(_literal(0)) is False def test_not_zero_is_truthy(self): """!0 should be recognized as truthy.""" node = {'type': 'UnaryExpression', 'operator': '!', 'argument': _literal(0), 'prefix': True} - assert self.t._is_truthy(node) is True + assert self.transform._is_truthy(node) is True def test_not_dict(self): - assert self.t._is_truthy(None) is False - assert self.t._is_truthy('string') is False + assert self.transform._is_truthy(None) is False + assert self.transform._is_truthy('string') is False def test_double_not_array_is_truthy(self): """!![] should be truthy.""" inner = {'type': 'ArrayExpression', 'elements': []} not_inner = {'type': 'UnaryExpression', 'operator': '!', 'argument': inner, 'prefix': True} double_not = {'type': 'UnaryExpression', 'operator': '!', 'argument': not_inner, 'prefix': True} - assert self.t._is_truthy(double_not) is True + assert self.transform._is_truthy(double_not) is True # --------------------------------------------------------------------------- @@ -537,8 +537,8 @@ class TestNonDictInBody: def test_non_dict_in_body_skipped(self): """Non-dict items in body should be skipped without crashing.""" ast = _program([None, 'not_a_dict', _expr_stmt(_call_expr(_identifier('a'), []))]) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is False @@ -553,8 +553,8 @@ def test_expression_pattern_no_split(self): _while_true([_break_stmt()]), ] ) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is False def test_expression_pattern_missing_loop(self): @@ -563,8 +563,8 @@ def test_expression_pattern_missing_loop(self): counter_stmt = _expr_stmt(_assignment('_i', _literal(0))) # No loop after counter, just ends ast = _program([assign_stmt, counter_stmt]) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is False @@ -572,28 +572,28 @@ class TestFindCounterInit: """Lines 175-183, 194: _find_counter_init with VariableDeclaration and ExpressionStatement.""" def setup_method(self): - self.t = ControlFlowRecoverer(_program([])) + self.transform = ControlFlowRecoverer(_program([])) def test_variable_declaration_counter(self): stmt = _var_declaration([_var_declarator('_i', _literal(0))]) - result = self.t._find_counter_init(stmt) + result = self.transform._find_counter_init(stmt) assert result == '_i' def test_expression_statement_counter(self): stmt = _expr_stmt(_assignment('_i', _literal(0))) - result = self.t._find_counter_init(stmt) + result = self.transform._find_counter_init(stmt) assert result == '_i' def test_non_numeric_init_ignored(self): stmt = _var_declaration([_var_declarator('_i', _literal('hello'))]) - result = self.t._find_counter_init(stmt) + result = self.transform._find_counter_init(stmt) assert result is None def test_non_dict_returns_none(self): - result = self.t._find_counter_init(None) + result = self.transform._find_counter_init(None) assert result is None - result = self.t._find_counter_init('not a dict') + result = self.transform._find_counter_init('not a dict') assert result is None @@ -608,7 +608,7 @@ def test_for_statement_recovery(self): ' switch (_a[_j++]) { case "0": b(); continue; case "1": a(); continue; }' ' break; }' ) - code, changed = rt(js) + code, changed = roundtrip_cff(js) assert changed is True assert 'a()' in code assert 'b()' in code @@ -655,25 +655,25 @@ class TestExtractSwitchFromLoopBody: """Lines 302, 308-310: _extract_switch_from_loop_body edge cases.""" def setup_method(self): - self.t = ControlFlowRecoverer(_program([])) + self.transform = ControlFlowRecoverer(_program([])) def test_non_block_statement(self): """Non-BlockStatement body returns None.""" - result = self.t._extract_switch_from_loop_body(_expr_stmt(_call_expr(_identifier('a'), []))) + result = self.transform._extract_switch_from_loop_body(_expr_stmt(_call_expr(_identifier('a'), []))) assert result is None def test_switch_directly_as_body(self): """SwitchStatement directly as loop body.""" switch = _switch_stmt(_identifier('x'), []) - result = self.t._extract_switch_from_loop_body(switch) + result = self.transform._extract_switch_from_loop_body(switch) assert result is not None assert result['type'] == 'SwitchStatement' def test_non_dict_body(self): - result = self.t._extract_switch_from_loop_body(None) + result = self.transform._extract_switch_from_loop_body(None) assert result is None - result = self.t._extract_switch_from_loop_body('not a dict') + result = self.transform._extract_switch_from_loop_body('not a dict') assert result is None @@ -688,7 +688,7 @@ def test_while_not_zero(self): ' switch(_a[_i++]) { case "0": b(); continue; case "1": a(); continue; }' ' break; }' ) - result, changed = rt(code) + result, changed = roundtrip_cff(code) assert changed assert 'a()' in result assert 'b()' in result @@ -701,7 +701,7 @@ def test_while_double_not_array(self): ' switch(_a[_i++]) { case "0": a(); continue; case "1": b(); continue; }' ' break; }' ) - result, changed = rt(code) + result, changed = roundtrip_cff(code) assert changed assert 'a()' in result assert 'b()' in result @@ -714,27 +714,27 @@ def test_case_with_return_in_roundtrip(self): ' while(true) { switch(_a[_i++]) { case "0": a(); continue; case "1": return b(); } break; }' ' }' ) - result, changed = rt(code) + result, changed = roundtrip_cff(code) assert changed assert 'return' in result def test_is_truthy_not_array_is_false(self): """![] is falsy (line 324).""" - t = ControlFlowRecoverer(_program([])) + transform = ControlFlowRecoverer(_program([])) node = { 'type': 'UnaryExpression', 'operator': '!', 'argument': {'type': 'ArrayExpression', 'elements': []}, 'prefix': True, } - assert t._is_truthy(node) is False + assert transform._is_truthy(node) is False def test_is_truthy_literal_non_bool(self): """Literal with non-bool truthy value (line 317).""" - t = ControlFlowRecoverer(_program([])) - assert t._is_truthy(_literal(42)) is True - assert t._is_truthy(_literal('')) is False - assert t._is_truthy(_literal('hello')) is True + transform = ControlFlowRecoverer(_program([])) + assert transform._is_truthy(_literal(42)) is True + assert transform._is_truthy(_literal('')) is False + assert transform._is_truthy(_literal('hello')) is True def test_visited_set_dedup(self): """Lines 36-37: visited set prevents re-processing the same node.""" @@ -742,8 +742,8 @@ def test_visited_set_dedup(self): shared = _expr_stmt(_call_expr(_identifier('a'), [])) block = {'type': 'BlockStatement', 'body': [shared]} ast = _program([block]) - t = ControlFlowRecoverer(ast) - changed = t.execute() + transform = ControlFlowRecoverer(ast) + changed = transform.execute() assert changed is False diff --git a/tests/unit/transforms/expression_simplifier_test.py b/tests/unit/transforms/expression_simplifier_test.py index 5469b49..f2e6fed 100644 --- a/tests/unit/transforms/expression_simplifier_test.py +++ b/tests/unit/transforms/expression_simplifier_test.py @@ -1,13 +1,16 @@ """Unit tests for ExpressionSimplifier transform.""" +import math + import pytest from pyjsclear.transforms.expression_simplifier import ExpressionSimplifier +from pyjsclear.transforms.expression_simplifier import _JS_NULL from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip -def rt(js_code): +def rt(js_code: str) -> tuple: """Shorthand roundtrip for ExpressionSimplifier.""" code, changed = roundtrip(js_code, ExpressionSimplifier) return normalize(code), changed @@ -265,7 +268,7 @@ def test_non_dict_returns_false(self): val, ok = es._get_resolvable_value(None) assert ok is False - val, ok = es._get_resolvable_value("string") + val, ok = es._get_resolvable_value('string') assert ok is False @@ -462,14 +465,10 @@ def test_string_numeric(self): assert es._js_to_number('5') == 5 def test_null_to_zero(self): - from pyjsclear.transforms.expression_simplifier import _JS_NULL - es = ExpressionSimplifier({'type': 'Program', 'body': []}) assert es._js_to_number(_JS_NULL) == 0 def test_undefined_to_nan(self): - import math - es = ExpressionSimplifier({'type': 'Program', 'body': []}) result = es._js_to_number(None) assert math.isnan(result) @@ -483,8 +482,6 @@ class TestJsToString: """Lines 343-358: _js_to_string for various types.""" def test_null_to_string(self): - from pyjsclear.transforms.expression_simplifier import _JS_NULL - es = ExpressionSimplifier({'type': 'Program', 'body': []}) assert es._js_to_string(_JS_NULL) == 'null' @@ -516,8 +513,6 @@ def test_string_comparison(self): assert es._js_compare('a', 'a') == 0 def test_nan_comparison(self): - import math - es = ExpressionSimplifier({'type': 'Program', 'body': []}) result = es._js_compare(float('nan'), 1) assert math.isnan(result) @@ -530,8 +525,6 @@ class TestValueToNode: """Lines 384-403: _value_to_node for various types.""" def test_null_value(self): - from pyjsclear.transforms.expression_simplifier import _JS_NULL - es = ExpressionSimplifier({'type': 'Program', 'body': []}) result = es._value_to_node(_JS_NULL) assert result['type'] == 'Literal' @@ -622,7 +615,6 @@ def test_unary_minus_undefined_returns_nan_no_node(self): """-undefined → NaN, which _value_to_node returns None for.""" code, changed = rt('-undefined;') # -undefined produces NaN, which can't be represented as a literal - # so it stays unchanged assert changed is False diff --git a/tests/unit/transforms/hex_numerics_test.py b/tests/unit/transforms/hex_numerics_test.py index cb797ba..1342487 100644 --- a/tests/unit/transforms/hex_numerics_test.py +++ b/tests/unit/transforms/hex_numerics_test.py @@ -1,5 +1,7 @@ """Tests for the HexNumerics transform.""" +from pyjsclear.generator import generate +from pyjsclear.parser import parse from pyjsclear.transforms.hex_numerics import HexNumerics from tests.unit.conftest import roundtrip @@ -47,9 +49,6 @@ def test_float_unchanged(self): def test_no_raw_field_unchanged(self): """Numeric literal without a raw field should not be transformed.""" - from pyjsclear.generator import generate - from pyjsclear.parser import parse - ast = parse('var x = 1;') # Manually remove raw to simulate a synthetic node for stmt in ast['body']: @@ -62,9 +61,6 @@ def test_no_raw_field_unchanged(self): def test_negative_hex_value(self): """Hex literal with negative value uses plain str() path (line 26).""" - from pyjsclear.generator import generate - from pyjsclear.parser import parse - ast = parse('var x = 0x1;') # Manually set the value to a negative to test the else branch decl = ast['body'][0]['declarations'][0] diff --git a/tests/unit/transforms/logical_to_if_test.py b/tests/unit/transforms/logical_to_if_test.py index 98791f5..bab9130 100644 --- a/tests/unit/transforms/logical_to_if_test.py +++ b/tests/unit/transforms/logical_to_if_test.py @@ -1,5 +1,7 @@ -import pytest +"""Tests for the LogicalToIf transform.""" +from pyjsclear.generator import generate +from pyjsclear.parser import parse from pyjsclear.transforms.logical_to_if import LogicalToIf from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip @@ -150,9 +152,6 @@ def test_return_single_element_sequence(self): """Line 104: Return with single-element sequence (len <= 1) returns None.""" # Manually constructing is tricky; a single-element SequenceExpression # is unusual. We test via the AST directly. - from pyjsclear.generator import generate - from pyjsclear.parser import parse - ast = parse('function f() { return a; }') # Manually make the return argument a SequenceExpression with 1 element ret_stmt = ast['body'][0]['body']['body'][0] @@ -181,9 +180,6 @@ def test_return_logical_right_not_sequence(self): def test_return_logical_right_sequence_single_element(self): """Line 123: Return logical where right side is sequence with <=1 elements.""" - from pyjsclear.generator import generate - from pyjsclear.parser import parse - ast = parse('function f() { return a || b; }') ret_stmt = ast['body'][0]['body']['body'][0] ret_stmt['argument'] = { @@ -201,9 +197,6 @@ def test_return_logical_right_sequence_single_element(self): def test_nullish_coalescing_not_converted(self): """Lines 147-148: _logical_to_if with unknown operator (e.g. '??') returns None.""" - from pyjsclear.generator import generate - from pyjsclear.parser import parse - ast = parse('a ?? b();') # Esprima may not parse ?? as LogicalExpression, so force it expr_stmt = ast['body'][0] @@ -219,8 +212,6 @@ def test_nullish_coalescing_not_converted(self): def test_expression_stmt_non_dict_expression(self): """Line 75: ExpressionStatement with non-dict expression returns None.""" - from pyjsclear.parser import parse - ast = parse('a();') ast['body'][0]['expression'] = 42 t = LogicalToIf(ast) diff --git a/tests/unit/transforms/member_chain_resolver_test.py b/tests/unit/transforms/member_chain_resolver_test.py index afad3b3..1d43c72 100644 --- a/tests/unit/transforms/member_chain_resolver_test.py +++ b/tests/unit/transforms/member_chain_resolver_test.py @@ -1,6 +1,7 @@ """Tests for the MemberChainResolver transform.""" from pyjsclear.transforms.member_chain_resolver import MemberChainResolver +from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip @@ -90,13 +91,9 @@ class TestHelperFunctions: """Direct tests for _get_member_names helper.""" def test_get_member_names_none(self): - from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names - assert _get_member_names(None) == (None, None) def test_get_member_names_no_prop(self): - from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names - node = { 'type': 'MemberExpression', 'object': {'type': 'Identifier', 'name': 'x'}, @@ -106,8 +103,6 @@ def test_get_member_names_no_prop(self): assert _get_member_names(node) == (None, None) def test_get_member_names_computed_non_string(self): - from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names - node = { 'type': 'MemberExpression', 'object': {'type': 'Identifier', 'name': 'x'}, @@ -117,8 +112,6 @@ def test_get_member_names_computed_non_string(self): assert _get_member_names(node) == (None, None) def test_get_member_names_non_identifier_obj(self): - from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names - node = { 'type': 'MemberExpression', 'object': {'type': 'Literal', 'value': 1}, diff --git a/tests/unit/transforms/object_packer_test.py b/tests/unit/transforms/object_packer_test.py index d9afd53..118dc1b 100644 --- a/tests/unit/transforms/object_packer_test.py +++ b/tests/unit/transforms/object_packer_test.py @@ -1,5 +1,8 @@ +"""Tests for the ObjectPacker transform.""" + import pytest +from pyjsclear.parser import parse from pyjsclear.transforms.object_packer import ObjectPacker from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip @@ -128,7 +131,6 @@ def test_object_name_mismatch(self): def test_property_node_is_none(self): """Line 87: Property node is None stops packing.""" from pyjsclear.generator import generate - from pyjsclear.parser import parse ast = parse('var o = {}; o.x = 1;') # Manually set the property of the MemberExpression to None @@ -165,8 +167,6 @@ def test_references_name_identifier_match(self): def test_non_dict_in_body_direct_ast(self): """Line 22/57: non-dict in body triggers skip in _process_bodies and _try_pack_body.""" - from pyjsclear.parser import parse - ast = parse('var o = {}; o.x = 1;') ast['body'].append(42) t = ObjectPacker(ast) diff --git a/tests/unit/transforms/object_simplifier_test.py b/tests/unit/transforms/object_simplifier_test.py index feec01a..b2ea715 100644 --- a/tests/unit/transforms/object_simplifier_test.py +++ b/tests/unit/transforms/object_simplifier_test.py @@ -1,5 +1,9 @@ +"""Tests for the ObjectSimplifier transform.""" + import pytest +from pyjsclear.generator import generate +from pyjsclear.parser import parse from pyjsclear.transforms.object_simplifier import ObjectSimplifier from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip @@ -182,9 +186,6 @@ def test_recurse_into_child_scopes(self): def test_is_proxy_object_non_property_type(self): """Lines 121, 124: _is_proxy_object with non-Property type or missing value.""" - from pyjsclear.generator import generate - from pyjsclear.parser import parse - ast = parse('const o = {x: 1}; y(o.x);') t = ObjectSimplifier(ast) # SpreadElement is not a Property @@ -194,8 +195,6 @@ def test_is_proxy_object_non_property_type(self): def test_get_property_key_no_key(self): """Lines 136: _get_property_key returns None when key is missing.""" - from pyjsclear.parser import parse - ast = parse('const o = {x: 1};') t = ObjectSimplifier(ast) assert t._get_property_key({}) is None @@ -203,8 +202,6 @@ def test_get_property_key_no_key(self): def test_computed_property_key_ignored(self): """Line 46: property key returns None for computed key (not Identifier or string Literal).""" - from pyjsclear.parser import parse - ast = parse('const o = {x: 1}; var y = o.x;') t = ObjectSimplifier(ast) # A property with a computed (non-string, non-identifier) key returns None @@ -212,8 +209,6 @@ def test_computed_property_key_ignored(self): def test_has_property_assignment_me_parent_info_none(self): """Line 97: _has_property_assignment where find_parent returns None for me.""" - from pyjsclear.parser import parse - # Build a scenario with a detached member expression ast = parse('const o = {x: 1}; var y = o.x;') t = ObjectSimplifier(ast) @@ -223,8 +218,6 @@ def test_has_property_assignment_me_parent_info_none(self): def test_try_inline_function_call_me_parent_info_none(self): """Line 107: _try_inline_function_call where find_parent returns None.""" - from pyjsclear.parser import parse - ast = parse('const o = {fn: function(a) { return a; }}; o.fn(1);') t = ObjectSimplifier(ast) changed = t.execute() @@ -232,8 +225,6 @@ def test_try_inline_function_call_me_parent_info_none(self): def test_get_member_prop_name_no_property(self): """Line 148: _get_member_prop_name with no property returns None.""" - from pyjsclear.parser import parse - ast = parse('const o = {x: 1};') t = ObjectSimplifier(ast) assert t._get_member_prop_name({}) is None @@ -241,12 +232,9 @@ def test_get_member_prop_name_no_property(self): def test_body_not_block_not_expression(self): """Line 183: body that's not BlockStatement and not expression for non-arrow.""" - from pyjsclear.parser import parse - ast = parse('const o = {fn: function(a) { return a; }}; o.fn(1);') t = ObjectSimplifier(ast) # Manually set the function body to something that is not a BlockStatement - # Find the function in the object props = ast['body'][0]['declarations'][0]['init']['properties'] func = props[0]['value'] func['body'] = {'type': 'Literal', 'value': 1} # Not a BlockStatement diff --git a/tests/unit/transforms/optional_chaining_test.py b/tests/unit/transforms/optional_chaining_test.py index b0b931c..387f739 100644 --- a/tests/unit/transforms/optional_chaining_test.py +++ b/tests/unit/transforms/optional_chaining_test.py @@ -1,5 +1,7 @@ """Tests for the OptionalChaining transform.""" +from pyjsclear.generator import generate +from pyjsclear.parser import parse from pyjsclear.transforms.optional_chaining import OptionalChaining from pyjsclear.transforms.optional_chaining import _nodes_match from pyjsclear.utils.ast_helpers import is_null_literal @@ -10,7 +12,7 @@ class TestSimplePattern: """Tests for X === null || X === undefined ? undefined : X.prop → X?.prop.""" - def test_basic_member_access(self): + def test_basic_member_access(self) -> None: code, changed = roundtrip( 'var y = x === null || x === undefined ? undefined : x.foo;', OptionalChaining, @@ -18,7 +20,7 @@ def test_basic_member_access(self): assert changed is True assert 'x?.foo' in code - def test_computed_member_access(self): + def test_computed_member_access(self) -> None: code, changed = roundtrip( 'var y = x === null || x === undefined ? undefined : x["foo"];', OptionalChaining, @@ -26,7 +28,7 @@ def test_computed_member_access(self): assert changed is True assert '?.["foo"]' in code - def test_reversed_null_undefined(self): + def test_reversed_null_undefined(self) -> None: code, changed = roundtrip( 'var y = x === undefined || x === null ? undefined : x.foo;', OptionalChaining, @@ -34,7 +36,7 @@ def test_reversed_null_undefined(self): assert changed is True assert 'x?.foo' in code - def test_void_0_as_undefined(self): + def test_void_0_as_undefined(self) -> None: code, changed = roundtrip( 'var y = x === null || x === void 0 ? void 0 : x.foo;', OptionalChaining, @@ -46,7 +48,7 @@ def test_void_0_as_undefined(self): class TestTempAssignmentPattern: """Tests for (_tmp = expr) === null || _tmp === undefined ? undefined : _tmp.prop.""" - def test_temp_var_member(self): + def test_temp_var_member(self) -> None: code, changed = roundtrip( 'var y = (_tmp = obj.prop) === null || _tmp === undefined ? undefined : _tmp.nested;', OptionalChaining, @@ -54,7 +56,7 @@ def test_temp_var_member(self): assert changed is True assert 'obj.prop?.nested' in code - def test_temp_var_eliminates_temp(self): + def test_temp_var_eliminates_temp(self) -> None: """The temp variable should not appear in the output.""" code, changed = roundtrip( 'var y = (_tmp = obj.a) === null || _tmp === undefined ? undefined : _tmp.b;', @@ -67,21 +69,21 @@ def test_temp_var_eliminates_temp(self): class TestNoTransform: """Cases that should NOT trigger the transform.""" - def test_consequent_not_undefined(self): + def test_consequent_not_undefined(self) -> None: code, changed = roundtrip( 'var y = x === null || x === undefined ? 0 : x.foo;', OptionalChaining, ) assert changed is False - def test_different_variables_in_checks(self): + def test_different_variables_in_checks(self) -> None: code, changed = roundtrip( 'var y = x === null || z === undefined ? undefined : x.foo;', OptionalChaining, ) assert changed is False - def test_alternate_is_plain_identifier(self): + def test_alternate_is_plain_identifier(self) -> None: """x?. would require the alternate to be a member/call, not just x.""" code, changed = roundtrip( 'var y = x === null || x === undefined ? undefined : x;', @@ -89,7 +91,7 @@ def test_alternate_is_plain_identifier(self): ) assert changed is False - def test_and_operator_not_or(self): + def test_and_operator_not_or(self) -> None: """&& instead of || should not match (that's the nullish coalescing pattern).""" code, changed = roundtrip( 'var y = x === null && x === undefined ? undefined : x.foo;', @@ -97,7 +99,7 @@ def test_and_operator_not_or(self): ) assert changed is False - def test_not_equality_check(self): + def test_not_equality_check(self) -> None: """!== instead of === should not match.""" code, changed = roundtrip( 'var y = x !== null || x !== undefined ? undefined : x.foo;', @@ -109,7 +111,7 @@ def test_not_equality_check(self): class TestOptionalCall: """Tests for X === null || X === undefined ? undefined : X() → X?.().""" - def test_optional_call(self): + def test_optional_call(self) -> None: code, changed = roundtrip( 'var y = fn === null || fn === undefined ? undefined : fn();', OptionalChaining, @@ -117,7 +119,7 @@ def test_optional_call(self): assert changed is True assert 'fn?.()' in code - def test_optional_call_with_args(self): + def test_optional_call_with_args(self) -> None: code, changed = roundtrip( 'var y = fn === null || fn === undefined ? undefined : fn(1, 2);', OptionalChaining, @@ -129,7 +131,7 @@ def test_optional_call_with_args(self): class TestYodaStyle: """Tests for null/undefined on the left side of comparison.""" - def test_null_on_left(self): + def test_null_on_left(self) -> None: """null === x || undefined === x ? undefined : x.foo → x?.foo.""" code, changed = roundtrip( 'var y = null === x || undefined === x ? undefined : x.foo;', @@ -138,7 +140,7 @@ def test_null_on_left(self): assert changed is True assert 'x?.foo' in code - def test_void_0_on_consequent(self): + def test_void_0_on_consequent(self) -> None: """void 0 as consequent should be recognized as undefined.""" code, changed = roundtrip( 'var y = x === null || x === void 0 ? void 0 : x.foo;', @@ -150,11 +152,11 @@ def test_void_0_on_consequent(self): class TestHelperFunctions: """Direct tests for helper functions to cover edge cases.""" - def test_is_undefined_with_non_dict(self): + def test_is_undefined_with_non_dict(self) -> None: assert is_undefined(None) is False assert is_undefined('string') is False - def test_is_undefined_with_void_0(self): + def test_is_undefined_with_void_0(self) -> None: node = { 'type': 'UnaryExpression', 'operator': 'void', @@ -162,7 +164,7 @@ def test_is_undefined_with_void_0(self): } assert is_undefined(node) is True - def test_is_undefined_with_void_non_zero(self): + def test_is_undefined_with_void_non_zero(self) -> None: node = { 'type': 'UnaryExpression', 'operator': 'void', @@ -170,23 +172,23 @@ def test_is_undefined_with_void_non_zero(self): } assert is_undefined(node) is False - def test_is_null_literal_true(self): + def test_is_null_literal_true(self) -> None: assert is_null_literal({'type': 'Literal', 'value': None, 'raw': 'null'}) is True - def test_is_null_literal_false(self): + def test_is_null_literal_false(self) -> None: assert is_null_literal({'type': 'Literal', 'value': 0}) is False assert is_null_literal(None) is False - def test_nodes_match_non_dict(self): + def test_nodes_match_non_dict(self) -> None: assert _nodes_match(None, None) is False assert _nodes_match({}, None) is False - def test_nodes_match_different_types(self): + def test_nodes_match_different_types(self) -> None: a = {'type': 'Identifier', 'name': 'x'} b = {'type': 'Literal', 'value': 1} assert _nodes_match(a, b) is False - def test_nodes_match_member_expression(self): + def test_nodes_match_member_expression(self) -> None: a = { 'type': 'MemberExpression', 'object': {'type': 'Identifier', 'name': 'obj'}, @@ -201,7 +203,7 @@ def test_nodes_match_member_expression(self): } assert _nodes_match(a, b) is True - def test_nodes_match_unknown_type_returns_false(self): + def test_nodes_match_unknown_type_returns_false(self) -> None: a = {'type': 'CallExpression'} b = {'type': 'CallExpression'} assert _nodes_match(a, b) is False @@ -210,34 +212,22 @@ def test_nodes_match_unknown_type_returns_false(self): class TestGeneratorOutput: """Tests that ?. roundtrips through parse → generate correctly.""" - def test_optional_member_roundtrip(self): - from pyjsclear.generator import generate - from pyjsclear.parser import parse - + def test_optional_member_roundtrip(self) -> None: code = 'var x = a?.b;' result = generate(parse(code)) assert 'a?.b' in result - def test_optional_computed_roundtrip(self): - from pyjsclear.generator import generate - from pyjsclear.parser import parse - + def test_optional_computed_roundtrip(self) -> None: code = 'var x = a?.[0];' result = generate(parse(code)) assert 'a?.[0]' in result - def test_optional_call_roundtrip(self): - from pyjsclear.generator import generate - from pyjsclear.parser import parse - + def test_optional_call_roundtrip(self) -> None: code = 'var x = fn?.();' result = generate(parse(code)) assert 'fn?.()' in result - def test_chained_optional(self): - from pyjsclear.generator import generate - from pyjsclear.parser import parse - + def test_chained_optional(self) -> None: code = 'var x = a?.b?.c;' result = generate(parse(code)) assert 'a?.b?.c' in result diff --git a/tests/unit/transforms/property_simplifier_test.py b/tests/unit/transforms/property_simplifier_test.py index ff2132f..ac44eef 100644 --- a/tests/unit/transforms/property_simplifier_test.py +++ b/tests/unit/transforms/property_simplifier_test.py @@ -1,5 +1,3 @@ -import pytest - from pyjsclear.transforms.property_simplifier import PropertySimplifier from tests.unit.conftest import roundtrip @@ -7,27 +5,27 @@ class TestBracketToDot: """Tests for converting obj["prop"] to obj.prop.""" - def test_simple_bracket_to_dot(self): + def test_simple_bracket_to_dot(self) -> None: code, changed = roundtrip('obj["foo"];', PropertySimplifier) assert changed is True assert code == 'obj.foo;' - def test_invalid_identifier_unchanged(self): + def test_invalid_identifier_unchanged(self) -> None: code, changed = roundtrip('obj["0abc"];', PropertySimplifier) assert changed is False assert '"0abc"' in code or "'0abc'" in code - def test_reserved_word_is_valid_identifier(self): + def test_reserved_word_is_valid_identifier(self) -> None: code, changed = roundtrip('obj["class"];', PropertySimplifier) assert changed is True assert code == 'obj.class;' - def test_non_string_computed_unchanged(self): + def test_non_string_computed_unchanged(self) -> None: code, changed = roundtrip('obj[x];', PropertySimplifier) assert changed is False assert 'obj[x]' in code - def test_already_dot_notation_unchanged(self): + def test_already_dot_notation_unchanged(self) -> None: code, changed = roundtrip('obj.foo;', PropertySimplifier) assert changed is False assert 'obj.foo' in code @@ -36,18 +34,18 @@ def test_already_dot_notation_unchanged(self): class TestObjectLiteralKeys: """Tests for simplifying string literal keys in object literals.""" - def test_string_key_to_identifier(self): + def test_string_key_to_identifier(self) -> None: """String key becomes Identifier.""" code, changed = roundtrip('var x = {"foo": 1};', PropertySimplifier) assert changed is True assert 'foo: 1' in code or 'foo:' in code - def test_invalid_identifier_key_unchanged(self): + def test_invalid_identifier_key_unchanged(self) -> None: code, changed = roundtrip('var x = {"0abc": 1};', PropertySimplifier) assert changed is False assert '0abc' in code - def test_already_identifier_key_no_change(self): + def test_already_identifier_key_no_change(self) -> None: code, changed = roundtrip('var x = {foo: 1};', PropertySimplifier) assert changed is False assert 'foo' in code @@ -56,14 +54,14 @@ def test_already_identifier_key_no_change(self): class TestMultipleProperties: """Tests for mixed property access patterns.""" - def test_mixed_bracket_and_dot(self): + def test_mixed_bracket_and_dot(self) -> None: code, changed = roundtrip('obj["foo"]; obj.bar; obj["0bad"];', PropertySimplifier) assert changed is True assert 'obj.foo' in code assert 'obj.bar' in code assert '0bad' in code - def test_multiple_object_literal_keys(self): + def test_multiple_object_literal_keys(self) -> None: code, changed = roundtrip('var x = {"good": 1, "0bad": 2, ok: 3};', PropertySimplifier) assert changed is True # "good" converted to identifier key diff --git a/tests/unit/transforms/proxy_functions_test.py b/tests/unit/transforms/proxy_functions_test.py index 660b0a6..89567e8 100644 --- a/tests/unit/transforms/proxy_functions_test.py +++ b/tests/unit/transforms/proxy_functions_test.py @@ -1,14 +1,14 @@ """Unit tests for ProxyFunctionInliner transform.""" -import pytest - +from pyjsclear.generator import generate +from pyjsclear.parser import parse from pyjsclear.transforms.proxy_functions import ProxyFunctionInliner from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip class TestProxyFunctionInlinerBasic: - def test_binary_proxy_inlined(self): + def test_binary_proxy_inlined(self) -> None: code, changed = roundtrip( 'function p(a, b) { return a + b; } p(1, 2);', ProxyFunctionInliner, @@ -19,7 +19,7 @@ def test_binary_proxy_inlined(self): assert '1 + 2;' in norm assert 'p(1, 2)' not in norm - def test_call_proxy_inlined(self): + def test_call_proxy_inlined(self) -> None: code, changed = roundtrip( 'function p(a, b) { return a(b); } p(foo, 1);', ProxyFunctionInliner, @@ -29,7 +29,7 @@ def test_call_proxy_inlined(self): assert 'foo(1);' in norm assert 'p(foo, 1)' not in norm - def test_arrow_proxy_inlined(self): + def test_arrow_proxy_inlined(self) -> None: code, changed = roundtrip( 'var p = (a) => a * 2; p(3);', ProxyFunctionInliner, @@ -41,7 +41,7 @@ def test_arrow_proxy_inlined(self): class TestProxyFunctionInlinerSkips: - def test_multi_statement_body_not_inlined(self): + def test_multi_statement_body_not_inlined(self) -> None: code, changed = roundtrip( 'function p(a) { var x = 1; return a + x; } p(5);', ProxyFunctionInliner, @@ -49,7 +49,7 @@ def test_multi_statement_body_not_inlined(self): assert changed is False assert 'p(5)' in normalize(code) - def test_non_constant_binding_not_inlined(self): + def test_non_constant_binding_not_inlined(self) -> None: code, changed = roundtrip( 'var p = (a) => a * 2; p = something; p(3);', ProxyFunctionInliner, @@ -57,7 +57,7 @@ def test_non_constant_binding_not_inlined(self): assert changed is False assert 'p(3)' in normalize(code) - def test_no_proxy_functions_returns_false(self): + def test_no_proxy_functions_returns_false(self) -> None: code, changed = roundtrip( 'var x = 1 + 2;', ProxyFunctionInliner, @@ -66,7 +66,7 @@ def test_no_proxy_functions_returns_false(self): class TestProxyFunctionInlinerEdgeCases: - def test_missing_args_substitutes_undefined(self): + def test_missing_args_substitutes_undefined(self) -> None: code, changed = roundtrip( 'function p(a, b) { return a + b; } p(1);', ProxyFunctionInliner, @@ -75,7 +75,7 @@ def test_missing_args_substitutes_undefined(self): norm = normalize(code) assert '1 + undefined' in norm - def test_function_with_no_return_value_gives_undefined(self): + def test_function_with_no_return_value_gives_undefined(self) -> None: code, changed = roundtrip( 'function p() { return; } var x = p();', ProxyFunctionInliner, @@ -83,7 +83,7 @@ def test_function_with_no_return_value_gives_undefined(self): assert changed is True assert 'undefined' in code - def test_nested_calls_processed_innermost_first(self): + def test_nested_calls_processed_innermost_first(self) -> None: code, changed = roundtrip( 'function p(a, b) { return a + b; } p(p(1, 2), 3);', ProxyFunctionInliner, @@ -98,21 +98,21 @@ def test_nested_calls_processed_innermost_first(self): class TestProxyFunctionInlinerDisallowed: - def test_function_expression_in_return_not_inlined(self): + def test_function_expression_in_return_not_inlined(self) -> None: code, changed = roundtrip( 'function p() { return function() {}; } p();', ProxyFunctionInliner, ) assert changed is False - def test_assignment_expression_in_return_not_inlined(self): + def test_assignment_expression_in_return_not_inlined(self) -> None: code, changed = roundtrip( 'function p(a) { return a = 1; } p(x);', ProxyFunctionInliner, ) assert changed is False - def test_sequence_expression_in_return_not_inlined(self): + def test_sequence_expression_in_return_not_inlined(self) -> None: code, changed = roundtrip( 'function p(a) { return (1, a); } p(x);', ProxyFunctionInliner, @@ -123,7 +123,7 @@ def test_sequence_expression_in_return_not_inlined(self): class TestCoverageGaps: """Tests targeting uncovered lines in proxy_functions.py.""" - def test_callee_not_identifier(self): + def test_callee_not_identifier(self) -> None: """Line 42: CallExpression with non-identifier callee (e.g., member expression).""" code, changed = roundtrip( 'function p(a) { return a; } obj.p(1);', @@ -132,7 +132,7 @@ def test_callee_not_identifier(self): # obj.p(1) callee is MemberExpression, not Identifier — should not inline assert 'obj.p(1)' in normalize(code) - def test_destructuring_params_not_proxy(self): + def test_destructuring_params_not_proxy(self) -> None: """Line 109: Function with non-identifier params (destructuring) is not a proxy.""" code, changed = roundtrip( 'function f({a, b}) { return a + b; } f({a: 1, b: 2});', @@ -140,11 +140,8 @@ def test_destructuring_params_not_proxy(self): ) assert changed is False - def test_body_is_none(self): + def test_body_is_none(self) -> None: """Line 113: Function body is None — not a proxy.""" - from pyjsclear.generator import generate - from pyjsclear.parser import parse - ast = parse('function f() { return 1; } f();') # Manually remove the body func = ast['body'][0] @@ -153,7 +150,7 @@ def test_body_is_none(self): changed = t.execute() assert changed is False - def test_block_body_non_return_statement(self): + def test_block_body_non_return_statement(self) -> None: """Line 126: Block body with a non-return statement (e.g., expression).""" code, changed = roundtrip( 'function f() { console.log(1); } f();', @@ -161,7 +158,7 @@ def test_block_body_non_return_statement(self): ) assert changed is False - def test_arrow_in_return_not_proxy(self): + def test_arrow_in_return_not_proxy(self) -> None: """Lines 158-159: _is_proxy_value rejects ArrowFunctionExpression.""" code, changed = roundtrip( 'function f() { return () => 1; } f();', @@ -169,7 +166,7 @@ def test_arrow_in_return_not_proxy(self): ) assert changed is False - def test_list_child_with_disallowed_type(self): + def test_list_child_with_disallowed_type(self) -> None: """Line 157: _is_proxy_value rejects disallowed type in list child.""" # Array with a function expression as element code, changed = roundtrip( @@ -178,10 +175,8 @@ def test_list_child_with_disallowed_type(self): ) assert changed is False - def test_get_replacement_body_none(self): + def test_get_replacement_body_none(self) -> None: """Line 166: _get_replacement when body is None returns undefined.""" - from pyjsclear.parser import parse - ast = parse('function f() { return 1; } f();') t = ProxyFunctionInliner(ast) # Manually create a func_node with no body @@ -190,10 +185,8 @@ def test_get_replacement_body_none(self): assert result is not None assert result.get('name') == 'undefined' - def test_get_replacement_block_empty(self): + def test_get_replacement_block_empty(self) -> None: """Line 174: _get_replacement with empty block body returns None.""" - from pyjsclear.parser import parse - ast = parse('var x = 1;') t = ProxyFunctionInliner(ast) func_node = { @@ -204,10 +197,8 @@ def test_get_replacement_block_empty(self): result = t._get_replacement(func_node, []) assert result is None - def test_get_replacement_block_non_return(self): + def test_get_replacement_block_non_return(self) -> None: """Line 174: _get_replacement with block body that starts with non-return.""" - from pyjsclear.parser import parse - ast = parse('var x = 1;') t = ProxyFunctionInliner(ast) func_node = { @@ -221,10 +212,8 @@ def test_get_replacement_block_non_return(self): result = t._get_replacement(func_node, []) assert result is None - def test_get_replacement_not_block_not_arrow(self): + def test_get_replacement_not_block_not_arrow(self) -> None: """Line 180: _get_replacement with non-block, non-arrow body returns None.""" - from pyjsclear.parser import parse - ast = parse('var x = 1;') t = ProxyFunctionInliner(ast) func_node = { @@ -235,10 +224,8 @@ def test_get_replacement_not_block_not_arrow(self): result = t._get_replacement(func_node, []) assert result is None - def test_get_replacement_return_no_argument(self): + def test_get_replacement_return_no_argument(self) -> None: """Line 177: _get_replacement with return but no argument gives undefined.""" - from pyjsclear.parser import parse - ast = parse('var x = 1;') t = ProxyFunctionInliner(ast) func_node = { @@ -253,20 +240,16 @@ def test_get_replacement_return_no_argument(self): assert result is not None assert result.get('name') == 'undefined' - def test_is_proxy_value_non_dict(self): + def test_is_proxy_value_non_dict(self) -> None: """Line 148: _is_proxy_value with non-dict returns False.""" - from pyjsclear.parser import parse - ast = parse('var x = 1;') t = ProxyFunctionInliner(ast) assert t._is_proxy_value('not_a_dict') is False assert t._is_proxy_value(None) is False assert t._is_proxy_value(42) is False - def test_function_non_block_body_not_arrow(self): + def test_function_non_block_body_not_arrow(self) -> None: """Line 132: body not BlockStatement and func is not ArrowFunction.""" - from pyjsclear.parser import parse - ast = parse('function f(a) { return a; } f(1);') func = ast['body'][0] # Replace body with a non-block to trigger line 132 @@ -276,10 +259,8 @@ def test_function_non_block_body_not_arrow(self): # Should not crash; function is not a proxy since body is not block or arrow expr assert not changed - def test_is_proxy_value_child_dict_disallowed(self): + def test_is_proxy_value_child_dict_disallowed(self) -> None: """Line 158-159: _is_proxy_value child dict with disallowed type.""" - from pyjsclear.parser import parse - ast = parse('var x = 1;') t = ProxyFunctionInliner(ast) # A node containing a child dict with a disallowed type diff --git a/tests/unit/transforms/reassignment_test.py b/tests/unit/transforms/reassignment_test.py index 5bcece6..2e7c253 100644 --- a/tests/unit/transforms/reassignment_test.py +++ b/tests/unit/transforms/reassignment_test.py @@ -1,7 +1,5 @@ """Unit tests for ReassignmentRemover transform.""" -import pytest - from pyjsclear.transforms.reassignment import ReassignmentRemover from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip @@ -10,37 +8,37 @@ class TestReassignmentRemoverDeclaratorInline: """Tests for var x = y; declarator-based inlining.""" - def test_well_known_global_json(self): + def test_well_known_global_json(self) -> None: code = 'var x = JSON; x.parse(y);' result, changed = roundtrip(code, ReassignmentRemover) assert changed is True assert 'JSON.parse(y)' in normalize(result) assert 'var x' not in normalize(result) or 'x.parse' not in normalize(result) - def test_well_known_global_console(self): + def test_well_known_global_console(self) -> None: code = 'var x = console; x.log("hi");' result, changed = roundtrip(code, ReassignmentRemover) assert changed is True assert 'console.log("hi")' in normalize(result) - def test_constant_binding_inline(self): + def test_constant_binding_inline(self) -> None: code = 'const a = 1; var b = a; c(b);' result, changed = roundtrip(code, ReassignmentRemover) assert changed is True assert 'c(a)' in normalize(result) - def test_unknown_target_unchanged(self): + def test_unknown_target_unchanged(self) -> None: code = 'var x = y; x.foo();' result, changed = roundtrip(code, ReassignmentRemover) assert changed is False assert 'x.foo()' in normalize(result) - def test_self_assignment_skipped(self): + def test_self_assignment_skipped(self) -> None: code = 'var x = x;' result, changed = roundtrip(code, ReassignmentRemover) assert changed is False - def test_non_constant_target_unchanged(self): + def test_non_constant_target_unchanged(self) -> None: code = 'var y = 1; y = 2; var x = y; x.foo();' result, changed = roundtrip(code, ReassignmentRemover) assert changed is False @@ -50,7 +48,7 @@ def test_non_constant_target_unchanged(self): class TestReassignmentRemoverAssignmentAlias: """Tests for var x; ... x = y; assignment alias inlining.""" - def test_assignment_alias_console(self): + def test_assignment_alias_console(self) -> None: code = 'var x; x = console; x.log("hi");' result, changed = roundtrip(code, ReassignmentRemover) assert changed is True @@ -58,19 +56,19 @@ def test_assignment_alias_console(self): # The assignment statement should be removed assert 'x = console' not in normalize(result) - def test_assignment_alias_json(self): + def test_assignment_alias_json(self) -> None: code = 'var x; x = JSON; x.parse(s);' result, changed = roundtrip(code, ReassignmentRemover) assert changed is True assert 'JSON.parse(s)' in normalize(result) - def test_assignment_alias_unknown_target_unchanged(self): + def test_assignment_alias_unknown_target_unchanged(self) -> None: code = 'var x; x = y; x.foo();' result, changed = roundtrip(code, ReassignmentRemover) # y is not well-known and not a constant binding, so no change assert 'x.foo()' in normalize(result) - def test_assignment_alias_non_constant_target_unchanged(self): + def test_assignment_alias_non_constant_target_unchanged(self) -> None: code = 'var y = 1; y = 2; var x; x = y; x.foo();' result, changed = roundtrip(code, ReassignmentRemover) assert 'x.foo()' in normalize(result) @@ -79,12 +77,12 @@ def test_assignment_alias_non_constant_target_unchanged(self): class TestReassignmentRemoverNoChange: """Tests for cases where nothing changes.""" - def test_no_reassignments_returns_false(self): + def test_no_reassignments_returns_false(self) -> None: code = 'var a = 1; console.log(a);' result, changed = roundtrip(code, ReassignmentRemover) assert changed is False - def test_param_not_inlined(self): + def test_param_not_inlined(self) -> None: code = 'function f(x) { var y = x; return y; }' result, changed = roundtrip(code, ReassignmentRemover) # x is a param binding; it is constant, so inlining y=x should work @@ -92,7 +90,7 @@ def test_param_not_inlined(self): # Let's just verify it doesn't crash assert isinstance(changed, bool) - def test_no_variable_declarations(self): + def test_no_variable_declarations(self) -> None: code = 'console.log(1);' result, changed = roundtrip(code, ReassignmentRemover) assert changed is False @@ -101,28 +99,28 @@ def test_no_variable_declarations(self): class TestReassignmentRemoverEdgeCases: """Edge case tests.""" - def test_multiple_well_known_globals(self): + def test_multiple_well_known_globals(self) -> None: code = 'var a = Object; var b = Array; a.keys(x); b.isArray(y);' result, changed = roundtrip(code, ReassignmentRemover) assert changed is True assert 'Object.keys(x)' in normalize(result) assert 'Array.isArray(y)' in normalize(result) - def test_rebuild_scope_flag(self): + def test_rebuild_scope_flag(self) -> None: assert ReassignmentRemover.rebuild_scope is True class TestReassignmentRemoverSkipConditions: """Tests for skip conditions on reference replacement.""" - def test_reference_as_assignment_lhs_skipped(self): + def test_reference_as_assignment_lhs_skipped(self) -> None: """Line 93: Reference used as assignment left-hand side should be skipped.""" code = 'var x = console; x = something; x.log("hi");' result, changed = roundtrip(code, ReassignmentRemover) # x has writes so it is not constant, no inlining should happen assert isinstance(changed, bool) - def test_reference_as_declarator_id_skipped(self): + def test_reference_as_declarator_id_skipped(self) -> None: """Line 95: Reference used as VariableDeclarator id should be skipped.""" # When a variable is reassigned via declarator pattern, the id ref should be skipped code = 'var x = JSON; x.parse(s);' @@ -134,7 +132,7 @@ def test_reference_as_declarator_id_skipped(self): class TestReassignmentRemoverAssignmentAliasEdgeCases: """Tests for assignment alias edge cases.""" - def test_assignment_alias_with_init_skipped(self): + def test_assignment_alias_with_init_skipped(self) -> None: """Line 131: VariableDeclarator with init should be skipped for assignment alias.""" code = 'var _0x1 = undefined; _0x1 = console; _0x1.log("hi");' result, changed = roundtrip(code, ReassignmentRemover) @@ -142,34 +140,34 @@ def test_assignment_alias_with_init_skipped(self): # but it might be handled by the declarator path instead assert isinstance(changed, bool) - def test_assignment_alias_multiple_writes_skipped(self): + def test_assignment_alias_multiple_writes_skipped(self) -> None: """Line 151: Assignment alias with != 1 writes should be skipped.""" code = 'var _0x1; _0x1 = console; _0x1 = JSON; _0x1.log("hi");' result, changed = roundtrip(code, ReassignmentRemover) # Two writes means it won't be inlined via assignment alias assert '_0x1' in result or changed is False - def test_assignment_alias_rhs_not_identifier(self): + def test_assignment_alias_rhs_not_identifier(self) -> None: """Line 157: Assignment alias where right side is not an identifier.""" code = 'var _0x1; _0x1 = 123; console.log(_0x1);' result, changed = roundtrip(code, ReassignmentRemover) # RHS is a literal, not an identifier — should be skipped assert '_0x1' in result - def test_assignment_alias_self_assignment_skipped(self): + def test_assignment_alias_self_assignment_skipped(self) -> None: """Line 161: target_name equals name should be skipped.""" code = 'var _0x1; _0x1 = _0x1;' result, changed = roundtrip(code, ReassignmentRemover) assert changed is False - def test_assignment_alias_replacement_with_index(self): + def test_assignment_alias_replacement_with_index(self) -> None: """Line 175: Assignment alias replacement in array position (with index).""" code = 'var _0x1; _0x1 = console; foo([_0x1]);' result, changed = roundtrip(code, ReassignmentRemover) assert changed is True assert 'console' in result - def test_assignment_alias_pattern(self): + def test_assignment_alias_pattern(self) -> None: """Assignment alias: var x; x = console; log(x);""" code = 'var _0x1 = undefined; var _0x2; _0x2 = console; _0x2.log("hi");' result, changed = roundtrip(code, ReassignmentRemover) diff --git a/tests/unit/transforms/require_inliner_test.py b/tests/unit/transforms/require_inliner_test.py index ed17858..cb892b0 100644 --- a/tests/unit/transforms/require_inliner_test.py +++ b/tests/unit/transforms/require_inliner_test.py @@ -7,7 +7,7 @@ class TestRequirePolyfillDetection: """Tests for detecting and inlining require polyfill wrappers.""" - def test_typeof_require_polyfill(self): + def test_typeof_require_polyfill(self) -> None: code = ''' var _0x544bfe = (function() { return typeof require !== "undefined" ? require : null; })(); _0x544bfe("fs"); @@ -18,17 +18,17 @@ def test_typeof_require_polyfill(self): assert 'require("fs")' in result assert 'require("path")' in result - def test_preserves_non_polyfill_calls(self): + def test_preserves_non_polyfill_calls(self) -> None: """Regular function calls should not be changed.""" result, changed = roundtrip('myFunc("fs");', RequireInliner) assert changed is False assert 'require' not in result - def test_no_polyfills_returns_false(self): + def test_no_polyfills_returns_false(self) -> None: result, changed = roundtrip('var x = 1;', RequireInliner) assert changed is False - def test_multi_arg_call_unchanged(self): + def test_multi_arg_call_unchanged(self) -> None: """Polyfill calls with != 1 arg should not be replaced.""" code = ''' var _0x544bfe = (function() { return typeof require !== "undefined" ? require : null; })(); diff --git a/tests/unit/transforms/sequence_splitter_test.py b/tests/unit/transforms/sequence_splitter_test.py index c91aaa1..43a6b50 100644 --- a/tests/unit/transforms/sequence_splitter_test.py +++ b/tests/unit/transforms/sequence_splitter_test.py @@ -1,18 +1,18 @@ -import pytest - +from pyjsclear.generator import generate +from pyjsclear.parser import parse from pyjsclear.transforms.sequence_splitter import SequenceSplitter from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip class TestSequenceSplittingInExpressionStatements: - def test_splits_sequence_into_separate_statements(self): + def test_splits_sequence_into_separate_statements(self) -> None: code = '(a(), b(), c());' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('a(); b(); c();') - def test_splits_two_element_sequence(self): + def test_splits_two_element_sequence(self) -> None: code = '(a(), b());' result, changed = roundtrip(code, SequenceSplitter) assert changed is True @@ -20,25 +20,25 @@ def test_splits_two_element_sequence(self): class TestMultiVarSplitting: - def test_splits_multi_declarator_var(self): + def test_splits_multi_declarator_var(self) -> None: code = 'var a = 1, b = 2;' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('var a = 1; var b = 2;') - def test_splits_three_declarators(self): + def test_splits_three_declarators(self) -> None: code = 'var a = 1, b = 2, c = 3;' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('var a = 1; var b = 2; var c = 3;') - def test_preserves_kind_let(self): + def test_preserves_kind_let(self) -> None: code = 'let a = 1, b = 2;' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('let a = 1; let b = 2;') - def test_preserves_kind_const(self): + def test_preserves_kind_const(self) -> None: code = 'const a = 1, b = 2;' result, changed = roundtrip(code, SequenceSplitter) assert changed is True @@ -46,13 +46,13 @@ def test_preserves_kind_const(self): class TestSingleDeclaratorSequenceInitSplitting: - def test_splits_sequence_in_var_init(self): + def test_splits_sequence_in_var_init(self) -> None: code = 'var x = (a, b, expr());' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('a; b; var x = expr();') - def test_splits_two_element_sequence_in_init(self): + def test_splits_two_element_sequence_in_init(self) -> None: code = 'var x = (a, expr());' result, changed = roundtrip(code, SequenceSplitter) assert changed is True @@ -60,13 +60,13 @@ def test_splits_two_element_sequence_in_init(self): class TestIndirectCallPrefixExtraction: - def test_extracts_zero_prefix(self): + def test_extracts_zero_prefix(self) -> None: code = '(0, fn)(args);' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('0; fn(args);') - def test_extracts_multiple_prefixes(self): + def test_extracts_multiple_prefixes(self) -> None: code = '(0, 1, fn)(args);' result, changed = roundtrip(code, SequenceSplitter) assert changed is True @@ -74,19 +74,19 @@ def test_extracts_multiple_prefixes(self): class TestBodyNormalization: - def test_if_body_normalized_to_block(self): + def test_if_body_normalized_to_block(self) -> None: code = 'if (x) y();' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('if (x) { y(); }') - def test_while_body_normalized_to_block(self): + def test_while_body_normalized_to_block(self) -> None: code = 'while (x) y();' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('while (x) { y(); }') - def test_for_body_normalized_to_block(self): + def test_for_body_normalized_to_block(self) -> None: code = 'for (;;) y();' result, changed = roundtrip(code, SequenceSplitter) assert changed is True @@ -94,13 +94,13 @@ def test_for_body_normalized_to_block(self): class TestIfBranchNormalization: - def test_alternate_normalized(self): + def test_alternate_normalized(self) -> None: code = 'if (x) { y(); } else z();' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('if (x) { y(); } else { z(); }') - def test_else_if_not_wrapped(self): + def test_else_if_not_wrapped(self) -> None: code = 'if (x) { y(); } else if (z) { w(); }' result, changed = roundtrip(code, SequenceSplitter) # else-if should not be wrapped in a block @@ -108,38 +108,38 @@ def test_else_if_not_wrapped(self): class TestNoChangeWhenNothingToSplit: - def test_single_expression_statement(self): + def test_single_expression_statement(self) -> None: code = 'a();' result, changed = roundtrip(code, SequenceSplitter) assert changed is False assert normalize(result) == normalize('a();') - def test_single_var_declaration(self): + def test_single_var_declaration(self) -> None: code = 'var a = 1;' result, changed = roundtrip(code, SequenceSplitter) assert changed is False assert normalize(result) == normalize('var a = 1;') - def test_already_block_if(self): + def test_already_block_if(self) -> None: code = 'if (x) { y(); }' result, changed = roundtrip(code, SequenceSplitter) assert changed is False assert normalize(result) == normalize('if (x) { y(); }') - def test_empty_program(self): + def test_empty_program(self) -> None: code = '' result, changed = roundtrip(code, SequenceSplitter) assert changed is False class TestAwaitWrappedSequenceSplitting: - def test_splits_await_sequence_in_var_init(self): + def test_splits_await_sequence_in_var_init(self) -> None: code = 'async function f() { var x = await (a, b, expr()); }' result, changed = roundtrip(code, SequenceSplitter) assert changed is True assert normalize(result) == normalize('async function f() { a; b; var x = await expr(); }') - def test_splits_await_two_element_sequence(self): + def test_splits_await_two_element_sequence(self) -> None: code = 'async function f() { var x = await (a, expr()); }' result, changed = roundtrip(code, SequenceSplitter) assert changed is True @@ -149,14 +149,14 @@ def test_splits_await_two_element_sequence(self): class TestSequenceSplitterEdgeCases: """Tests for uncovered edge cases.""" - def test_non_dict_in_body_arrays(self): + def test_non_dict_in_body_arrays(self) -> None: """Line 59: Non-dict in _split_in_body_arrays should be skipped gracefully.""" # This is handled internally; just ensure no crash with normal code code = 'a(); b();' result, changed = roundtrip(code, SequenceSplitter) assert isinstance(changed, bool) - def test_return_indirect_call(self): + def test_return_indirect_call(self) -> None: """Line 107: ReturnStatement with indirect call: return (0, fn)(args).""" code = 'function f() { return (0, g)("hello"); }' result, changed = roundtrip(code, SequenceSplitter) @@ -165,7 +165,7 @@ def test_return_indirect_call(self): assert '0' in norm assert 'g("hello")' in norm - def test_assignment_indirect_call(self): + def test_assignment_indirect_call(self) -> None: """Line 113: Assignment with indirect call: x = (0, fn)(args).""" code = 'var x; x = (0, g)("hello");' result, changed = roundtrip(code, SequenceSplitter) @@ -174,21 +174,21 @@ def test_assignment_indirect_call(self): assert '0' in norm assert 'g("hello")' in norm - def test_non_dict_statement_in_process_stmt_array(self): + def test_non_dict_statement_in_process_stmt_array(self) -> None: """Lines 123-124: Non-dict statement in _process_stmt_array should be skipped.""" # Internal handling; verify no crash with various patterns code = 'var a = 1;' result, changed = roundtrip(code, SequenceSplitter) assert changed is False - def test_non_dict_init_in_single_declarator(self): + def test_non_dict_init_in_single_declarator(self) -> None: """Line 189: _try_split_single_declarator_init with non-dict init.""" # Variable with no init (init is None, not a dict) code = 'var x;' result, changed = roundtrip(code, SequenceSplitter) assert changed is False - def test_sequence_callee_with_single_expression(self): + def test_sequence_callee_with_single_expression(self) -> None: """Line 96: SequenceExpression callee with <=1 expression should not be extracted.""" # Construct a case where there's a single-element sequence callee # This is a degenerate case; normal JS won't produce it, but we can test @@ -197,28 +197,28 @@ def test_sequence_callee_with_single_expression(self): result, changed = roundtrip(code, SequenceSplitter) assert changed is False - def test_direct_sequence_init_single_expression(self): + def test_direct_sequence_init_single_expression(self) -> None: """Line 195: Direct SequenceExpression init with <=1 expression should return None.""" # A single-element sequence is degenerate; just test normal single-init var code = 'var x = 1;' result, changed = roundtrip(code, SequenceSplitter) assert changed is False - def test_await_wrapped_sequence_single_expression(self): + def test_await_wrapped_sequence_single_expression(self) -> None: """Line 209: Await-wrapped sequence with <=1 expression should return None.""" code = 'async function f() { var x = await expr(); }' result, changed = roundtrip(code, SequenceSplitter) # No sequence to split, just a simple await assert normalize(result) == normalize('async function f() { var x = await expr(); }') - def test_extract_from_call_non_dict(self): + def test_extract_from_call_non_dict(self) -> None: """Line 85: Non-dict in extract_from_call should be skipped.""" # When expression is not a dict (e.g., expression missing), no crash code = ';' result, changed = roundtrip(code, SequenceSplitter) assert isinstance(changed, bool) - def test_var_declaration_indirect_call_in_init(self): + def test_var_declaration_indirect_call_in_init(self) -> None: """VariableDeclaration path for extracting indirect call prefixes.""" code = 'var x = (0, fn)("hello");' result, changed = roundtrip(code, SequenceSplitter) @@ -231,11 +231,8 @@ def test_var_declaration_indirect_call_in_init(self): class TestSequenceSplitterDirectASTCoverage: """Tests using direct AST manipulation to hit remaining uncovered lines.""" - def test_sequence_callee_with_single_expression(self): + def test_sequence_callee_with_single_expression(self) -> None: """Line 95-96: SequenceExpression callee with <=1 expression.""" - from pyjsclear.generator import generate - from pyjsclear.parser import parse - ast = parse('fn("hello");') # Manually wrap callee in a SequenceExpression with 1 element call_expr = ast['body'][0]['expression'] @@ -249,10 +246,8 @@ def test_sequence_callee_with_single_expression(self): # Single-element sequence callee should not be extracted # (but body normalization may still trigger changes) - def test_single_element_await_sequence(self): + def test_single_element_await_sequence(self) -> None: """Line 208-209: single-element await sequence returns None.""" - from pyjsclear.parser import parse - ast = parse('async function f() { var x = await expr(); }') # Find the var declaration inside the function body func_body = ast['body'][0]['body']['body'] diff --git a/tests/unit/transforms/single_use_vars_test.py b/tests/unit/transforms/single_use_vars_test.py index e229d33..cb4bbc5 100644 --- a/tests/unit/transforms/single_use_vars_test.py +++ b/tests/unit/transforms/single_use_vars_test.py @@ -7,7 +7,7 @@ class TestRequireInlining: """Tests for single-use require() inlining.""" - def test_simple_require_inlined(self): + def test_simple_require_inlined(self) -> None: code = ''' function f() { const x = require("fs"); @@ -19,7 +19,7 @@ def test_simple_require_inlined(self): assert 'require("fs").readFileSync' in result assert 'const x' not in result - def test_require_member_access(self): + def test_require_member_access(self) -> None: code = ''' function f() { const proc = require("process"); @@ -30,7 +30,7 @@ def test_require_member_access(self): assert changed is True assert 'require("process").env.HOME' in result - def test_var_require_inlined(self): + def test_var_require_inlined(self) -> None: code = ''' function f() { var x = require("path"); @@ -46,7 +46,7 @@ def test_var_require_inlined(self): class TestExpressionInlining: """Tests for single-use non-require expression inlining.""" - def test_property_access_inlined(self): + def test_property_access_inlined(self) -> None: code = ''' function f() { const x = obj.prop; @@ -58,7 +58,7 @@ def test_property_access_inlined(self): assert 'obj.prop.foo' in result assert 'const x' not in result - def test_method_call_inlined(self): + def test_method_call_inlined(self) -> None: code = ''' function f(arr) { const x = Buffer.from(arr); @@ -69,7 +69,7 @@ def test_method_call_inlined(self): assert changed is True assert 'Buffer.from(arr).toString()' in result - def test_new_expression_inlined(self): + def test_new_expression_inlined(self) -> None: code = ''' function f() { const d = new Date(); @@ -80,7 +80,7 @@ def test_new_expression_inlined(self): assert changed is True assert 'new Date().getTime()' in result - def test_string_literal_inlined(self): + def test_string_literal_inlined(self) -> None: code = ''' function f() { const url = "https://example.com"; @@ -92,7 +92,7 @@ def test_string_literal_inlined(self): assert 'fetch("https://example.com")' in result assert 'const url' not in result - def test_simple_call_inlined(self): + def test_simple_call_inlined(self) -> None: code = ''' function f(x) { const n = parseInt(x); @@ -107,7 +107,7 @@ def test_simple_call_inlined(self): class TestNoInlining: """Tests where inlining should NOT occur.""" - def test_multi_use_not_inlined(self): + def test_multi_use_not_inlined(self) -> None: code = ''' function f() { const fs = require("fs"); @@ -118,7 +118,7 @@ def test_multi_use_not_inlined(self): result, changed = roundtrip(code, SingleUseVarInliner) assert changed is False - def test_reassigned_var_not_inlined(self): + def test_reassigned_var_not_inlined(self) -> None: code = ''' function f() { let x = require("fs"); @@ -129,11 +129,11 @@ def test_reassigned_var_not_inlined(self): result, changed = roundtrip(code, SingleUseVarInliner) assert changed is False - def test_no_init_returns_false(self): + def test_no_init_returns_false(self) -> None: result, changed = roundtrip('var x;', SingleUseVarInliner) assert changed is False - def test_large_init_not_inlined(self): + def test_large_init_not_inlined(self) -> None: """Init expressions with too many AST nodes should not be inlined.""" # Build a deeply nested expression that exceeds the node limit code = ''' @@ -145,7 +145,7 @@ def test_large_init_not_inlined(self): result, changed = roundtrip(code, SingleUseVarInliner) assert changed is False - def test_assignment_target_not_inlined(self): + def test_assignment_target_not_inlined(self) -> None: code = ''' function f() { const x = obj.prop; @@ -155,7 +155,7 @@ def test_assignment_target_not_inlined(self): result, changed = roundtrip(code, SingleUseVarInliner) assert changed is False - def test_mutated_member_not_inlined(self): + def test_mutated_member_not_inlined(self) -> None: """var x = {}; x[key] = val; should NOT inline to {}[key] = val.""" code = ''' function f() { @@ -167,7 +167,7 @@ def test_mutated_member_not_inlined(self): assert changed is False assert 'var x' in result - def test_mutated_dot_member_not_inlined(self): + def test_mutated_dot_member_not_inlined(self) -> None: """var x = {}; x.foo = val; should NOT inline.""" code = ''' function f() { @@ -182,7 +182,7 @@ def test_mutated_dot_member_not_inlined(self): class TestNestedContexts: """Tests for inlining in nested scopes (classes, closures).""" - def test_class_method_inlined(self): + def test_class_method_inlined(self) -> None: code = ''' var Cls = class { static method() { @@ -195,7 +195,7 @@ def test_class_method_inlined(self): assert changed is True assert 'require("process").env.HOME' in result - def test_nested_function_inlined(self): + def test_nested_function_inlined(self) -> None: code = ''' function outer() { function inner() { @@ -208,7 +208,7 @@ def test_nested_function_inlined(self): assert changed is True assert 'require("fs").existsSync' in result - def test_multiple_scopes_inlined(self): + def test_multiple_scopes_inlined(self) -> None: code = ''' function a() { const x = obj.foo; diff --git a/tests/unit/transforms/string_revealer_test.py b/tests/unit/transforms/string_revealer_test.py index 7931105..2f1ad00 100644 --- a/tests/unit/transforms/string_revealer_test.py +++ b/tests/unit/transforms/string_revealer_test.py @@ -11,6 +11,8 @@ from pyjsclear.transforms.string_revealer import _js_parse_int from pyjsclear.transforms.string_revealer import _resolve_arg_value from pyjsclear.transforms.string_revealer import _resolve_string_arg +from pyjsclear.utils.string_decoders import Base64StringDecoder +from pyjsclear.utils.string_decoders import BasicStringDecoder from tests.unit.conftest import normalize from tests.unit.conftest import parse_expr from tests.unit.conftest import roundtrip @@ -24,59 +26,59 @@ class TestEvalNumeric: """Tests for the _eval_numeric helper.""" - def test_integer_literal(self): + def test_integer_literal(self) -> None: node = parse_expr('42') assert _eval_numeric(node) == 42 - def test_float_literal(self): + def test_float_literal(self) -> None: node = parse_expr('3.14') assert _eval_numeric(node) == pytest.approx(3.14) - def test_unary_negative(self): + def test_unary_negative(self) -> None: node = parse_expr('-7') assert _eval_numeric(node) == -7 - def test_unary_positive(self): + def test_unary_positive(self) -> None: node = parse_expr('+5') assert _eval_numeric(node) == 5 - def test_binary_addition(self): + def test_binary_addition(self) -> None: node = parse_expr('3 + 4') assert _eval_numeric(node) == 7 - def test_binary_subtraction(self): + def test_binary_subtraction(self) -> None: node = parse_expr('10 - 3') assert _eval_numeric(node) == 7 - def test_binary_multiplication(self): + def test_binary_multiplication(self) -> None: node = parse_expr('6 * 7') assert _eval_numeric(node) == 42 - def test_binary_division(self): + def test_binary_division(self) -> None: node = parse_expr('20 / 4') assert _eval_numeric(node) == 5.0 - def test_division_by_zero_returns_none(self): + def test_division_by_zero_returns_none(self) -> None: node = parse_expr('1 / 0') assert _eval_numeric(node) is None - def test_nested_expression(self): + def test_nested_expression(self) -> None: node = parse_expr('(2 + 3) * 4') assert _eval_numeric(node) == 20 - def test_string_literal_returns_none(self): + def test_string_literal_returns_none(self) -> None: node = parse_expr('"hello"') assert _eval_numeric(node) is None - def test_identifier_returns_none(self): + def test_identifier_returns_none(self) -> None: node = parse_expr('x') assert _eval_numeric(node) is None - def test_non_dict_returns_none(self): + def test_non_dict_returns_none(self) -> None: assert _eval_numeric(None) is None assert _eval_numeric(42) is None - def test_unsupported_operator_returns_none(self): + def test_unsupported_operator_returns_none(self) -> None: node = parse_expr('2 << 3') assert _eval_numeric(node) is None @@ -89,34 +91,34 @@ def test_unsupported_operator_returns_none(self): class TestJsParseInt: """Tests for the _js_parse_int helper.""" - def test_pure_integer(self): + def test_pure_integer(self) -> None: assert _js_parse_int('123') == 123 - def test_leading_digits_with_trailing_chars(self): + def test_leading_digits_with_trailing_chars(self) -> None: assert _js_parse_int('12abc') == 12 - def test_no_leading_digits_returns_nan(self): + def test_no_leading_digits_returns_nan(self) -> None: result = _js_parse_int('abc') assert math.isnan(result) - def test_negative_number(self): + def test_negative_number(self) -> None: assert _js_parse_int('-42') == -42 - def test_positive_sign(self): + def test_positive_sign(self) -> None: assert _js_parse_int('+99') == 99 - def test_whitespace_stripped(self): + def test_whitespace_stripped(self) -> None: assert _js_parse_int(' 56 ') == 56 - def test_empty_string_returns_nan(self): + def test_empty_string_returns_nan(self) -> None: result = _js_parse_int('') assert math.isnan(result) - def test_non_string_returns_nan(self): + def test_non_string_returns_nan(self) -> None: result = _js_parse_int(42) assert math.isnan(result) - def test_none_returns_nan(self): + def test_none_returns_nan(self) -> None: result = _js_parse_int(None) assert math.isnan(result) @@ -129,39 +131,39 @@ def test_none_returns_nan(self): class TestWrapperInfo: """Tests for WrapperInfo.get_effective_index.""" - def test_basic_offset(self): + def test_basic_offset(self) -> None: info = WrapperInfo('w', param_index=0, wrapper_offset=10, func_node={}) assert info.get_effective_index([5]) == 15 - def test_negative_offset(self): + def test_negative_offset(self) -> None: info = WrapperInfo('w', param_index=0, wrapper_offset=-3, func_node={}) assert info.get_effective_index([10]) == 7 - def test_zero_offset(self): + def test_zero_offset(self) -> None: info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={}) assert info.get_effective_index([7]) == 7 - def test_param_index_selects_correct_arg(self): + def test_param_index_selects_correct_arg(self) -> None: info = WrapperInfo('w', param_index=1, wrapper_offset=100, func_node={}) assert info.get_effective_index(['ignored', 5]) == 105 - def test_param_index_out_of_bounds_returns_none(self): + def test_param_index_out_of_bounds_returns_none(self) -> None: info = WrapperInfo('w', param_index=2, wrapper_offset=0, func_node={}) assert info.get_effective_index([1]) is None - def test_non_numeric_arg_returns_none(self): + def test_non_numeric_arg_returns_none(self) -> None: info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={}) assert info.get_effective_index(['not_a_number']) is None - def test_get_key_with_key_param(self): + def test_get_key_with_key_param(self) -> None: info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={}, key_param_index=1) assert info.get_key([10, 'secret']) == 'secret' - def test_get_key_without_key_param(self): + def test_get_key_without_key_param(self) -> None: info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={}) assert info.get_key([10]) is None - def test_get_key_index_out_of_bounds(self): + def test_get_key_index_out_of_bounds(self) -> None: info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={}, key_param_index=5) assert info.get_key([10]) is None @@ -174,14 +176,14 @@ def test_get_key_index_out_of_bounds(self): class TestDirectArrays: """Tests for direct array declaration and replacement.""" - def test_basic_direct_array(self): + def test_basic_direct_array(self) -> None: js = 'var arr = ["hello", "world"]; x(arr[0]); y(arr[1]);' code, changed = roundtrip(js, StringRevealer) assert changed is True assert 'x("hello")' in code assert 'y("world")' in code - def test_direct_array_multiple_accesses(self): + def test_direct_array_multiple_accesses(self) -> None: js = 'var arr = ["a", "b", "c"]; f(arr[0]); g(arr[1]); h(arr[2]);' code, changed = roundtrip(js, StringRevealer) assert changed is True @@ -189,30 +191,30 @@ def test_direct_array_multiple_accesses(self): assert '"b"' in code assert '"c"' in code - def test_direct_array_out_of_bounds_no_replacement(self): + def test_direct_array_out_of_bounds_no_replacement(self) -> None: js = 'var arr = ["hello"]; x(arr[5]);' code, changed = roundtrip(js, StringRevealer) assert changed is False assert 'arr[5]' in code - def test_direct_array_non_numeric_index_no_replacement(self): + def test_direct_array_non_numeric_index_no_replacement(self) -> None: js = 'var arr = ["hello", "world"]; x(arr[y]);' code, changed = roundtrip(js, StringRevealer) assert changed is False assert 'arr[y]' in code - def test_direct_array_single_element(self): + def test_direct_array_single_element(self) -> None: js = 'var arr = ["only"]; x(arr[0]);' code, changed = roundtrip(js, StringRevealer) assert changed is True assert '"only"' in code - def test_direct_array_preserves_non_array_vars(self): + def test_direct_array_preserves_non_array_vars(self) -> None: js = 'var x = 42; console.log(x);' code, changed = roundtrip(js, StringRevealer) assert changed is False - def test_mixed_element_array_not_replaced(self): + def test_mixed_element_array_not_replaced(self) -> None: js = 'var arr = ["hello", 42]; x(arr[0]);' code, changed = roundtrip(js, StringRevealer) assert changed is False @@ -227,12 +229,12 @@ def test_mixed_element_array_not_replaced(self): class TestNoStringArrays: """Tests for code with no string array patterns.""" - def test_empty_program_returns_false(self): + def test_empty_program_returns_false(self) -> None: js = '' _, changed = roundtrip(js, StringRevealer) assert changed is False - def test_non_string_array_returns_false(self): + def test_non_string_array_returns_false(self) -> None: js = 'var arr = [1, 2, 3]; x(arr[0]);' _, changed = roundtrip(js, StringRevealer) assert changed is False @@ -246,7 +248,7 @@ def test_non_string_array_returns_false(self): class TestObfuscatorIoShortArray: """Short string arrays (< 5 elements) should not trigger obfuscator.io strategy.""" - def test_short_array_function_decoded(self): + def test_short_array_function_decoded(self) -> None: # Arrays with >= 2 elements in obfuscator.io function pattern are decoded. js = """ function _0xArr() { @@ -274,28 +276,28 @@ def test_short_array_function_decoded(self): class TestApplyArith: """Tests for the _apply_arith helper.""" - def test_addition(self): + def test_addition(self) -> None: assert _apply_arith('+', 3, 4) == 7 - def test_subtraction(self): + def test_subtraction(self) -> None: assert _apply_arith('-', 10, 3) == 7 - def test_multiplication(self): + def test_multiplication(self) -> None: assert _apply_arith('*', 6, 7) == 42 - def test_division(self): + def test_division(self) -> None: assert _apply_arith('/', 20, 4) == 5.0 - def test_division_by_zero(self): + def test_division_by_zero(self) -> None: assert _apply_arith('/', 1, 0) is None - def test_modulo(self): + def test_modulo(self) -> None: assert _apply_arith('%', 10, 3) == 1 - def test_modulo_by_zero(self): + def test_modulo_by_zero(self) -> None: assert _apply_arith('%', 10, 0) is None - def test_unsupported_operator_returns_none(self): + def test_unsupported_operator_returns_none(self) -> None: assert _apply_arith('**', 2, 3) is None assert _apply_arith('<<', 2, 3) is None assert _apply_arith('>>', 8, 1) is None @@ -310,41 +312,41 @@ def test_unsupported_operator_returns_none(self): class TestCollectObjectLiterals: """Tests for the _collect_object_literals helper.""" - def test_numeric_properties(self): + def test_numeric_properties(self) -> None: ast = parse('var obj = {a: 0x1b1, b: 42};') result = _collect_object_literals(ast) assert ('obj', 'a') in result assert result[('obj', 'a')] == 0x1B1 assert result[('obj', 'b')] == 42 - def test_string_properties(self): + def test_string_properties(self) -> None: ast = parse('var obj = {a: "hello", b: "world"};') result = _collect_object_literals(ast) assert result[('obj', 'a')] == 'hello' assert result[('obj', 'b')] == 'world' - def test_mixed_properties(self): + def test_mixed_properties(self) -> None: ast = parse('var obj = {a: 0x1b1, b: "hello"};') result = _collect_object_literals(ast) assert result[('obj', 'a')] == 0x1B1 assert result[('obj', 'b')] == 'hello' - def test_string_key_properties(self): + def test_string_key_properties(self) -> None: ast = parse('var obj = {"myKey": 42};') result = _collect_object_literals(ast) assert result[('obj', 'myKey')] == 42 - def test_empty_object(self): + def test_empty_object(self) -> None: ast = parse('var obj = {};') result = _collect_object_literals(ast) assert len(result) == 0 - def test_non_object_init_ignored(self): + def test_non_object_init_ignored(self) -> None: ast = parse('var x = 42;') result = _collect_object_literals(ast) assert len(result) == 0 - def test_multiple_objects(self): + def test_multiple_objects(self) -> None: ast = parse('var a = {x: 1}; var b = {y: 2};') result = _collect_object_literals(ast) assert result[('a', 'x')] == 1 @@ -359,23 +361,23 @@ def test_multiple_objects(self): class TestResolveArgValue: """Tests for the _resolve_arg_value helper.""" - def test_numeric_literal(self): + def test_numeric_literal(self) -> None: arg = parse_expr('42') assert _resolve_arg_value(arg, {}) == 42 - def test_string_hex_literal(self): + def test_string_hex_literal(self) -> None: arg = parse_expr('"0x1a"') assert _resolve_arg_value(arg, {}) == 0x1A - def test_string_decimal_literal(self): + def test_string_decimal_literal(self) -> None: arg = parse_expr('"10"') assert _resolve_arg_value(arg, {}) == 10 - def test_string_non_numeric_returns_none(self): + def test_string_non_numeric_returns_none(self) -> None: arg = parse_expr('"hello"') assert _resolve_arg_value(arg, {}) is None - def test_member_expression_numeric(self): + def test_member_expression_numeric(self) -> None: obj_literals = {('obj', 'x'): 0x42} arg = { 'type': 'MemberExpression', @@ -385,7 +387,7 @@ def test_member_expression_numeric(self): } assert _resolve_arg_value(arg, obj_literals) == 0x42 - def test_member_expression_string_numeric(self): + def test_member_expression_string_numeric(self) -> None: obj_literals = {('obj', 'x'): '0x10'} arg = { 'type': 'MemberExpression', @@ -395,7 +397,7 @@ def test_member_expression_string_numeric(self): } assert _resolve_arg_value(arg, obj_literals) == 0x10 - def test_member_expression_string_non_numeric(self): + def test_member_expression_string_non_numeric(self) -> None: obj_literals = {('obj', 'x'): 'hello'} arg = { 'type': 'MemberExpression', @@ -405,7 +407,7 @@ def test_member_expression_string_non_numeric(self): } assert _resolve_arg_value(arg, obj_literals) is None - def test_member_expression_unknown_key(self): + def test_member_expression_unknown_key(self) -> None: obj_literals = {('obj', 'x'): 42} arg = { 'type': 'MemberExpression', @@ -415,7 +417,7 @@ def test_member_expression_unknown_key(self): } assert _resolve_arg_value(arg, obj_literals) is None - def test_computed_member_expression_not_resolved(self): + def test_computed_member_expression_not_resolved(self) -> None: obj_literals = {('obj', 'x'): 42} arg = { 'type': 'MemberExpression', @@ -425,7 +427,7 @@ def test_computed_member_expression_not_resolved(self): } assert _resolve_arg_value(arg, obj_literals) is None - def test_identifier_returns_none(self): + def test_identifier_returns_none(self) -> None: arg = parse_expr('x') assert _resolve_arg_value(arg, {}) is None @@ -438,11 +440,11 @@ def test_identifier_returns_none(self): class TestResolveStringArg: """Tests for the _resolve_string_arg helper.""" - def test_string_literal(self): + def test_string_literal(self) -> None: arg = parse_expr('"hello"') assert _resolve_string_arg(arg, {}) == 'hello' - def test_member_expression_string(self): + def test_member_expression_string(self) -> None: obj_literals = {('obj', 'key'): 'secret'} arg = { 'type': 'MemberExpression', @@ -452,7 +454,7 @@ def test_member_expression_string(self): } assert _resolve_string_arg(arg, obj_literals) == 'secret' - def test_member_expression_numeric_returns_none(self): + def test_member_expression_numeric_returns_none(self) -> None: obj_literals = {('obj', 'key'): 42} arg = { 'type': 'MemberExpression', @@ -462,15 +464,15 @@ def test_member_expression_numeric_returns_none(self): } assert _resolve_string_arg(arg, obj_literals) is None - def test_numeric_literal_returns_none(self): + def test_numeric_literal_returns_none(self) -> None: arg = parse_expr('42') assert _resolve_string_arg(arg, {}) is None - def test_identifier_returns_none(self): + def test_identifier_returns_none(self) -> None: arg = parse_expr('x') assert _resolve_string_arg(arg, {}) is None - def test_member_expression_unknown_returns_none(self): + def test_member_expression_unknown_returns_none(self) -> None: arg = { 'type': 'MemberExpression', 'computed': False, @@ -488,7 +490,7 @@ def test_member_expression_unknown_returns_none(self): class TestVarArrayPattern: """Tests for var-based string array with rotation and decoder (Strategy 2b).""" - def test_var_array_with_rotation_and_decoder(self): + def test_var_array_with_rotation_and_decoder(self) -> None: js = """ var _0xarr = ['hello', 'world', 'foo', 'bar', 'baz', 'qux']; (function(arr, count) { @@ -514,7 +516,7 @@ def test_var_array_with_rotation_and_decoder(self): # After rotation 2: ['foo', 'bar', 'baz', 'qux', 'hello', 'world'] assert '"foo"' in code - def test_var_array_without_rotation(self): + def test_var_array_without_rotation(self) -> None: js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; var _0xdec = function(a) { @@ -530,7 +532,7 @@ def test_var_array_without_rotation(self): assert '"hello"' in code assert '"world"' in code - def test_var_array_with_offset(self): + def test_var_array_with_offset(self) -> None: js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; var _0xdec = function(a) { @@ -546,7 +548,7 @@ def test_var_array_with_offset(self): assert '"hello"' in code assert '"world"' in code - def test_var_array_too_short_ignored(self): + def test_var_array_too_short_ignored(self) -> None: # Arrays with < 3 elements should not match _find_var_string_array js = """ var _0xarr = ['hello', 'world']; @@ -571,7 +573,7 @@ def test_var_array_too_short_ignored(self): class TestObfuscatorIoFullPattern: """Tests for the full obfuscator.io string array pattern.""" - def test_basic_obfuscator_io_pattern(self): + def test_basic_obfuscator_io_pattern(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -591,7 +593,7 @@ def test_basic_obfuscator_io_pattern(self): assert '"hello"' in code assert '"world"' in code - def test_obfuscator_io_with_offset(self): + def test_obfuscator_io_with_offset(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -611,7 +613,7 @@ def test_obfuscator_io_with_offset(self): assert '"hello"' in code assert '"world"' in code - def test_obfuscator_io_multiple_calls(self): + def test_obfuscator_io_multiple_calls(self) -> None: js = """ function _0xArr() { var a = ['alpha', 'beta', 'gamma', 'delta', 'epsilon']; @@ -631,7 +633,7 @@ def test_obfuscator_io_multiple_calls(self): assert '"alpha"' in code assert '"epsilon"' in code - def test_obfuscator_io_removes_infrastructure(self): + def test_obfuscator_io_removes_infrastructure(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -650,7 +652,7 @@ def test_obfuscator_io_removes_infrastructure(self): assert '_0xArr' not in code assert '_0xDec' not in code - def test_obfuscator_io_with_wrapper_function(self): + def test_obfuscator_io_with_wrapper_function(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -671,7 +673,7 @@ def test_obfuscator_io_with_wrapper_function(self): assert changed is True assert '"world"' in code - def test_obfuscator_io_wrapper_with_key_param(self): + def test_obfuscator_io_wrapper_with_key_param(self) -> None: # Wrapper that passes two args to decoder (index + key) js = """ function _0xArr() { @@ -695,7 +697,7 @@ def test_obfuscator_io_wrapper_with_key_param(self): assert '"hello"' in code assert '"world"' in code - def test_obfuscator_io_with_decoder_alias(self): + def test_obfuscator_io_with_decoder_alias(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -716,7 +718,7 @@ def test_obfuscator_io_with_decoder_alias(self): assert '"hello"' in code assert '"world"' in code - def test_obfuscator_io_with_transitive_alias(self): + def test_obfuscator_io_with_transitive_alias(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -745,7 +747,7 @@ def test_obfuscator_io_with_transitive_alias(self): class TestObfuscatorIoRotation: """Tests for the obfuscator.io rotation IIFE pattern.""" - def test_rotation_with_while_loop(self): + def test_rotation_with_while_loop(self) -> None: js = """ function _0xArr() { var a = ['100', 'hello', 'world', 'foo', 'bar', 'baz']; @@ -789,7 +791,7 @@ def test_rotation_with_while_loop(self): class TestWrapperAnalysis: """Tests for wrapper function analysis (_analyze_wrapper_expr).""" - def test_var_function_expression_wrapper(self): + def test_var_function_expression_wrapper(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -810,7 +812,7 @@ def test_var_function_expression_wrapper(self): assert changed is True assert '"foo"' in code - def test_arrow_function_wrapper(self): + def test_arrow_function_wrapper(self) -> None: # ArrowFunctionExpression with block body as wrapper js = """ function _0xArr() { @@ -841,7 +843,7 @@ def test_arrow_function_wrapper(self): class TestExtractWrapperOffset: """Tests for wrapper offset extraction patterns.""" - def test_wrapper_with_subtraction_offset(self): + def test_wrapper_with_subtraction_offset(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -862,7 +864,7 @@ def test_wrapper_with_subtraction_offset(self): assert changed is True assert '"hello"' in code - def test_wrapper_with_second_param_index(self): + def test_wrapper_with_second_param_index(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -892,7 +894,7 @@ def test_wrapper_with_second_param_index(self): class TestObjectLiteralResolution: """Tests for resolving member expressions via object literals.""" - def test_decoder_call_with_object_member_arg(self): + def test_decoder_call_with_object_member_arg(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -911,7 +913,7 @@ def test_decoder_call_with_object_member_arg(self): assert changed is True assert '"hello"' in code - def test_wrapper_call_with_object_member_arg(self): + def test_wrapper_call_with_object_member_arg(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -942,20 +944,20 @@ def test_wrapper_call_with_object_member_arg(self): class TestEvalNumericModulo: """Additional _eval_numeric tests for modulo operator.""" - def test_modulo(self): + def test_modulo(self) -> None: node = parse_expr('10 % 3') assert _eval_numeric(node) == 1 - def test_modulo_by_zero(self): + def test_modulo_by_zero(self) -> None: node = parse_expr('10 % 0') assert _eval_numeric(node) is None - def test_unary_unsupported_operator(self): + def test_unary_unsupported_operator(self) -> None: # The ~ operator is unsupported node = parse_expr('~5') assert _eval_numeric(node) is None - def test_deeply_nested_expression(self): + def test_deeply_nested_expression(self) -> None: node = parse_expr('(1 + 2) * (3 - 1) + 4 / 2') assert _eval_numeric(node) == 8.0 @@ -968,7 +970,7 @@ def test_deeply_nested_expression(self): class TestCollectRotationLocals: """Tests for _collect_rotation_locals static method.""" - def test_collects_object_from_iife(self): + def test_collects_object_from_iife(self) -> None: ast = parse( """ (function(arr, stop) { @@ -991,7 +993,7 @@ def test_collects_object_from_iife(self): assert result['J']['S'] == 0xA7 assert result['J']['D'] == 'M8Y&' - def test_empty_iife_returns_empty(self): + def test_empty_iife_returns_empty(self) -> None: ast = parse( """ (function() { @@ -1012,7 +1014,7 @@ def test_empty_iife_returns_empty(self): class TestExpressionFromTryBlock: """Tests for _expression_from_try_block static method.""" - def test_variable_declaration(self): + def test_variable_declaration(self) -> None: ast = parse('var x = 42;') stmt = ast['body'][0] result = StringRevealer._expression_from_try_block(stmt) @@ -1020,7 +1022,7 @@ def test_variable_declaration(self): assert result.get('type') == 'Literal' assert result.get('value') == 42 - def test_assignment_expression(self): + def test_assignment_expression(self) -> None: ast = parse('x = 42;') stmt = ast['body'][0] result = StringRevealer._expression_from_try_block(stmt) @@ -1028,13 +1030,13 @@ def test_assignment_expression(self): assert result.get('type') == 'Literal' assert result.get('value') == 42 - def test_non_matching_returns_none(self): + def test_non_matching_returns_none(self) -> None: ast = parse('if (true) {}') stmt = ast['body'][0] result = StringRevealer._expression_from_try_block(stmt) assert result is None - def test_expression_statement_non_assignment(self): + def test_expression_statement_non_assignment(self) -> None: ast = parse('foo();') stmt = ast['body'][0] result = StringRevealer._expression_from_try_block(stmt) @@ -1049,7 +1051,7 @@ def test_expression_statement_non_assignment(self): class TestDirectArrayEdgeCases: """Additional edge case tests for direct array access replacement.""" - def test_direct_array_in_function_scope(self): + def test_direct_array_in_function_scope(self) -> None: # Direct array strategy only processes the root scope bindings, # so arrays inside function scopes are not replaced. js = """ @@ -1061,7 +1063,7 @@ def test_direct_array_in_function_scope(self): code, changed = roundtrip(js, StringRevealer) assert changed is False - def test_direct_array_not_used_as_member(self): + def test_direct_array_not_used_as_member(self) -> None: # Using arr as a standalone identifier (not arr[N]) should not trigger js = 'var arr = ["hello"]; f(arr);' code, changed = roundtrip(js, StringRevealer) @@ -1076,7 +1078,7 @@ def test_direct_array_not_used_as_member(self): class TestVarPatternWithWrappersAndAliases: """Tests for var-based array with wrappers and alias resolution.""" - def test_var_array_with_decoder_alias(self): + def test_var_array_with_decoder_alias(self) -> None: js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; var _0xdec = function(a) { @@ -1091,7 +1093,7 @@ def test_var_array_with_decoder_alias(self): assert changed is True assert '"hello"' in code - def test_var_array_with_wrapper(self): + def test_var_array_with_wrapper(self) -> None: js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; var _0xdec = function(a) { @@ -1117,7 +1119,7 @@ def test_var_array_with_wrapper(self): class TestHexStringResolution: """Tests for hex string resolution in decoder calls.""" - def test_hex_string_arg_to_decoder(self): + def test_hex_string_arg_to_decoder(self) -> None: js = """ function _0xArr() { var a = ['hello', 'world', 'foo', 'bar', 'baz']; @@ -1148,7 +1150,7 @@ def test_hex_string_arg_to_decoder(self): class TestRotationLogicFull: """Tests for the full rotation pipeline with parseInt-based expressions.""" - def test_rotation_with_direct_decoder_call_in_parseint(self): + def test_rotation_with_direct_decoder_call_in_parseint(self) -> None: """Rotation where parseInt calls the decoder directly (via alias in decoder_aliases).""" js = """ function _0xArr() { @@ -1180,7 +1182,7 @@ def test_rotation_with_direct_decoder_call_in_parseint(self): # '200' is already at position 0, so parseInt('200') == 200 == stop_value assert '"200"' in code - def test_rotation_with_binary_expression(self): + def test_rotation_with_binary_expression(self) -> None: """Rotation with binary expression: parseInt(dec(0)) + parseInt(dec(1)).""" js = """ function _0xArr() { @@ -1212,7 +1214,7 @@ def test_rotation_with_binary_expression(self): # parseInt('100') + parseInt('50') = 150 = stop, no rotation needed assert '"100"' in code - def test_rotation_with_subtraction_expression(self): + def test_rotation_with_subtraction_expression(self) -> None: """Rotation with subtraction: parseInt(dec(0)) - parseInt(dec(1)).""" js = """ function _0xArr() { @@ -1243,7 +1245,7 @@ def test_rotation_with_subtraction_expression(self): assert changed is True assert '"300"' in code - def test_rotation_with_multiplication_expression(self): + def test_rotation_with_multiplication_expression(self) -> None: """Rotation with multiply: parseInt(dec(0)) * parseInt(dec(1)).""" js = """ function _0xArr() { @@ -1273,7 +1275,7 @@ def test_rotation_with_multiplication_expression(self): code, changed = roundtrip(js, StringRevealer) assert changed is True - def test_rotation_with_wrapper_in_parseint(self): + def test_rotation_with_wrapper_in_parseint(self) -> None: """Rotation where parseInt calls a wrapper function.""" js = """ function _0xArr() { @@ -1307,7 +1309,7 @@ def test_rotation_with_wrapper_in_parseint(self): assert changed is True assert '"500"' in code - def test_rotation_needs_one_shift(self): + def test_rotation_needs_one_shift(self) -> None: """Rotation that needs exactly one shift before parseInt matches. Uses a wrapper in the rotation expression so _parse_parseInt_call can match. @@ -1346,7 +1348,7 @@ def test_rotation_needs_one_shift(self): # dec(0) returns '42' assert '"42"' in code - def test_rotation_with_negate_expression(self): + def test_rotation_with_negate_expression(self) -> None: """Rotation with negation: -parseInt(dec(0)) + literal.""" js = """ function _0xArr() { @@ -1378,7 +1380,7 @@ def test_rotation_with_negate_expression(self): # -parseInt('-100') + 200 = -(-100) + 200 = 100 + 200 = 300 = stop_value assert '"-100"' in code - def test_rotation_with_literal_node_in_try(self): + def test_rotation_with_literal_node_in_try(self) -> None: """Rotation expression that is just a literal value (no parseInt call).""" js = """ function _0xArr() { @@ -1415,7 +1417,7 @@ def test_rotation_with_literal_node_in_try(self): class TestRotationArgResolution: """Tests for _resolve_rotation_arg with various argument types.""" - def test_rotation_with_member_expression_arg(self): + def test_rotation_with_member_expression_arg(self) -> None: """Rotation IIFE that has local objects referenced in parseInt args.""" js = """ function _0xArr() { @@ -1447,7 +1449,7 @@ def test_rotation_with_member_expression_arg(self): assert changed is True assert '"300"' in code - def test_rotation_with_string_hex_arg(self): + def test_rotation_with_string_hex_arg(self) -> None: """Rotation with hex string literal as argument to decoder.""" js = """ function _0xArr() { @@ -1487,7 +1489,7 @@ def test_rotation_with_string_hex_arg(self): class TestSequenceExpressionRotation: """Tests for rotation inside a SequenceExpression.""" - def test_rotation_in_sequence_expression_obfuscatorio(self): + def test_rotation_in_sequence_expression_obfuscatorio(self) -> None: """Rotation IIFE as part of a SequenceExpression (obfuscator.io pattern).""" js = """ function _0xArr() { @@ -1518,7 +1520,7 @@ def test_rotation_in_sequence_expression_obfuscatorio(self): assert changed is True assert '"777"' in code - def test_var_rotation_in_sequence_expression(self): + def test_var_rotation_in_sequence_expression(self) -> None: """Var-based rotation IIFE inside a SequenceExpression.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar', 'baz', 'qux']; @@ -1541,7 +1543,7 @@ def test_var_rotation_in_sequence_expression(self): class TestWrapperAnalysisEdgeCases: """Tests for _analyze_wrapper_expr edge cases.""" - def test_wrapper_with_non_block_body_ignored(self): + def test_wrapper_with_non_block_body_ignored(self) -> None: """Function expression with expression body (not BlockStatement) is not a wrapper.""" js = """ function _0xArr() { @@ -1560,7 +1562,7 @@ def test_wrapper_with_non_block_body_ignored(self): assert changed is True assert '"hello"' in code - def test_wrapper_with_multiple_statements_not_wrapper(self): + def test_wrapper_with_multiple_statements_not_wrapper(self) -> None: """Function with more than one statement is not recognized as a wrapper.""" js = """ function _0xArr() { @@ -1586,7 +1588,7 @@ def test_wrapper_with_multiple_statements_not_wrapper(self): # _0xNotWrap(1) should NOT be replaced since it's not a valid wrapper assert '_0xNotWrap' in code - def test_wrapper_non_return_statement_not_wrapper(self): + def test_wrapper_non_return_statement_not_wrapper(self) -> None: """Function with single non-return statement is not a wrapper.""" js = """ function _0xArr() { @@ -1610,7 +1612,7 @@ def test_wrapper_non_return_statement_not_wrapper(self): assert '"hello"' in code assert '_0xNotWrap' in code - def test_wrapper_return_not_call_not_wrapper(self): + def test_wrapper_return_not_call_not_wrapper(self) -> None: """Wrapper that returns a non-call expression is not a wrapper.""" js = """ function _0xArr() { @@ -1634,7 +1636,7 @@ def test_wrapper_return_not_call_not_wrapper(self): assert '"hello"' in code assert '_0xNotWrap' in code - def test_wrapper_calls_wrong_decoder_not_wrapper(self): + def test_wrapper_calls_wrong_decoder_not_wrapper(self) -> None: """Wrapper that calls a different function is not recognized as decoder wrapper.""" js = """ function _0xArr() { @@ -1659,7 +1661,7 @@ def test_wrapper_calls_wrong_decoder_not_wrapper(self): assert '"hello"' in code assert '_0xNotWrap' in code - def test_wrapper_no_call_args_not_wrapper(self): + def test_wrapper_no_call_args_not_wrapper(self) -> None: """Wrapper with no arguments to decoder call is not a wrapper.""" js = """ function _0xArr() { @@ -1682,7 +1684,7 @@ def test_wrapper_no_call_args_not_wrapper(self): assert changed is True assert '"hello"' in code - def test_wrapper_with_non_identifier_first_arg(self): + def test_wrapper_with_non_identifier_first_arg(self) -> None: """Wrapper where first arg to decoder is not a param reference.""" js = """ function _0xArr() { @@ -1706,7 +1708,7 @@ def test_wrapper_with_non_identifier_first_arg(self): assert '"hello"' in code assert '_0xNotWrap' in code - def test_extract_wrapper_offset_non_plus_minus_operator(self): + def test_extract_wrapper_offset_non_plus_minus_operator(self) -> None: """Wrapper arg expression with unsupported operator (e.g. *).""" js = """ function _0xArr() { @@ -1730,7 +1732,7 @@ def test_extract_wrapper_offset_non_plus_minus_operator(self): assert '"hello"' in code assert '_0xNotWrap' in code - def test_extract_wrapper_offset_non_numeric_right(self): + def test_extract_wrapper_offset_non_numeric_right(self) -> None: """Wrapper arg expression p + x where x is not numeric.""" js = """ function _0xArr() { @@ -1755,7 +1757,7 @@ def test_extract_wrapper_offset_non_numeric_right(self): # _0xNotWrap should remain since p + q doesn't have a numeric right side assert '_0xNotWrap' in code - def test_extract_wrapper_offset_left_not_param(self): + def test_extract_wrapper_offset_left_not_param(self) -> None: """Wrapper arg expression where left side of binary is not a param.""" js = """ function _0xArr() { @@ -1789,7 +1791,7 @@ def test_extract_wrapper_offset_left_not_param(self): class TestReplacementEdgeCases: """Edge cases in _replace_all_wrapper_calls and _replace_direct_decoder_calls.""" - def test_wrapper_call_with_insufficient_args(self): + def test_wrapper_call_with_insufficient_args(self) -> None: """Wrapper call with fewer args than expected param_index.""" js = """ function _0xArr() { @@ -1814,7 +1816,7 @@ def test_wrapper_call_with_insufficient_args(self): # _0xWrap() called with no args should remain unreplaced assert '_0xWrap()' in code - def test_decoder_call_with_no_args(self): + def test_decoder_call_with_no_args(self) -> None: """Direct decoder call with no arguments is not replaced.""" js = """ function _0xArr() { @@ -1834,7 +1836,7 @@ def test_decoder_call_with_no_args(self): assert changed is True assert '"hello"' in code - def test_decoder_call_with_unresolvable_arg(self): + def test_decoder_call_with_unresolvable_arg(self) -> None: """Decoder call with variable (not literal) argument is not replaced.""" js = """ function _0xArr() { @@ -1855,7 +1857,7 @@ def test_decoder_call_with_unresolvable_arg(self): assert changed is True assert '"hello"' in code - def test_decoder_call_with_out_of_bounds_index(self): + def test_decoder_call_with_out_of_bounds_index(self) -> None: """Decoder call with an index beyond the array doesn't crash.""" js = """ function _0xArr() { @@ -1877,7 +1879,7 @@ def test_decoder_call_with_out_of_bounds_index(self): # Out of bounds call should remain assert '999' in code - def test_decoder_call_with_string_key_second_arg(self): + def test_decoder_call_with_string_key_second_arg(self) -> None: """Decoder direct call with a second string argument (key).""" js = """ function _0xArr() { @@ -1896,7 +1898,7 @@ def test_decoder_call_with_string_key_second_arg(self): assert changed is True assert '"hello"' in code - def test_wrapper_call_with_object_member_key_arg(self): + def test_wrapper_call_with_object_member_key_arg(self) -> None: """Wrapper call where the key param is resolved via object literal.""" js = """ function _0xArr() { @@ -1928,7 +1930,7 @@ def test_wrapper_call_with_object_member_key_arg(self): class TestVarPatternEdgeCases: """Edge cases for _find_var_string_array, _find_simple_rotation, _find_var_decoder.""" - def test_var_array_not_in_first_three_statements(self): + def test_var_array_not_in_first_three_statements(self) -> None: """Var string array declared after the first 3 statements is not found.""" js = """ var a = 1; @@ -1942,7 +1944,7 @@ def test_var_array_not_in_first_three_statements(self): # Array is at index 3, beyond the first 3 statements checked assert changed is False - def test_var_array_with_non_string_elements(self): + def test_var_array_with_non_string_elements(self) -> None: """Var array with mixed types is not recognized as string array.""" js = """ var _0xarr = ['hello', 42, 'foo', 'bar']; @@ -1952,7 +1954,7 @@ def test_var_array_with_non_string_elements(self): code, changed = roundtrip(js, StringRevealer) assert changed is False - def test_var_array_with_non_identifier_declaration(self): + def test_var_array_with_non_identifier_declaration(self) -> None: """Var declaration with destructuring pattern is not matched.""" js = """ var [a, b] = ['hello', 'world', 'foo', 'bar']; @@ -1962,7 +1964,7 @@ def test_var_array_with_non_identifier_declaration(self): code, changed = roundtrip(js, StringRevealer) assert changed is False - def test_find_var_decoder_function_expression(self): + def test_find_var_decoder_function_expression(self) -> None: """Var decoder as function expression referencing array name.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -1977,7 +1979,7 @@ def test_find_var_decoder_function_expression(self): assert changed is True assert '"hello"' in code - def test_var_decoder_with_no_matching_array_ref(self): + def test_var_decoder_with_no_matching_array_ref(self) -> None: """Decoder function that doesn't reference the array name is not found.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -1991,7 +1993,7 @@ def test_var_decoder_with_no_matching_array_ref(self): code, changed = roundtrip(js, StringRevealer) assert changed is False - def test_var_decoder_not_function_expression(self): + def test_var_decoder_not_function_expression(self) -> None: """Var declaration that is not a function expression is not decoder.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -2002,7 +2004,7 @@ def test_var_decoder_not_function_expression(self): # Direct array strategy should handle this assert isinstance(code, str) - def test_simple_rotation_with_for_statement(self): + def test_simple_rotation_with_for_statement(self) -> None: """Simple rotation pattern matching checks for push/shift in source.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar', 'baz', 'qux']; @@ -2013,7 +2015,7 @@ def test_simple_rotation_with_for_statement(self): code, changed = roundtrip(js, StringRevealer) assert changed is True - def test_simple_rotation_no_push_shift_not_rotation(self): + def test_simple_rotation_no_push_shift_not_rotation(self) -> None: """IIFE without push/shift in source is not recognized as rotation.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -2025,7 +2027,7 @@ def test_simple_rotation_no_push_shift_not_rotation(self): assert changed is True assert '"hello"' in code # No rotation applied - def test_simple_rotation_wrong_array_name(self): + def test_simple_rotation_wrong_array_name(self) -> None: """Rotation IIFE that references different array name is not matched.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -2037,7 +2039,7 @@ def test_simple_rotation_wrong_array_name(self): assert changed is True assert '"hello"' in code # No rotation since IIFE references _0xother - def test_simple_rotation_non_numeric_count(self): + def test_simple_rotation_non_numeric_count(self) -> None: """Rotation IIFE with non-numeric count is not matched.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -2058,13 +2060,13 @@ def test_simple_rotation_non_numeric_count(self): class TestDirectArrayAccessEdgeCases: """Edge cases for _try_replace_array_access and _process_direct_arrays_in_scope.""" - def test_direct_array_non_computed_member(self): + def test_direct_array_non_computed_member(self) -> None: """arr.length style access (non-computed) is not replaced.""" js = 'var arr = ["hello", "world"]; f(arr.length);' code, changed = roundtrip(js, StringRevealer) assert changed is False - def test_direct_array_used_in_child_scope(self): + def test_direct_array_used_in_child_scope(self) -> None: """Direct array used inside a function (child scope).""" js = """ var arr = ["hello", "world"]; @@ -2077,7 +2079,7 @@ def test_direct_array_used_in_child_scope(self): assert changed is True assert '"hello"' in code - def test_direct_array_in_child_scope_no_binding(self): + def test_direct_array_in_child_scope_no_binding(self) -> None: """Child scope that doesn't reference the array leaves it alone.""" js = """ var arr = ["hello", "world"]; @@ -2089,7 +2091,7 @@ def test_direct_array_in_child_scope_no_binding(self): code, changed = roundtrip(js, StringRevealer) assert changed is False - def test_replace_node_in_ast_index_path(self): + def test_replace_node_in_ast_index_path(self) -> None: """Verify replacement works when target is in an array (index != None).""" js = 'var arr = ["a", "b"]; f(arr[0], arr[1]);' code, changed = roundtrip(js, StringRevealer) @@ -2106,14 +2108,14 @@ def test_replace_node_in_ast_index_path(self): class TestFindArrayExpressionInStatement: """Tests for _find_array_expression_in_statement.""" - def test_array_in_variable_declaration(self): + def test_array_in_variable_declaration(self) -> None: ast = parse("var x = [1, 2, 3];") stmt = ast['body'][0] result = StringRevealer._find_array_expression_in_statement(stmt) assert result is not None assert result['type'] == 'ArrayExpression' - def test_array_in_assignment_expression(self): + def test_array_in_assignment_expression(self) -> None: """Array in ExpressionStatement with AssignmentExpression.""" ast = parse("x = [1, 2, 3];") stmt = ast['body'][0] @@ -2121,28 +2123,28 @@ def test_array_in_assignment_expression(self): assert result is not None assert result['type'] == 'ArrayExpression' - def test_assignment_non_array_rhs(self): + def test_assignment_non_array_rhs(self) -> None: """Assignment with non-array right side returns None.""" ast = parse("x = 42;") stmt = ast['body'][0] result = StringRevealer._find_array_expression_in_statement(stmt) assert result is None - def test_non_declaration_non_assignment(self): + def test_non_declaration_non_assignment(self) -> None: """Statement that is neither declaration nor assignment returns None.""" ast = parse("if (true) {}") stmt = ast['body'][0] result = StringRevealer._find_array_expression_in_statement(stmt) assert result is None - def test_variable_declaration_non_array_init(self): + def test_variable_declaration_non_array_init(self) -> None: """Variable declaration with non-array init returns None.""" ast = parse("var x = 42;") stmt = ast['body'][0] result = StringRevealer._find_array_expression_in_statement(stmt) assert result is None - def test_expression_statement_non_assignment(self): + def test_expression_statement_non_assignment(self) -> None: """ExpressionStatement that is not an assignment returns None.""" ast = parse("foo();") stmt = ast['body'][0] @@ -2158,7 +2160,7 @@ def test_expression_statement_non_assignment(self): class TestExtractArrayFromStatement: """Tests for _extract_array_from_statement ExpressionStatement path.""" - def test_array_from_assignment_expression_statement(self): + def test_array_from_assignment_expression_statement(self) -> None: """String array in an assignment expression (not var declaration).""" js = """ function _0xArr() { @@ -2190,11 +2192,11 @@ def test_array_from_assignment_expression_statement(self): class TestEvalNumericBinaryEdge: """Test _eval_numeric with binary expressions producing None children.""" - def test_binary_with_non_numeric_left(self): + def test_binary_with_non_numeric_left(self) -> None: node = parse_expr('"abc" + 1') assert _eval_numeric(node) is None - def test_binary_with_non_numeric_right(self): + def test_binary_with_non_numeric_right(self) -> None: node = parse_expr('1 + "abc"') assert _eval_numeric(node) is None @@ -2207,7 +2209,7 @@ def test_binary_with_non_numeric_right(self): class TestCollectObjectLiteralsEdgeCases: """Edge cases for _collect_object_literals.""" - def test_property_with_non_literal_key(self): + def test_property_with_non_literal_key(self) -> None: """Object with computed key -- esprima parses [x] as Identifier key with computed flag.""" # In esprima, {[x]: 42} still produces a Property with key as Identifier # but the property is marked computed. _collect_object_literals checks @@ -2218,14 +2220,14 @@ def test_property_with_non_literal_key(self): # Numeric key is not an identifier or string literal, so it should be skipped assert ('obj', 0) not in result - def test_property_with_no_key_or_value(self): + def test_property_with_no_key_or_value(self) -> None: """Shorthand property patterns.""" ast = parse('var obj = {a: 42, b: "hi"};') result = _collect_object_literals(ast) assert result[('obj', 'a')] == 42 assert result[('obj', 'b')] == 'hi' - def test_property_non_numeric_non_string_value(self): + def test_property_non_numeric_non_string_value(self) -> None: """Object property with non-literal value is ignored.""" ast = parse('var obj = {a: x};') result = _collect_object_literals(ast) @@ -2240,7 +2242,7 @@ def test_property_non_numeric_non_string_value(self): class TestDecoderTypeDetection: """Tests for base64/RC4 decoder type detection.""" - def test_base64_decoder_detected(self): + def test_base64_decoder_detected(self) -> None: """Decoder function containing base64 alphabet is detected as Base64.""" js = """ function _0xArr() { @@ -2268,7 +2270,7 @@ def test_base64_decoder_detected(self): class TestUpdateAstArray: """Tests for _update_ast_array via rotation that modifies the AST.""" - def test_rotation_updates_ast_array(self): + def test_rotation_updates_ast_array(self) -> None: """Rotation execution should update the AST array elements. Must use a wrapper in the rotation expression so _parse_parseInt_call matches. @@ -2316,7 +2318,7 @@ def test_rotation_updates_ast_array(self): class TestExtractRotationExpression: """Tests for _extract_rotation_expression with various loop types.""" - def test_rotation_with_for_loop(self): + def test_rotation_with_for_loop(self) -> None: """Rotation IIFE with for loop instead of while.""" js = """ function _0xArr() { @@ -2347,7 +2349,7 @@ def test_rotation_with_for_loop(self): assert changed is True assert '"500"' in code - def test_rotation_with_empty_func_body(self): + def test_rotation_with_empty_func_body(self) -> None: """Rotation IIFE with empty body produces no rotation expression.""" js = """ function _0xArr() { @@ -2368,7 +2370,7 @@ def test_rotation_with_empty_func_body(self): assert changed is True assert '"hello"' in code - def test_rotation_with_assignment_in_try(self): + def test_rotation_with_assignment_in_try(self) -> None: """Rotation where try block uses assignment expression instead of var.""" js = """ function _0xArr() { @@ -2409,7 +2411,7 @@ def test_rotation_with_assignment_in_try(self): class TestParseRotationOp: """Tests for _parse_rotation_op with various expression types.""" - def test_rotation_op_with_modulo(self): + def test_rotation_op_with_modulo(self) -> None: """Rotation expression with modulo operator.""" js = """ function _0xArr() { @@ -2440,7 +2442,7 @@ def test_rotation_op_with_modulo(self): assert changed is True # 10 % 3 = 1 = stop_value - def test_rotation_op_with_division(self): + def test_rotation_op_with_division(self) -> None: """Rotation expression with division operator.""" js = """ function _0xArr() { @@ -2470,7 +2472,7 @@ def test_rotation_op_with_division(self): code, changed = roundtrip(js, StringRevealer) assert changed is True - def test_rotation_non_parseint_call_ignored(self): + def test_rotation_non_parseint_call_ignored(self) -> None: """Rotation expression with non-parseInt call returns None from _parse_rotation_op.""" js = """ function _0xArr() { @@ -2511,7 +2513,7 @@ def test_rotation_non_parseint_call_ignored(self): class TestTryExecuteRotationCall: """Edge cases for _try_execute_rotation_call.""" - def test_rotation_callee_not_function_expression(self): + def test_rotation_callee_not_function_expression(self) -> None: """Rotation call whose callee is not a FunctionExpression is skipped.""" js = """ function _0xArr() { @@ -2531,7 +2533,7 @@ def test_rotation_callee_not_function_expression(self): assert changed is True assert '"hello"' in code - def test_rotation_wrong_arg_count(self): + def test_rotation_wrong_arg_count(self) -> None: """Rotation IIFE with != 2 arguments is skipped.""" js = """ function _0xArr() { @@ -2552,7 +2554,7 @@ def test_rotation_wrong_arg_count(self): assert changed is True assert '"hello"' in code - def test_rotation_first_arg_not_array_func(self): + def test_rotation_first_arg_not_array_func(self) -> None: """Rotation IIFE where first arg is not the array function name.""" js = """ function _0xArr() { @@ -2573,7 +2575,7 @@ def test_rotation_first_arg_not_array_func(self): assert changed is True assert '"hello"' in code - def test_rotation_non_numeric_stop_value(self): + def test_rotation_non_numeric_stop_value(self) -> None: """Rotation IIFE with non-numeric stop value is skipped.""" js = """ function _0xArr() { @@ -2610,7 +2612,7 @@ def test_rotation_non_numeric_stop_value(self): class TestExtractDecoderOffset: """Tests for _extract_decoder_offset edge cases.""" - def test_decoder_with_addition_offset(self): + def test_decoder_with_addition_offset(self) -> None: """Decoder with idx = idx + OFFSET.""" js = """ function _0xArr() { @@ -2630,7 +2632,7 @@ def test_decoder_with_addition_offset(self): # offset is +2, so _0xDec(0) -> arr[0+2] = 'foo' assert '"foo"' in code - def test_decoder_with_no_offset(self): + def test_decoder_with_no_offset(self) -> None: """Decoder without any param reassignment has offset 0.""" js = """ function _0xArr() { @@ -2657,7 +2659,7 @@ def test_decoder_with_no_offset(self): class TestResolveRotationArgEdgeCases: """Edge cases for _resolve_rotation_arg returning string values.""" - def test_rotation_arg_non_hex_string_literal(self): + def test_rotation_arg_non_hex_string_literal(self) -> None: """Non-hex, non-numeric string literal is returned as-is (for RC4 key).""" # We test this indirectly through the rotation with wrapper + key param js = """ @@ -2692,7 +2694,7 @@ def test_rotation_arg_non_hex_string_literal(self): assert changed is True assert '"500"' in code - def test_rotation_arg_member_expression_computed(self): + def test_rotation_arg_member_expression_computed(self) -> None: """MemberExpression with computed string key in rotation locals.""" js = """ function _0xArr() { @@ -2733,7 +2735,7 @@ def test_rotation_arg_member_expression_computed(self): class TestCollectRotationLocalsEdgeCases: """Edge cases for _collect_rotation_locals.""" - def test_rotation_locals_with_string_key(self): + def test_rotation_locals_with_string_key(self) -> None: """Object literal with string keys in rotation IIFE.""" ast = parse( """ @@ -2755,7 +2757,7 @@ def test_rotation_locals_with_string_key(self): assert result['J']['A'] == 0xB9 assert result['J']['B'] == 'key' - def test_rotation_locals_non_object_var_ignored(self): + def test_rotation_locals_non_object_var_ignored(self) -> None: """Non-object variable declarations in IIFE are ignored.""" ast = parse( """ @@ -2774,7 +2776,7 @@ def test_rotation_locals_non_object_var_ignored(self): result = StringRevealer._collect_rotation_locals(iife_func) assert result == {} - def test_rotation_locals_non_identifier_name_ignored(self): + def test_rotation_locals_non_identifier_name_ignored(self) -> None: """Var declaration with non-identifier pattern is ignored.""" ast = parse( """ @@ -2789,7 +2791,7 @@ def test_rotation_locals_non_identifier_name_ignored(self): assert 'J' in result assert result['J']['A'] == 1 - def test_rotation_locals_empty_object(self): + def test_rotation_locals_empty_object(self) -> None: """Empty object literal in IIFE is not added (no properties).""" ast = parse( """ @@ -2812,7 +2814,7 @@ def test_rotation_locals_empty_object(self): class TestRc4DecoderCreation: """Tests for RC4 decoder type being created via _create_base_decoder.""" - def test_rc4_decoder_detected_and_used(self): + def test_rc4_decoder_detected_and_used(self) -> None: """Decoder with base64 alphabet AND fromCharCode...^ pattern is detected as RC4.""" js = """ function _0xArr() { @@ -2847,14 +2849,14 @@ def _make_revealer(self, js): ast = parse(js) return StringRevealer(ast), ast - def test_parse_rotation_op_literal(self): + def test_parse_rotation_op_literal(self) -> None: """_parse_rotation_op handles a bare numeric literal.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('42') result = t._parse_rotation_op(node, {}, set()) assert result == {'op': 'literal', 'value': 42} - def test_parse_rotation_op_negate(self): + def test_parse_rotation_op_negate(self) -> None: """_parse_rotation_op handles unary negation.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('-42') @@ -2864,7 +2866,7 @@ def test_parse_rotation_op_negate(self): assert result['child']['op'] == 'literal' assert result['child']['value'] == 42 - def test_parse_rotation_op_binary(self): + def test_parse_rotation_op_binary(self) -> None: """_parse_rotation_op handles binary addition of literals.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('10 + 20') @@ -2873,7 +2875,7 @@ def test_parse_rotation_op_binary(self): assert result['op'] == 'binary' assert result['operator'] == '+' - def test_parse_rotation_op_call_with_wrapper(self): + def test_parse_rotation_op_call_with_wrapper(self) -> None: """_parse_rotation_op handles parseInt(wrapper(0)).""" t, ast = self._make_revealer('var x = 1;') wrapper = WrapperInfo('_0xWrap', param_index=0, wrapper_offset=0, func_node={}) @@ -2884,7 +2886,7 @@ def test_parse_rotation_op_call_with_wrapper(self): assert result['wrapper_name'] == '_0xWrap' assert result['args'] == [0] - def test_parse_rotation_op_call_with_decoder_alias(self): + def test_parse_rotation_op_call_with_decoder_alias(self) -> None: """_parse_rotation_op handles parseInt(alias(0)) where alias is in decoder_aliases.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('parseInt(_0xAlias(0))') @@ -2894,20 +2896,20 @@ def test_parse_rotation_op_call_with_decoder_alias(self): assert result['alias_name'] == '_0xAlias' assert result['args'] == [0] - def test_parse_rotation_op_non_dict_returns_none(self): + def test_parse_rotation_op_non_dict_returns_none(self) -> None: """_parse_rotation_op returns None for non-dict input.""" t, ast = self._make_revealer('var x = 1;') assert t._parse_rotation_op(None, {}, set()) is None assert t._parse_rotation_op('string', {}, set()) is None - def test_parse_rotation_op_unsupported_type_returns_none(self): + def test_parse_rotation_op_unsupported_type_returns_none(self) -> None: """_parse_rotation_op returns None for unsupported node types.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('x') # Identifier result = t._parse_rotation_op(node, {}, set()) assert result is None - def test_parse_rotation_op_negate_with_non_numeric_child(self): + def test_parse_rotation_op_negate_with_non_numeric_child(self) -> None: """_parse_rotation_op negate with non-parseable child returns None.""" t, ast = self._make_revealer('var x = 1;') # -x where x is an identifier, not resolvable @@ -2915,7 +2917,7 @@ def test_parse_rotation_op_negate_with_non_numeric_child(self): result = t._parse_rotation_op(node, {}, set()) assert result is None - def test_parse_rotation_op_binary_with_unparseable_child(self): + def test_parse_rotation_op_binary_with_unparseable_child(self) -> None: """_parse_rotation_op binary with unparseable left or right returns None.""" t, ast = self._make_revealer('var x = 1;') # x + 1 where x is identifier @@ -2923,42 +2925,42 @@ def test_parse_rotation_op_binary_with_unparseable_child(self): result = t._parse_rotation_op(node, {}, set()) assert result is None - def test_parse_parseInt_call_not_parseint(self): + def test_parse_parseInt_call_not_parseint(self) -> None: """_parse_parseInt_call returns None when callee is not parseInt.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('Math.floor(1)') result = t._parse_parseInt_call(node, {}, set()) assert result is None - def test_parse_parseInt_call_wrong_arg_count(self): + def test_parse_parseInt_call_wrong_arg_count(self) -> None: """_parse_parseInt_call returns None when parseInt has != 1 arg.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('parseInt(1, 10)') result = t._parse_parseInt_call(node, {}, set()) assert result is None - def test_parse_parseInt_call_inner_not_call(self): + def test_parse_parseInt_call_inner_not_call(self) -> None: """_parse_parseInt_call returns None when inner arg is not a call.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('parseInt(42)') result = t._parse_parseInt_call(node, {}, set()) assert result is None - def test_parse_parseInt_call_inner_callee_not_identifier(self): + def test_parse_parseInt_call_inner_callee_not_identifier(self) -> None: """_parse_parseInt_call returns None when inner callee is not identifier.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('parseInt(a.b(0))') result = t._parse_parseInt_call(node, {}, set()) assert result is None - def test_parse_parseInt_call_inner_unknown_function(self): + def test_parse_parseInt_call_inner_unknown_function(self) -> None: """_parse_parseInt_call returns None when inner function is not in wrappers/aliases.""" t, ast = self._make_revealer('var x = 1;') node = parse_expr('parseInt(unknownFunc(0))') result = t._parse_parseInt_call(node, {}, set()) assert result is None - def test_parse_parseInt_call_unresolvable_arg(self): + def test_parse_parseInt_call_unresolvable_arg(self) -> None: """_parse_parseInt_call returns None when inner arg can't be resolved.""" t, ast = self._make_revealer('var x = 1;') t._rotation_locals = {} @@ -2967,35 +2969,35 @@ def test_parse_parseInt_call_unresolvable_arg(self): result = t._parse_parseInt_call(node, {'_0xW': wrapper}, set()) assert result is None - def test_resolve_rotation_arg_numeric(self): + def test_resolve_rotation_arg_numeric(self) -> None: """_resolve_rotation_arg resolves numeric literal.""" t, ast = self._make_revealer('var x = 1;') t._rotation_locals = {} node = parse_expr('42') assert t._resolve_rotation_arg(node) == 42 - def test_resolve_rotation_arg_string_hex(self): + def test_resolve_rotation_arg_string_hex(self) -> None: """_resolve_rotation_arg resolves hex string literal.""" t, ast = self._make_revealer('var x = 1;') t._rotation_locals = {} node = parse_expr('"0x1b"') assert t._resolve_rotation_arg(node) == 0x1B - def test_resolve_rotation_arg_string_decimal(self): + def test_resolve_rotation_arg_string_decimal(self) -> None: """_resolve_rotation_arg resolves decimal string literal.""" t, ast = self._make_revealer('var x = 1;') t._rotation_locals = {} node = parse_expr('"42"') assert t._resolve_rotation_arg(node) == 42 - def test_resolve_rotation_arg_string_non_numeric(self): + def test_resolve_rotation_arg_string_non_numeric(self) -> None: """_resolve_rotation_arg resolves non-numeric string as-is.""" t, ast = self._make_revealer('var x = 1;') t._rotation_locals = {} node = parse_expr('"myKey"') assert t._resolve_rotation_arg(node) == 'myKey' - def test_resolve_rotation_arg_member_identifier_prop(self): + def test_resolve_rotation_arg_member_identifier_prop(self) -> None: """_resolve_rotation_arg resolves J.A from rotation locals.""" t, ast = self._make_revealer('var x = 1;') t._rotation_locals = {'J': {'A': 42}} @@ -3007,7 +3009,7 @@ def test_resolve_rotation_arg_member_identifier_prop(self): } assert t._resolve_rotation_arg(node) == 42 - def test_resolve_rotation_arg_member_string_prop(self): + def test_resolve_rotation_arg_member_string_prop(self) -> None: """_resolve_rotation_arg resolves J['A'] from rotation locals.""" t, ast = self._make_revealer('var x = 1;') t._rotation_locals = {'J': {'A': 99}} @@ -3019,7 +3021,7 @@ def test_resolve_rotation_arg_member_string_prop(self): } assert t._resolve_rotation_arg(node) == 99 - def test_resolve_rotation_arg_member_unknown_object(self): + def test_resolve_rotation_arg_member_unknown_object(self) -> None: """_resolve_rotation_arg returns None for unknown object in member.""" t, ast = self._make_revealer('var x = 1;') t._rotation_locals = {} @@ -3031,27 +3033,27 @@ def test_resolve_rotation_arg_member_unknown_object(self): } assert t._resolve_rotation_arg(node) is None - def test_resolve_rotation_arg_identifier_returns_none(self): + def test_resolve_rotation_arg_identifier_returns_none(self) -> None: """_resolve_rotation_arg returns None for bare identifier.""" t, ast = self._make_revealer('var x = 1;') t._rotation_locals = {} node = parse_expr('x') assert t._resolve_rotation_arg(node) is None - def test_apply_rotation_op_literal(self): + def test_apply_rotation_op_literal(self) -> None: """_apply_rotation_op evaluates a literal node.""" t, ast = self._make_revealer('var x = 1;') result = t._apply_rotation_op({'op': 'literal', 'value': 42}, {}, None) assert result == 42 - def test_apply_rotation_op_negate(self): + def test_apply_rotation_op_negate(self) -> None: """_apply_rotation_op evaluates a negate node.""" t, ast = self._make_revealer('var x = 1;') op = {'op': 'negate', 'child': {'op': 'literal', 'value': 10}} result = t._apply_rotation_op(op, {}, None) assert result == -10 - def test_apply_rotation_op_binary(self): + def test_apply_rotation_op_binary(self) -> None: """_apply_rotation_op evaluates a binary node.""" t, ast = self._make_revealer('var x = 1;') op = { @@ -3063,9 +3065,8 @@ def test_apply_rotation_op_binary(self): result = t._apply_rotation_op(op, {}, None) assert result == 30 - def test_apply_rotation_op_call_wrapper(self): + def test_apply_rotation_op_call_wrapper(self) -> None: """_apply_rotation_op evaluates a wrapper call op.""" - from pyjsclear.utils.string_decoders import BasicStringDecoder t, ast = self._make_revealer('var x = 1;') decoder = BasicStringDecoder(['100', 'hello', 'world'], 0) @@ -3074,9 +3075,8 @@ def test_apply_rotation_op_call_wrapper(self): result = t._apply_rotation_op(op, {'w': wrapper}, decoder) assert result == 100 - def test_apply_rotation_op_direct_decoder_call(self): + def test_apply_rotation_op_direct_decoder_call(self) -> None: """_apply_rotation_op evaluates a direct_decoder_call op.""" - from pyjsclear.utils.string_decoders import BasicStringDecoder t, ast = self._make_revealer('var x = 1;') decoder = BasicStringDecoder(['200', 'hello'], 0) @@ -3085,9 +3085,8 @@ def test_apply_rotation_op_direct_decoder_call(self): result = t._apply_rotation_op(op, {}, decoder, alias_decoder_map=alias_map) assert result == 200 - def test_apply_rotation_op_direct_decoder_call_with_key(self): + def test_apply_rotation_op_direct_decoder_call_with_key(self) -> None: """_apply_rotation_op direct_decoder_call with key arg.""" - from pyjsclear.utils.string_decoders import BasicStringDecoder t, ast = self._make_revealer('var x = 1;') decoder = BasicStringDecoder(['300', 'hello'], 0) @@ -3096,13 +3095,13 @@ def test_apply_rotation_op_direct_decoder_call_with_key(self): # BasicStringDecoder ignores the key, just returns by index assert result == 300 - def test_apply_rotation_op_unknown_op_raises(self): + def test_apply_rotation_op_unknown_op_raises(self) -> None: """_apply_rotation_op raises for unknown op.""" t, ast = self._make_revealer('var x = 1;') with pytest.raises(ValueError, match='Unknown op'): t._apply_rotation_op({'op': 'unknown_op'}, {}, None) - def test_apply_rotation_op_call_invalid_wrapper_args_raises(self): + def test_apply_rotation_op_call_invalid_wrapper_args_raises(self) -> None: """_apply_rotation_op raises when wrapper args are invalid.""" t, ast = self._make_revealer('var x = 1;') wrapper = WrapperInfo('w', param_index=5, wrapper_offset=0, func_node={}) @@ -3110,16 +3109,15 @@ def test_apply_rotation_op_call_invalid_wrapper_args_raises(self): with pytest.raises(ValueError, match='Invalid wrapper args'): t._apply_rotation_op(op, {'w': wrapper}, None) - def test_apply_rotation_op_direct_decoder_no_args_raises(self): + def test_apply_rotation_op_direct_decoder_no_args_raises(self) -> None: """_apply_rotation_op raises when direct_decoder_call has no args.""" t, ast = self._make_revealer('var x = 1;') op = {'op': 'direct_decoder_call', 'alias_name': 'x', 'args': []} with pytest.raises(ValueError, match='No args'): t._apply_rotation_op(op, {}, None) - def test_execute_rotation_basic(self): + def test_execute_rotation_basic(self) -> None: """_execute_rotation rotates array until expression matches stop.""" - from pyjsclear.utils.string_decoders import BasicStringDecoder t, ast = self._make_revealer('var x = 1;') string_array = ['hello', '42', 'world', 'foo', 'bar'] @@ -3131,9 +3129,8 @@ def test_execute_rotation_basic(self): assert result is True assert string_array[0] == '42' - def test_execute_rotation_clears_decoder_cache(self): + def test_execute_rotation_clears_decoder_cache(self) -> None: """_execute_rotation clears decoder caches on each rotation.""" - from pyjsclear.utils.string_decoders import Base64StringDecoder t, ast = self._make_revealer('var x = 1;') # Use Base64StringDecoder which has a _cache attribute @@ -3147,7 +3144,6 @@ def test_execute_rotation_clears_decoder_cache(self): # BasicStringDecoder returns string directly, Base64StringDecoder does base64_transform # But since Base64 won't give us clean ints, use BasicStringDecoder for the actual test # and just verify that _cache.clear() is called on Base64StringDecoder - from pyjsclear.utils.string_decoders import BasicStringDecoder primary = BasicStringDecoder(string_array, 0) op = {'op': 'call', 'wrapper_name': 'w', 'args': [0]} @@ -3164,11 +3160,8 @@ def test_execute_rotation_clears_decoder_cache(self): # Verify the Base64 decoder's cache was cleared during rotation assert len(decoder._cache) == 0 - def test_execute_rotation_with_alias_decoder_map(self): + def test_execute_rotation_with_alias_decoder_map(self) -> None: """_execute_rotation uses alias_decoder_map for clearing caches.""" - from pyjsclear.utils.string_decoders import Base64StringDecoder - from pyjsclear.utils.string_decoders import BasicStringDecoder - t, ast = self._make_revealer('var x = 1;') string_array = ['hello', '42', 'world', 'foo', 'bar'] primary = BasicStringDecoder(string_array, 0) @@ -3192,7 +3185,7 @@ def test_execute_rotation_with_alias_decoder_map(self): class TestSequenceExpressionRotationRemoval: """Tests for rotation inside SequenceExpression being properly removed.""" - def test_sequence_rotation_removal_with_wrapper(self): + def test_sequence_rotation_removal_with_wrapper(self) -> None: """Rotation in SequenceExpression is removed while keeping other expressions. This triggers lines 287-300 by having the rotation succeed inside a sequence. @@ -3241,7 +3234,7 @@ def test_sequence_rotation_removal_with_wrapper(self): class TestRotationSequenceExpression: """Test that rotation can be found inside a SequenceExpression.""" - def test_rotation_in_sequence_with_decoder_alias(self): + def test_rotation_in_sequence_with_decoder_alias(self) -> None: """Rotation IIFE inside SequenceExpression using decoder alias.""" js = """ function _0xArr() { @@ -3285,7 +3278,7 @@ def test_rotation_in_sequence_with_decoder_alias(self): class TestExtractDecoderOffsetDirect: """Direct tests for _extract_decoder_offset edge cases.""" - def test_offset_with_non_identifier_left(self): + def test_offset_with_non_identifier_left(self) -> None: """Assignment where left is not an identifier is ignored.""" t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') # Parse a function with arr[0] = arr[0] - 5 (member expression on left) @@ -3301,7 +3294,7 @@ def test_offset_with_non_identifier_left(self): offset = t._extract_decoder_offset(func_node) assert offset == 0 # Default when no matching pattern found - def test_offset_non_binary_right_side(self): + def test_offset_non_binary_right_side(self) -> None: """Assignment where right side is not BinaryExpression is ignored.""" t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') func_ast = parse( @@ -3316,7 +3309,7 @@ def test_offset_non_binary_right_side(self): offset = t._extract_decoder_offset(func_node) assert offset == 0 - def test_offset_binary_left_not_matching_param(self): + def test_offset_binary_left_not_matching_param(self) -> None: """Assignment where binary left doesn't match the assigned variable.""" t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') func_ast = parse( @@ -3331,7 +3324,7 @@ def test_offset_binary_left_not_matching_param(self): offset = t._extract_decoder_offset(func_node) assert offset == 0 - def test_offset_unsupported_operator(self): + def test_offset_unsupported_operator(self) -> None: """Assignment with unsupported operator (*) is ignored.""" t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') func_ast = parse( @@ -3355,28 +3348,28 @@ def test_offset_unsupported_operator(self): class TestStringArrayFromExpression: """Edge cases for _string_array_from_expression.""" - def test_array_with_non_string_element(self): + def test_array_with_non_string_element(self) -> None: """Array with a numeric element returns None.""" t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') node = parse_expr('[1, 2, 3]') result = t._string_array_from_expression(node) assert result is None - def test_empty_array_returns_none(self): + def test_empty_array_returns_none(self) -> None: """Empty array returns None.""" t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') node = parse_expr('[]') result = t._string_array_from_expression(node) assert result is None - def test_non_array_returns_none(self): + def test_non_array_returns_none(self) -> None: """Non-array node returns None.""" t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') node = parse_expr('42') result = t._string_array_from_expression(node) assert result is None - def test_none_returns_none(self): + def test_none_returns_none(self) -> None: """None returns None.""" t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') assert t._string_array_from_expression(None) is None @@ -3390,7 +3383,7 @@ def test_none_returns_none(self): class TestFindStringArrayFunction: """Edge cases for _find_string_array_function.""" - def test_function_with_short_body(self): + def test_function_with_short_body(self) -> None: """Function with only 1 statement in body is skipped.""" t, ast = TestRotationInternalsDirect._make_revealer( None, @@ -3404,7 +3397,7 @@ def test_function_with_short_body(self): name, arr, idx = t._find_string_array_function(body) assert name is None - def test_function_without_name(self): + def test_function_without_name(self) -> None: """Function without a name is skipped.""" # FunctionDeclarations always have names in valid JS, but we test the guard js = """ @@ -3430,12 +3423,12 @@ def test_function_without_name(self): class TestEvalNumericUnaryArgNone: """Test _eval_numeric when unary argument evaluates to None.""" - def test_negate_identifier_returns_none(self): + def test_negate_identifier_returns_none(self) -> None: """Negating an identifier that can't be evaluated returns None.""" node = parse_expr('-x') assert _eval_numeric(node) is None - def test_positive_identifier_returns_none(self): + def test_positive_identifier_returns_none(self) -> None: """Positive sign on identifier returns None.""" node = parse_expr('+x') assert _eval_numeric(node) is None @@ -3449,7 +3442,7 @@ def test_positive_identifier_returns_none(self): class TestCollectObjectLiteralsPropertyType: """Test _collect_object_literals with non-Property type entries.""" - def test_spread_element_ignored(self): + def test_spread_element_ignored(self) -> None: """SpreadElement in object properties is skipped (type != 'Property').""" # We can't easily create this in valid JS that esprima parses, # but we can test with a normal object to verify Property type check works @@ -3458,7 +3451,7 @@ def test_spread_element_ignored(self): assert result[('obj', 'a')] == 1 assert result[('obj', 'b')] == 'two' - def test_property_with_no_value(self): + def test_property_with_no_value(self) -> None: """Property with missing key or value is skipped.""" # Manually construct an AST with a malformed property ast = { @@ -3494,7 +3487,7 @@ def test_property_with_no_value(self): class TestProcessDirectArraysInScope: """Test _process_direct_arrays_in_scope when binding is not found.""" - def test_array_in_nested_function_scopes(self): + def test_array_in_nested_function_scopes(self) -> None: """Array accessed in deeply nested function scopes.""" js = """ var arr = ["hello", "world"]; @@ -3509,7 +3502,7 @@ def test_array_in_nested_function_scopes(self): assert changed is True assert '"hello"' in code - def test_array_not_referenced_in_child_scope(self): + def test_array_not_referenced_in_child_scope(self) -> None: """Child scope that doesn't reference the array binding.""" js = """ var arr = ["hello", "world"]; @@ -3531,7 +3524,7 @@ def test_array_not_referenced_in_child_scope(self): class TestFindVarStringArrayEdge: """Edge cases for _find_var_string_array.""" - def test_non_array_init_skipped(self): + def test_non_array_init_skipped(self) -> None: """Var declaration with non-array init is skipped.""" js = """ var _0x = 42; @@ -3543,7 +3536,7 @@ def test_non_array_init_skipped(self): assert changed is True assert '"hello"' in code - def test_short_array_skipped(self): + def test_short_array_skipped(self) -> None: """Array with < 3 elements is skipped by _find_var_string_array.""" js = """ var _0xarr = ['a', 'b']; @@ -3562,7 +3555,7 @@ def test_short_array_skipped(self): class TestFindSimpleRotationEdge: """Edge cases for _find_simple_rotation.""" - def test_non_expression_statement_skipped(self): + def test_non_expression_statement_skipped(self) -> None: """Non-ExpressionStatement in body is skipped when looking for rotation.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -3574,7 +3567,7 @@ def test_non_expression_statement_skipped(self): assert changed is True assert '"hello"' in code # No rotation applied - def test_expression_statement_without_expression(self): + def test_expression_statement_without_expression(self) -> None: """Expression statement edge case.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -3595,7 +3588,7 @@ def test_expression_statement_without_expression(self): class TestFindVarDecoderEdge: """Edge cases for _find_var_decoder.""" - def test_var_declaration_non_function_init(self): + def test_var_declaration_non_function_init(self) -> None: """Var declaration where init is not FunctionExpression is skipped.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -3606,7 +3599,7 @@ def test_var_declaration_non_function_init(self): # Direct array strategy handles arr[0], but decoder pattern is not found assert isinstance(code, str) - def test_var_declaration_non_variable(self): + def test_var_declaration_non_variable(self) -> None: """Non-variable declaration statement is skipped when looking for decoder.""" js = """ var _0xarr = ['hello', 'world', 'foo', 'bar']; @@ -3627,20 +3620,20 @@ def test_var_declaration_non_variable(self): class TestTryReplaceArrayAccessEdge: """Edge cases for _try_replace_array_access.""" - def test_non_computed_member_not_replaced(self): + def test_non_computed_member_not_replaced(self) -> None: """arr.foo style access is not replaced.""" js = 'var arr = ["hello", "world"]; console.log(arr.length);' code, changed = roundtrip(js, StringRevealer) assert changed is False assert 'arr.length' in code - def test_member_property_not_numeric(self): + def test_member_property_not_numeric(self) -> None: """arr[x] where x is an identifier is not replaced.""" js = 'var arr = ["hello", "world"]; console.log(arr[x]);' code, changed = roundtrip(js, StringRevealer) assert changed is False - def test_ref_key_not_object(self): + def test_ref_key_not_object(self) -> None: """Reference where key is not 'object' is not replaced.""" # This happens when arr is used as a property value, not as the object of a member js = 'var arr = ["hello", "world"]; var x = {prop: arr};' @@ -3656,7 +3649,7 @@ def test_ref_key_not_object(self): class TestUpdateAstArrayEdge: """Edge case for _update_ast_array.""" - def test_update_ast_array_with_assignment_init(self): + def test_update_ast_array_with_assignment_init(self) -> None: """_update_ast_array when first statement is assignment (not var decl).""" t, ast = TestRotationInternalsDirect._make_revealer( None, @@ -3679,7 +3672,7 @@ def test_update_ast_array_with_assignment_init(self): assert elements[0]['value'] == 'world' assert elements[1]['value'] == 'hello' - def test_update_ast_array_empty_body(self): + def test_update_ast_array_empty_body(self) -> None: """_update_ast_array with empty function body does nothing.""" t, ast = TestRotationInternalsDirect._make_revealer( None, @@ -3701,27 +3694,24 @@ def test_update_ast_array_empty_body(self): class TestDecodeAndParseInt: """Tests for _decode_and_parse_int error paths.""" - def test_decode_and_parse_int_returns_nan(self): + def test_decode_and_parse_int_returns_nan(self) -> None: """_decode_and_parse_int raises when decoded string is not parseable as int.""" - from pyjsclear.utils.string_decoders import BasicStringDecoder t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') decoder = BasicStringDecoder(['hello'], 0) with pytest.raises(ValueError, match='NaN'): t._decode_and_parse_int(decoder, 0) - def test_decode_and_parse_int_none_returned(self): + def test_decode_and_parse_int_none_returned(self) -> None: """_decode_and_parse_int raises when decoder returns None (out of bounds).""" - from pyjsclear.utils.string_decoders import BasicStringDecoder t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') decoder = BasicStringDecoder(['hello'], 0) with pytest.raises(ValueError, match='Decoder returned None'): t._decode_and_parse_int(decoder, 999) - def test_decode_and_parse_int_with_key(self): + def test_decode_and_parse_int_with_key(self) -> None: """_decode_and_parse_int passes key to decoder.get_string.""" - from pyjsclear.utils.string_decoders import BasicStringDecoder t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') decoder = BasicStringDecoder(['42'], 0) @@ -3733,9 +3723,8 @@ def test_decode_and_parse_int_with_key(self): class TestExecuteRotationEdge: """Edge cases for _execute_rotation.""" - def test_execute_rotation_returns_false_when_no_match(self): + def test_execute_rotation_returns_false_when_no_match(self) -> None: """_execute_rotation returns False when no match in 100001 iterations.""" - from pyjsclear.utils.string_decoders import BasicStringDecoder t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;') # Array where no rotation can produce parseInt matching stop_value 999999 @@ -3752,7 +3741,7 @@ def test_execute_rotation_returns_false_when_no_match(self): class TestWrapperReplacementEdge: """Edge cases for _replace_all_wrapper_calls.""" - def test_wrapper_call_unresolvable_index_value(self): + def test_wrapper_call_unresolvable_index_value(self) -> None: """Wrapper call where the index param can't be resolved.""" js = """ function _0xArr() { @@ -3777,7 +3766,7 @@ def test_wrapper_call_unresolvable_index_value(self): # _0xWrap(someVar) should not be replaced assert '_0xWrap(someVar)' in code - def test_wrapper_key_param_resolution(self): + def test_wrapper_key_param_resolution(self) -> None: """Wrapper with key_param_index resolves key from call args.""" js = """ function _0xArr() { @@ -3801,7 +3790,7 @@ def test_wrapper_key_param_resolution(self): assert '"hello"' in code assert '"world"' in code - def test_decoder_returns_non_string(self): + def test_decoder_returns_non_string(self) -> None: """When decoder returns None (out of bounds), the call is not replaced.""" js = """ function _0xArr() { @@ -3827,7 +3816,7 @@ def test_decoder_returns_non_string(self): class TestAnalyzeWrapperExprEdge: """Edge cases for _analyze_wrapper_expr.""" - def test_wrapper_with_key_param_from_second_arg(self): + def test_wrapper_with_key_param_from_second_arg(self) -> None: """Wrapper passing second param as key to decoder.""" js = """ function _0xArr() { @@ -3853,9 +3842,8 @@ def test_wrapper_with_key_param_from_second_arg(self): class TestFindAndExecuteRotationEdge: """Edge cases for _find_and_execute_rotation.""" - def test_rotation_not_found_returns_none(self): + def test_rotation_not_found_returns_none(self) -> None: """When no rotation IIFE exists, returns None.""" - from pyjsclear.utils.string_decoders import BasicStringDecoder t, ast = TestRotationInternalsDirect._make_revealer( None, @@ -3878,9 +3866,8 @@ def test_rotation_not_found_returns_none(self): result = t._find_and_execute_rotation(body, '_0xArr', ['hello'], decoder, {}, set()) assert result is None - def test_rotation_found_in_sequence_expression(self): + def test_rotation_found_in_sequence_expression(self) -> None: """Rotation IIFE inside a SequenceExpression is found and executed.""" - from pyjsclear.utils.string_decoders import BasicStringDecoder js = """ function _0xArr() { @@ -3924,7 +3911,7 @@ def test_rotation_found_in_sequence_expression(self): class TestExtractRotationExpressionEdge: """Edge cases for _extract_rotation_expression.""" - def test_no_loop_in_iife(self): + def test_no_loop_in_iife(self) -> None: """IIFE without a while/for loop returns None.""" t, ast = TestRotationInternalsDirect._make_revealer( None, @@ -3940,7 +3927,7 @@ def test_no_loop_in_iife(self): result = t._extract_rotation_expression(iife_func) assert result is None - def test_loop_without_try_statement(self): + def test_loop_without_try_statement(self) -> None: """Loop without a TryStatement returns None.""" t, ast = TestRotationInternalsDirect._make_revealer( None, @@ -3958,7 +3945,7 @@ def test_loop_without_try_statement(self): result = t._extract_rotation_expression(iife_func) assert result is None - def test_try_with_empty_block(self): + def test_try_with_empty_block(self) -> None: """Try block with empty body returns None.""" t, ast = TestRotationInternalsDirect._make_revealer( None, diff --git a/tests/unit/transforms/unreachable_code_test.py b/tests/unit/transforms/unreachable_code_test.py index 0d99459..042248a 100644 --- a/tests/unit/transforms/unreachable_code_test.py +++ b/tests/unit/transforms/unreachable_code_test.py @@ -7,41 +7,41 @@ class TestUnreachableCodeRemover: """Tests for removing unreachable statements after terminators.""" - def test_removes_after_return(self): + def test_removes_after_return(self) -> None: code = 'function f() { return 1; console.log("dead"); }' result, changed = roundtrip(code, UnreachableCodeRemover) assert changed is True assert 'dead' not in result assert 'return 1' in result - def test_removes_after_throw(self): + def test_removes_after_throw(self) -> None: code = 'function f() { throw new Error(); var x = 1; }' result, changed = roundtrip(code, UnreachableCodeRemover) assert changed is True assert 'var x' not in result - def test_removes_after_break(self): + def test_removes_after_break(self) -> None: code = 'for(;;) { break; console.log("dead"); }' result, changed = roundtrip(code, UnreachableCodeRemover) assert changed is True assert 'dead' not in result - def test_removes_after_continue(self): + def test_removes_after_continue(self) -> None: code = 'for(;;) { continue; console.log("dead"); }' result, changed = roundtrip(code, UnreachableCodeRemover) assert changed is True assert 'dead' not in result - def test_preserves_reachable_code(self): + def test_preserves_reachable_code(self) -> None: code = 'function f() { var x = 1; return x; }' result, changed = roundtrip(code, UnreachableCodeRemover) assert changed is False - def test_no_terminator_returns_false(self): + def test_no_terminator_returns_false(self) -> None: result, changed = roundtrip('var x = 1; var y = 2;', UnreachableCodeRemover) assert changed is False - def test_removes_multiple_after_return(self): + def test_removes_multiple_after_return(self) -> None: code = 'function f() { return; var a = 1; var b = 2; var c = 3; }' result, changed = roundtrip(code, UnreachableCodeRemover) assert changed is True @@ -49,14 +49,14 @@ def test_removes_multiple_after_return(self): assert 'var b' not in result assert 'var c' not in result - def test_handles_nested_blocks(self): + def test_handles_nested_blocks(self) -> None: code = 'function f() { if (x) { return; var dead = 1; } var live = 2; }' result, changed = roundtrip(code, UnreachableCodeRemover) assert changed is True assert 'dead' not in result assert 'live' in result - def test_return_at_end_no_change(self): + def test_return_at_end_no_change(self) -> None: code = 'function f() { var x = 1; return x; }' result, changed = roundtrip(code, UnreachableCodeRemover) assert changed is False diff --git a/tests/unit/transforms/unused_vars_test.py b/tests/unit/transforms/unused_vars_test.py index 89dcce5..f214cde 100644 --- a/tests/unit/transforms/unused_vars_test.py +++ b/tests/unit/transforms/unused_vars_test.py @@ -1,7 +1,6 @@ """Unit tests for UnusedVariableRemover transform.""" -import pytest - +from pyjsclear.parser import parse from pyjsclear.transforms.unused_vars import UnusedVariableRemover from tests.unit.conftest import normalize from tests.unit.conftest import roundtrip @@ -10,121 +9,121 @@ class TestUnusedVariableRemover: """Tests for removing unreferenced variables and functions.""" - def test_unreferenced_var_in_function_removed(self): + def test_unreferenced_var_in_function_removed(self) -> None: code = 'function f() { var x = 1; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True assert 'var x' not in result assert 'function f' in result - def test_referenced_var_in_function_kept(self): + def test_referenced_var_in_function_kept(self) -> None: code = 'function f() { var x = 1; return x; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'var x = 1' in result assert 'return x' in result - def test_global_0x_prefixed_var_removed(self): + def test_global_0x_prefixed_var_removed(self) -> None: code = 'var _0xabc = 1;' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True assert '_0xabc' not in result - def test_global_normal_var_kept(self): + def test_global_normal_var_kept(self) -> None: code = 'var foo = 1;' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'var foo = 1' in result - def test_side_effect_call_expression_kept(self): + def test_side_effect_call_expression_kept(self) -> None: code = 'function f() { var x = foo(); }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'var x = foo()' in result - def test_side_effect_new_expression_kept(self): + def test_side_effect_new_expression_kept(self) -> None: code = 'function f() { var x = new Foo(); }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'var x = new Foo()' in result - def test_side_effect_assignment_expression_kept(self): + def test_side_effect_assignment_expression_kept(self) -> None: code = 'function f() { var x = (a = 1); }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'var x' in result - def test_side_effect_update_expression_kept(self): + def test_side_effect_update_expression_kept(self) -> None: code = 'function f() { var x = a++; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'var x' in result - def test_global_0x_function_declaration_removed(self): + def test_global_0x_function_declaration_removed(self) -> None: code = 'function _0xabc() {}' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True assert '_0xabc' not in result - def test_unreferenced_nested_function_removed(self): + def test_unreferenced_nested_function_removed(self) -> None: code = 'function f() { function g() {} }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True assert 'function g' not in result assert 'function f' in result - def test_param_kept_even_if_unreferenced(self): + def test_param_kept_even_if_unreferenced(self) -> None: code = 'function f(x) { return 1; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'function f(x)' in result - def test_multiple_declarators_remove_only_unreferenced(self): + def test_multiple_declarators_remove_only_unreferenced(self) -> None: code = 'function f() { var x = 1, y = 2; return x; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True assert 'x' in result assert 'y' not in result - def test_multiple_declarators_all_unreferenced(self): + def test_multiple_declarators_all_unreferenced(self) -> None: code = 'function f() { var x = 1, y = 2; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True assert 'var' not in result - def test_no_unused_returns_false(self): + def test_no_unused_returns_false(self) -> None: code = 'function f(x) { return x; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False - def test_var_with_no_init_removed(self): + def test_var_with_no_init_removed(self) -> None: code = 'function f() { var x; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True assert 'var x' not in result - def test_global_normal_function_kept(self): + def test_global_normal_function_kept(self) -> None: code = 'function foo() {}' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'function foo' in result - def test_rebuild_scope_flag(self): + def test_rebuild_scope_flag(self) -> None: assert UnusedVariableRemover.rebuild_scope is True - def test_nested_side_effect_in_binary_kept(self): + def test_nested_side_effect_in_binary_kept(self) -> None: code = 'function f() { var x = 1 + foo(); }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'var x' in result - def test_pure_init_object_removed(self): + def test_pure_init_object_removed(self) -> None: code = 'function f() { var x = {a: 1}; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True assert 'var x' not in result - def test_pure_init_array_removed(self): + def test_pure_init_array_removed(self) -> None: code = 'function f() { var x = [1, 2]; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True @@ -134,7 +133,7 @@ def test_pure_init_array_removed(self): class TestUnusedVariableRemoverEdgeCases: """Tests for uncovered edge cases in unused variable removal.""" - def test_side_effect_in_dict_child(self): + def test_side_effect_in_dict_child(self) -> None: """Line 111: _has_side_effects with dict child having side effect.""" code = 'function f() { var x = -foo(); }' result, changed = roundtrip(code, UnusedVariableRemover) @@ -142,7 +141,7 @@ def test_side_effect_in_dict_child(self): assert not changed assert 'var x' in result - def test_has_side_effects_none_child(self): + def test_has_side_effects_none_child(self) -> None: """Line 105: _has_side_effects with None child should continue.""" # A conditional expression has test, consequent, alternate — where alternate can be None-ish code = 'function f() { var x = true ? 1 : 2; }' @@ -151,7 +150,7 @@ def test_has_side_effects_none_child(self): assert changed is True assert 'var x' not in result - def test_has_side_effects_non_dict_node(self): + def test_has_side_effects_non_dict_node(self) -> None: """Line 95: _has_side_effects with non-dict node returns False.""" # This is tested internally when init is a literal (non-dict-like in some paths) code = 'function f() { var x = 42; }' @@ -159,7 +158,7 @@ def test_has_side_effects_non_dict_node(self): assert changed is True assert 'var x' not in result - def test_empty_declarations_in_batch_remove(self): + def test_empty_declarations_in_batch_remove(self) -> None: """Line 81: decls is empty/None — should return early.""" # Unusual case where VariableDeclaration has empty declarations # Just test that normal removal works with a straightforward case @@ -167,14 +166,14 @@ def test_empty_declarations_in_batch_remove(self): result, changed = roundtrip(code, UnusedVariableRemover) assert changed is True - def test_all_declarators_referenced_no_change(self): + def test_all_declarators_referenced_no_change(self) -> None: """Line 84: new_decls length equals decls length (no change).""" code = 'function f() { var x = 1, y = 2; return x + y; }' result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False assert 'var x' in result - def test_side_effect_in_array_element(self): + def test_side_effect_in_array_element(self) -> None: """Lines 107-108: Array child item with side effect detected.""" # TemplateLiteral expressions list can contain call expressions code = 'function f() { var x = [foo(), 1]; }' @@ -183,7 +182,7 @@ def test_side_effect_in_array_element(self): # But the init check is just on the top-level node type assert isinstance(changed, bool) - def test_conditional_with_side_effect(self): + def test_conditional_with_side_effect(self) -> None: """Side effect deep in conditional expression.""" code = 'function f() { var x = true ? foo() : 1; }' result, changed = roundtrip(code, UnusedVariableRemover) @@ -192,7 +191,7 @@ def test_conditional_with_side_effect(self): assert not changed assert 'var x' in result - def test_binding_node_not_dict(self): + def test_binding_node_not_dict(self) -> None: """Line 56: binding.node that is not a dict should be skipped.""" # This is a defensive check. We test it by directly invoking with a scope # that has a non-dict binding node. Since we can't easily construct that @@ -201,14 +200,14 @@ def test_binding_node_not_dict(self): result, changed = roundtrip(code, UnusedVariableRemover) assert changed is False - def test_has_side_effects_non_dict(self): + def test_has_side_effects_non_dict(self) -> None: """Line 95: _has_side_effects with non-dict returns False.""" remover = UnusedVariableRemover({'type': 'Program', 'body': []}) assert remover._has_side_effects(None) is False assert remover._has_side_effects(42) is False assert remover._has_side_effects('string') is False - def test_has_side_effects_none_child(self): + def test_has_side_effects_none_child(self) -> None: """Line 105: None child in side effect check should be skipped.""" remover = UnusedVariableRemover({'type': 'Program', 'body': []}) # A ConditionalExpression where alternate is None @@ -220,7 +219,7 @@ def test_has_side_effects_none_child(self): } assert remover._has_side_effects(node) is False - def test_has_side_effects_list_child_with_side_effect(self): + def test_has_side_effects_list_child_with_side_effect(self) -> None: """Lines 107-108: list child with side effect item returns True.""" remover = UnusedVariableRemover({'type': 'Program', 'body': []}) # SequenceExpression has 'expressions' which is a list @@ -233,7 +232,7 @@ def test_has_side_effects_list_child_with_side_effect(self): } assert remover._has_side_effects(node) is True - def test_has_side_effects_list_child_no_side_effect(self): + def test_has_side_effects_list_child_no_side_effect(self) -> None: """Lines 107-108: list child without side effect items returns False.""" remover = UnusedVariableRemover({'type': 'Program', 'body': []}) node = { @@ -245,7 +244,7 @@ def test_has_side_effects_list_child_no_side_effect(self): } assert remover._has_side_effects(node) is False - def test_template_literal_recurses(self): + def test_template_literal_recurses(self) -> None: """TemplateLiteral is not in _PURE_TYPES or _SIDE_EFFECT_TYPES, so it recurses.""" code = 'function f() { var x = `hello`; }' result, changed = roundtrip(code, UnusedVariableRemover) @@ -253,11 +252,8 @@ def test_template_literal_recurses(self): assert changed is True assert 'var x' not in result - def test_empty_decls_list(self): + def test_empty_decls_list(self) -> None: """Line 81: VariableDeclaration with empty declarations list.""" - from pyjsclear.parser import parse - from pyjsclear.traverser import traverse - ast = parse('function f() { var x = 1; }') # Manually clear declarations to trigger the empty check for node in ast['body']: diff --git a/tests/unit/transforms/variable_renamer_test.py b/tests/unit/transforms/variable_renamer_test.py index 7a45747..8e1e993 100644 --- a/tests/unit/transforms/variable_renamer_test.py +++ b/tests/unit/transforms/variable_renamer_test.py @@ -12,31 +12,31 @@ class TestBasicRenaming: """Tests for basic _0x identifier renaming.""" - def test_single_var_renamed(self): + def test_single_var_renamed(self) -> None: code = 'function f() { var _0x1234 = 1; return _0x1234; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert '_0x1234' not in result - def test_multiple_vars_renamed(self): + def test_multiple_vars_renamed(self) -> None: code = 'function f() { var _0x1 = 1; var _0x2 = 2; return _0x1 + _0x2; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert '_0x' not in result - def test_function_name_renamed(self): + def test_function_name_renamed(self) -> None: code = 'function _0xabc() { return 1; } _0xabc();' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert '_0xabc' not in result - def test_param_renamed(self): + def test_param_renamed(self) -> None: code = 'function f(_0xdef) { return _0xdef; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert '_0xdef' not in result - def test_const_let_renamed(self): + def test_const_let_renamed(self) -> None: code = 'function f() { const _0x1 = 1; let _0x2 = 2; return _0x1 + _0x2; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True @@ -46,25 +46,25 @@ def test_const_let_renamed(self): class TestHeuristicNaming: """Tests for heuristic-based name inference.""" - def test_require_fs_named(self): + def test_require_fs_named(self) -> None: code = 'function f() { const _0x1 = require("fs"); _0x1.readFileSync("a"); }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'const fs = require("fs")' in result - def test_require_path_named(self): + def test_require_path_named(self) -> None: code = 'function f() { const _0x1 = require("path"); _0x1.join("a"); }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'const path = require("path")' in result - def test_require_child_process_named(self): + def test_require_child_process_named(self) -> None: code = 'function f() { const _0x1 = require("child_process"); _0x1.spawn("a"); }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'const cp = require("child_process")' in result - def test_require_dedupe(self): + def test_require_dedupe(self) -> None: """Multiple require("fs") in same scope get fs, fs2, fs3.""" code = ''' function f() { @@ -79,43 +79,43 @@ def test_require_dedupe(self): assert 'const fs =' in result assert 'const fs2 =' in result - def test_array_literal_named(self): + def test_array_literal_named(self) -> None: code = 'function f() { const _0x1 = []; _0x1.push(1); }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'const arr = []' in result - def test_object_literal_named(self): + def test_object_literal_named(self) -> None: code = 'function f() { const _0x1 = {}; _0x1.foo = 1; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'const obj = {}' in result or 'const obj =' in result - def test_buffer_from_named(self): + def test_buffer_from_named(self) -> None: code = 'function f() { const _0x1 = Buffer.from("abc"); return _0x1; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'const buf = Buffer.from' in result - def test_json_parse_named(self): + def test_json_parse_named(self) -> None: code = 'function f(s) { const _0x1 = JSON.parse(s); return _0x1; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'const data = JSON.parse' in result - def test_new_date_named(self): + def test_new_date_named(self) -> None: code = 'function f() { const _0x1 = new Date(); return _0x1; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'const date = new Date()' in result - def test_loop_counter_named_i(self): + def test_loop_counter_named_i(self) -> None: code = 'function f() { for (var _0x1 = 0; _0x1 < 10; _0x1++) { console.log(_0x1); } }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'var i = 0' in result or 'var i2 = 0' in result or 'let i = 0' in result - def test_usage_based_fs_naming(self): + def test_usage_based_fs_naming(self) -> None: """Variable used with .existsSync should be named fs even without require init.""" code = ''' function f(_0xabc) { @@ -127,7 +127,7 @@ def test_usage_based_fs_naming(self): assert changed is True assert 'function f(fs)' in result - def test_require_with_path_sanitized(self): + def test_require_with_path_sanitized(self) -> None: r"""require(".\lib\Foo.node") should not produce invalid identifiers.""" code = r'function f() { const _0x1 = require(".\\lib\\Foo.node"); return _0x1; }' result, changed = roundtrip(code, VariableRenamer) @@ -140,20 +140,20 @@ def test_require_with_path_sanitized(self): class TestPreservation: """Tests for names that should NOT be renamed.""" - def test_non_0x_preserved(self): + def test_non_0x_preserved(self) -> None: code = 'function f() { var foo = 1; return foo; }' result, changed = roundtrip(code, VariableRenamer) assert changed is False assert 'foo' in result - def test_mixed_names(self): + def test_mixed_names(self) -> None: code = 'function f() { var foo = 1; var _0x1234 = 2; return foo + _0x1234; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert 'foo' in result assert '_0x1234' not in result - def test_no_0x_returns_false(self): + def test_no_0x_returns_false(self) -> None: code = 'var x = 1;' result, changed = roundtrip(code, VariableRenamer) assert changed is False @@ -162,14 +162,14 @@ def test_no_0x_returns_false(self): class TestNoConflict: """Tests that renaming doesn't create naming conflicts.""" - def test_no_conflict_with_existing_names(self): + def test_no_conflict_with_existing_names(self) -> None: code = 'function f() { var a = 1; var _0x1234 = 2; return a + _0x1234; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert '_0x1234' not in result assert 'var a = 1' in result - def test_require_name_conflict_resolved(self): + def test_require_name_conflict_resolved(self) -> None: """If 'fs' already exists, require("fs") should get 'fs2'.""" code = ''' function f() { @@ -186,18 +186,18 @@ def test_require_name_conflict_resolved(self): class TestNameGenerator: """Tests for the name generator function.""" - def test_generates_single_letters(self): + def test_generates_single_letters(self) -> None: gen = _name_generator(set()) names = [next(gen) for _ in range(26)] assert names == list('abcdefghijklmnopqrstuvwxyz') - def test_generates_two_letter_after_single(self): + def test_generates_two_letter_after_single(self) -> None: gen = _name_generator(set()) names = [next(gen) for _ in range(28)] assert names[26] == 'aa' assert names[27] == 'ab' - def test_skips_reserved(self): + def test_skips_reserved(self) -> None: gen = _name_generator({'a', 'c'}) assert next(gen) == 'b' assert next(gen) == 'd' @@ -206,7 +206,7 @@ def test_skips_reserved(self): class TestNestedScopes: """Tests for renaming across nested scopes.""" - def test_nested_functions(self): + def test_nested_functions(self) -> None: code = ''' function _0xabc1() { var _0x1 = 1; @@ -221,31 +221,31 @@ def test_nested_functions(self): assert changed is True assert '_0x' not in result - def test_arrow_function_params(self): + def test_arrow_function_params(self) -> None: code = 'var f = (_0xaaa, _0xbbb) => _0xaaa + _0xbbb;' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert '_0x' not in result - def test_class_expression_name_renamed(self): + def test_class_expression_name_renamed(self) -> None: code = 'var C = class _0xabc { static m() { return _0xabc; } };' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert '_0xabc' not in result - def test_rest_param_renamed(self): + def test_rest_param_renamed(self) -> None: code = 'function f(_0xaaa, ..._0xbbb) { return _0xbbb; }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert '_0x' not in result - def test_catch_param_renamed(self): + def test_catch_param_renamed(self) -> None: code = 'function f() { try { x(); } catch (_0xabc) { console.log(_0xabc); } }' result, changed = roundtrip(code, VariableRenamer) assert changed is True assert '_0xabc' not in result - def test_destructuring_duplicate_names_fixed(self): + def test_destructuring_duplicate_names_fixed(self) -> None: """Obfuscators can produce const [a, a, a] = x; — fix duplicates.""" code = 'function f(_0xabc) { const [_0xabc, _0xabc, _0xabc] = _0xabc; }' result, changed = roundtrip(code, VariableRenamer) diff --git a/tests/unit/transforms/xor_string_decode_test.py b/tests/unit/transforms/xor_string_decode_test.py index 14beed7..a870c53 100644 --- a/tests/unit/transforms/xor_string_decode_test.py +++ b/tests/unit/transforms/xor_string_decode_test.py @@ -1,7 +1,5 @@ """Tests for the XorStringDecoder transform.""" -import pytest - from pyjsclear.transforms.xor_string_decode import XorStringDecoder from pyjsclear.transforms.xor_string_decode import _extract_numeric_array from pyjsclear.transforms.xor_string_decode import _xor_decode @@ -11,7 +9,7 @@ class TestXorDecode: """Tests for the _xor_decode helper.""" - def test_basic_xor(self): + def test_basic_xor(self) -> None: """XOR decode with known prefix and data.""" # Prefix: [1, 2, 3, 4], data XOR'd with prefix prefix = [1, 2, 3, 4] @@ -20,11 +18,11 @@ def test_basic_xor(self): result = _xor_decode(encoded) assert result == 'ABCD' - def test_too_short_returns_none(self): + def test_too_short_returns_none(self) -> None: assert _xor_decode([1, 2, 3]) is None assert _xor_decode([1, 2, 3, 4]) is None - def test_invalid_utf8_returns_none(self): + def test_invalid_utf8_returns_none(self) -> None: result = _xor_decode([0, 0, 0, 0, 0xFF, 0xFE]) assert result is None @@ -32,7 +30,7 @@ def test_invalid_utf8_returns_none(self): class TestExtractNumericArray: """Tests for _extract_numeric_array helper.""" - def test_valid_array(self): + def test_valid_array(self) -> None: node = { 'type': 'ArrayExpression', 'elements': [ @@ -42,18 +40,18 @@ def test_valid_array(self): } assert _extract_numeric_array(node) == [10, 20] - def test_non_array_returns_none(self): + def test_non_array_returns_none(self) -> None: assert _extract_numeric_array({'type': 'Literal', 'value': 1}) is None assert _extract_numeric_array(None) is None - def test_out_of_range_returns_none(self): + def test_out_of_range_returns_none(self) -> None: node = { 'type': 'ArrayExpression', 'elements': [{'type': 'Literal', 'value': 300, 'raw': '300'}], } assert _extract_numeric_array(node) is None - def test_non_numeric_element_returns_none(self): + def test_non_numeric_element_returns_none(self) -> None: node = { 'type': 'ArrayExpression', 'elements': [{'type': 'Literal', 'value': 'str', 'raw': '"str"'}], @@ -64,11 +62,11 @@ def test_non_numeric_element_returns_none(self): class TestXorStringDecoderTransform: """Tests for the full XorStringDecoder transform.""" - def test_no_decoder_returns_false(self): + def test_no_decoder_returns_false(self) -> None: result, changed = roundtrip('var x = 1;', XorStringDecoder) assert changed is False - def test_decoder_detected_and_inlined(self): + def test_decoder_detected_and_inlined(self) -> None: """Integration test: XOR decoder function + call site should be resolved.""" # Build a XOR-encoded byte array for "test" prefix = [0x10, 0x20, 0x30, 0x40] @@ -91,7 +89,7 @@ def test_decoder_detected_and_inlined(self): assert changed is True assert 'obj.test' in result or 'obj["test"]' in result or '"test"' in result - def test_standalone_identifier_replaced(self): + def test_standalone_identifier_replaced(self) -> None: """Standalone use of decoded var (e.g., require(_0xVar)) gets string literal.""" prefix = [0x10, 0x20, 0x30, 0x40] message = b'fs' @@ -113,7 +111,7 @@ def test_standalone_identifier_replaced(self): assert changed is True assert 'require("fs")' in result - def test_dead_declaration_removed(self): + def test_dead_declaration_removed(self) -> None: """After inlining, the var _0xResult = decoder(...) should be removed.""" prefix = [0x10, 0x20, 0x30, 0x40] message = b'test' @@ -136,7 +134,7 @@ def test_dead_declaration_removed(self): # _0xResult should be removed (no remaining references) assert '_0xResult' not in result - def test_non_valid_identifier_uses_string_literal(self): + def test_non_valid_identifier_uses_string_literal(self) -> None: """Decoded string with non-identifier chars stays as string literal.""" prefix = [0x10, 0x20, 0x30, 0x40] message = b'a-b' diff --git a/tests/unit/utils/ast_helpers_test.py b/tests/unit/utils/ast_helpers_test.py index de169f8..7fa16e3 100644 --- a/tests/unit/utils/ast_helpers_test.py +++ b/tests/unit/utils/ast_helpers_test.py @@ -3,8 +3,6 @@ import copy import math -import pytest - from pyjsclear.utils.ast_helpers import _CHILD_KEYS from pyjsclear.utils.ast_helpers import deep_copy from pyjsclear.utils.ast_helpers import get_child_keys @@ -32,29 +30,29 @@ class TestDeepCopy: - def test_basic_node(self): + def test_basic_node(self) -> None: node = {'type': 'Literal', 'value': 42, 'raw': '42'} result = deep_copy(node) assert result == node assert result is not node - def test_nested_node_independence(self): + def test_nested_node_independence(self) -> None: inner = {'type': 'Identifier', 'name': 'x'} outer = {'type': 'ExpressionStatement', 'expression': inner} result = deep_copy(outer) result['expression']['name'] = 'y' assert inner['name'] == 'x' - def test_list_children(self): + def test_list_children(self) -> None: node = {'type': 'BlockStatement', 'body': [{'type': 'EmptyStatement'}]} result = deep_copy(node) result['body'].append({'type': 'EmptyStatement'}) assert len(node['body']) == 1 - def test_none_passthrough(self): + def test_none_passthrough(self) -> None: assert deep_copy(None) is None - def test_empty_dict(self): + def test_empty_dict(self) -> None: assert deep_copy({}) == {} @@ -64,74 +62,74 @@ def test_empty_dict(self): class TestIsLiteral: - def test_string_literal(self): + def test_string_literal(self) -> None: assert is_literal({'type': 'Literal', 'value': 'hello', 'raw': '"hello"'}) - def test_numeric_literal(self): + def test_numeric_literal(self) -> None: assert is_literal({'type': 'Literal', 'value': 42, 'raw': '42'}) - def test_not_a_dict(self): + def test_not_a_dict(self) -> None: assert not is_literal('Literal') assert not is_literal(42) assert not is_literal(None) assert not is_literal([]) - def test_wrong_type(self): + def test_wrong_type(self) -> None: assert not is_literal({'type': 'Identifier', 'name': 'x'}) - def test_missing_type(self): + def test_missing_type(self) -> None: assert not is_literal({'value': 42}) - def test_empty_dict(self): + def test_empty_dict(self) -> None: assert not is_literal({}) class TestIsIdentifier: - def test_basic(self): + def test_basic(self) -> None: assert is_identifier({'type': 'Identifier', 'name': 'foo'}) - def test_not_identifier(self): + def test_not_identifier(self) -> None: assert not is_identifier({'type': 'Literal', 'value': 1}) - def test_non_dict(self): + def test_non_dict(self) -> None: assert not is_identifier('Identifier') assert not is_identifier(None) class TestIsStringLiteral: - def test_string(self): + def test_string(self) -> None: assert is_string_literal({'type': 'Literal', 'value': 'hello'}) - def test_empty_string(self): + def test_empty_string(self) -> None: assert is_string_literal({'type': 'Literal', 'value': ''}) - def test_number_not_string(self): + def test_number_not_string(self) -> None: assert not is_string_literal({'type': 'Literal', 'value': 42}) - def test_bool_not_string(self): + def test_bool_not_string(self) -> None: assert not is_string_literal({'type': 'Literal', 'value': True}) - def test_none_not_string(self): + def test_none_not_string(self) -> None: assert not is_string_literal({'type': 'Literal', 'value': None, 'raw': 'null'}) class TestIsNumericLiteral: - def test_int(self): + def test_int(self) -> None: assert is_numeric_literal({'type': 'Literal', 'value': 42}) - def test_float(self): + def test_float(self) -> None: assert is_numeric_literal({'type': 'Literal', 'value': 3.14}) - def test_zero(self): + def test_zero(self) -> None: assert is_numeric_literal({'type': 'Literal', 'value': 0}) - def test_negative(self): + def test_negative(self) -> None: assert is_numeric_literal({'type': 'Literal', 'value': -1}) - def test_string_not_numeric(self): + def test_string_not_numeric(self) -> None: assert not is_numeric_literal({'type': 'Literal', 'value': '42'}) - def test_bool_not_numeric(self): + def test_bool_not_numeric(self) -> None: # In Python bool is subclass of int, but isinstance(True, bool) is True # and the function checks for (int, float). Since bool IS int, this returns True. # However, the check order matters: is_boolean_literal checks bool first. @@ -140,45 +138,45 @@ def test_bool_not_numeric(self): class TestIsBooleanLiteral: - def test_true(self): + def test_true(self) -> None: assert is_boolean_literal({'type': 'Literal', 'value': True}) - def test_false(self): + def test_false(self) -> None: assert is_boolean_literal({'type': 'Literal', 'value': False}) - def test_int_not_bool(self): + def test_int_not_bool(self) -> None: assert not is_boolean_literal({'type': 'Literal', 'value': 1}) - def test_string_not_bool(self): + def test_string_not_bool(self) -> None: assert not is_boolean_literal({'type': 'Literal', 'value': 'true'}) class TestIsNullLiteral: - def test_null(self): + def test_null(self) -> None: assert is_null_literal({'type': 'Literal', 'value': None, 'raw': 'null'}) - def test_none_without_raw(self): + def test_none_without_raw(self) -> None: # Must have raw == 'null' to be considered null literal assert not is_null_literal({'type': 'Literal', 'value': None}) - def test_none_wrong_raw(self): + def test_none_wrong_raw(self) -> None: assert not is_null_literal({'type': 'Literal', 'value': None, 'raw': 'undefined'}) - def test_non_literal(self): + def test_non_literal(self) -> None: assert not is_null_literal({'type': 'Identifier', 'name': 'null'}) class TestIsUndefined: - def test_undefined_identifier(self): + def test_undefined_identifier(self) -> None: assert is_undefined({'type': 'Identifier', 'name': 'undefined'}) - def test_other_identifier(self): + def test_other_identifier(self) -> None: assert not is_undefined({'type': 'Identifier', 'name': 'null'}) - def test_literal_not_undefined(self): + def test_literal_not_undefined(self) -> None: assert not is_undefined({'type': 'Literal', 'value': None, 'raw': 'null'}) - def test_non_dict(self): + def test_non_dict(self) -> None: assert not is_undefined(None) @@ -188,33 +186,33 @@ def test_non_dict(self): class TestGetLiteralValue: - def test_string_value(self): + def test_string_value(self) -> None: val, ok = get_literal_value({'type': 'Literal', 'value': 'hello'}) assert ok is True assert val == 'hello' - def test_numeric_value(self): + def test_numeric_value(self) -> None: val, ok = get_literal_value({'type': 'Literal', 'value': 42}) assert ok is True assert val == 42 - def test_none_value_null(self): + def test_none_value_null(self) -> None: val, ok = get_literal_value({'type': 'Literal', 'value': None, 'raw': 'null'}) assert ok is True assert val is None - def test_non_literal(self): + def test_non_literal(self) -> None: val, ok = get_literal_value({'type': 'Identifier', 'name': 'x'}) assert ok is False assert val is None - def test_literal_missing_value_key(self): + def test_literal_missing_value_key(self) -> None: # A Literal node without a 'value' key; .get returns None val, ok = get_literal_value({'type': 'Literal', 'raw': '42'}) assert ok is True assert val is None - def test_bool_value(self): + def test_bool_value(self) -> None: val, ok = get_literal_value({'type': 'Literal', 'value': True}) assert ok is True assert val is True @@ -226,43 +224,43 @@ def test_bool_value(self): class TestMakeLiteral: - def test_integer(self): + def test_integer(self) -> None: node = make_literal(42) assert node == {'type': 'Literal', 'value': 42, 'raw': '42'} - def test_float_whole_number(self): + def test_float_whole_number(self) -> None: # Whole floats are rendered as int strings node = make_literal(3.0) assert node['raw'] == '3' - def test_float_fractional(self): + def test_float_fractional(self) -> None: node = make_literal(3.14) assert node['raw'] == '3.14' - def test_negative_zero_float(self): + def test_negative_zero_float(self) -> None: # -0.0 in Python: str(-0.0) is '-0.0', and -0.0 == 0 is True # but str(-0.0).startswith('-') triggers the else branch node = make_literal(-0.0) assert node['raw'] == '-0.0' - def test_boolean_true(self): + def test_boolean_true(self) -> None: node = make_literal(True) assert node == {'type': 'Literal', 'value': True, 'raw': 'true'} - def test_boolean_false(self): + def test_boolean_false(self) -> None: node = make_literal(False) assert node == {'type': 'Literal', 'value': False, 'raw': 'false'} - def test_null(self): + def test_null(self) -> None: node = make_literal(None) assert node == {'type': 'Literal', 'value': None, 'raw': 'null'} - def test_simple_string(self): + def test_simple_string(self) -> None: node = make_literal('hello') assert node['value'] == 'hello' assert node['raw'] == '"hello"' - def test_string_with_double_quotes(self): + def test_string_with_double_quotes(self) -> None: """Bug #2: String containing double quotes. repr('say "hi"') gives The raw value must properly escape inner double quotes.""" node = make_literal('say "hi"') @@ -271,15 +269,15 @@ def test_string_with_double_quotes(self): assert raw.endswith('"') assert raw == '"say \\"hi\\""' - def test_custom_raw(self): + def test_custom_raw(self) -> None: node = make_literal(42, raw='0x2A') assert node == {'type': 'Literal', 'value': 42, 'raw': '0x2A'} - def test_negative_int(self): + def test_negative_int(self) -> None: node = make_literal(-5) assert node['raw'] == '-5' - def test_large_float(self): + def test_large_float(self) -> None: node = make_literal(1e10) # 1e10 == int(1e10) and not negative zero, so raw = str(int(1e10)) assert node['raw'] == '10000000000' @@ -290,7 +288,7 @@ def test_large_float(self): # then checks if raw starts with '"' and if not, re-wraps. # This section documents the current behavior for edge cases. - def test_bug2_string_with_single_quotes(self): + def test_bug2_string_with_single_quotes(self) -> None: """Bug #2: Strings containing single quotes. repr() would use double quotes for the Python repr, so replacing ' with " produces unexpected results. @@ -305,7 +303,7 @@ def test_bug2_string_with_single_quotes(self): assert raw.startswith('"') assert raw.endswith('"') - def test_bug2_backslash_string(self): + def test_bug2_backslash_string(self) -> None: """Bug #2: Strings with backslashes. repr() produces escape sequences which then get the quote replacement applied. @@ -320,7 +318,7 @@ def test_bug2_backslash_string(self): # but repr gives us Python escaping, not JS escaping assert '\\\\' in raw # double-escaped backslash from repr - def test_bug2_unicode_string(self): + def test_bug2_unicode_string(self) -> None: """Bug #2: Unicode strings. repr() may produce \\uXXXX or the literal char depending on the character. For printable unicode, repr includes it literally.""" node = make_literal('\u00e9') # e-acute @@ -328,7 +326,7 @@ def test_bug2_unicode_string(self): assert raw.startswith('"') assert raw.endswith('"') - def test_bug2_newline_string(self): + def test_bug2_newline_string(self) -> None: """Bug #2: Strings with newlines. repr() produces \\n which is valid JS too, so this case happens to work.""" node = make_literal('line1\nline2') @@ -337,13 +335,13 @@ def test_bug2_newline_string(self): assert raw.endswith('"') assert '\\n' in raw - def test_bug2_tab_string(self): + def test_bug2_tab_string(self) -> None: """Bug #2: Strings with tabs. repr() produces \\t which is valid JS.""" node = make_literal('a\tb') raw = node['raw'] assert '\\t' in raw - def test_bug2_string_with_both_quote_types(self): + def test_bug2_string_with_both_quote_types(self) -> None: """Bug #2: String containing both single and double quotes. repr() for a string with both quotes uses single-quote wrapping and escapes the single quotes. The replace("'", '"') then converts ALL @@ -353,12 +351,12 @@ def test_bug2_string_with_both_quote_types(self): assert raw.startswith('"') assert raw.endswith('"') - def test_bug2_empty_string(self): + def test_bug2_empty_string(self) -> None: """Bug #2: Empty string should produce '""'.""" node = make_literal('') assert node['raw'] == '""' - def test_bug2_null_byte_string(self): + def test_bug2_null_byte_string(self) -> None: """Bug #2: String with null byte. repr() produces \\x00 which is NOT valid JS (JS uses \\0 or \\u0000). This documents the current behavior.""" node = make_literal('a\x00b') @@ -374,10 +372,10 @@ def test_bug2_null_byte_string(self): class TestMakeIdentifier: - def test_basic(self): + def test_basic(self) -> None: assert make_identifier('foo') == {'type': 'Identifier', 'name': 'foo'} - def test_underscore(self): + def test_underscore(self) -> None: assert make_identifier('_') == {'type': 'Identifier', 'name': '_'} @@ -387,7 +385,7 @@ def test_underscore(self): class TestMakeExpressionStatement: - def test_wraps_expression(self): + def test_wraps_expression(self) -> None: expr = {'type': 'Literal', 'value': 42, 'raw': '42'} result = make_expression_statement(expr) assert result['type'] == 'ExpressionStatement' @@ -400,11 +398,11 @@ def test_wraps_expression(self): class TestMakeBlockStatement: - def test_empty_body(self): + def test_empty_body(self) -> None: result = make_block_statement([]) assert result == {'type': 'BlockStatement', 'body': []} - def test_with_statements(self): + def test_with_statements(self) -> None: stmts = [{'type': 'EmptyStatement'}, {'type': 'EmptyStatement'}] result = make_block_statement(stmts) assert result['body'] is stmts @@ -417,7 +415,7 @@ def test_with_statements(self): class TestMakeVarDeclaration: - def test_var_no_init(self): + def test_var_no_init(self) -> None: result = make_var_declaration('x') assert result['type'] == 'VariableDeclaration' assert result['kind'] == 'var' @@ -427,13 +425,13 @@ def test_var_no_init(self): assert decl['id'] == {'type': 'Identifier', 'name': 'x'} assert decl['init'] is None - def test_let_with_init(self): + def test_let_with_init(self) -> None: init = {'type': 'Literal', 'value': 5, 'raw': '5'} result = make_var_declaration('y', init=init, kind='let') assert result['kind'] == 'let' assert result['declarations'][0]['init'] is init - def test_const(self): + def test_const(self) -> None: result = make_var_declaration('Z', kind='const') assert result['kind'] == 'const' @@ -444,40 +442,40 @@ def test_const(self): class TestIsValidIdentifier: - def test_simple(self): + def test_simple(self) -> None: assert is_valid_identifier('foo') - def test_underscore_prefix(self): + def test_underscore_prefix(self) -> None: assert is_valid_identifier('_private') - def test_dollar_prefix(self): + def test_dollar_prefix(self) -> None: assert is_valid_identifier('$jquery') - def test_digits_allowed_after_first(self): + def test_digits_allowed_after_first(self) -> None: assert is_valid_identifier('x1') - def test_starts_with_digit(self): + def test_starts_with_digit(self) -> None: assert not is_valid_identifier('1abc') - def test_empty_string(self): + def test_empty_string(self) -> None: assert not is_valid_identifier('') - def test_none(self): + def test_none(self) -> None: assert not is_valid_identifier(None) - def test_non_string(self): + def test_non_string(self) -> None: assert not is_valid_identifier(42) - def test_hyphen(self): + def test_hyphen(self) -> None: assert not is_valid_identifier('foo-bar') - def test_single_dollar(self): + def test_single_dollar(self) -> None: assert is_valid_identifier('$') - def test_single_underscore(self): + def test_single_underscore(self) -> None: assert is_valid_identifier('_') - def test_reserved_words_pass_regex(self): + def test_reserved_words_pass_regex(self) -> None: # The function only does regex check, not reserved word check assert is_valid_identifier('if') assert is_valid_identifier('return') @@ -490,25 +488,25 @@ def test_reserved_words_pass_regex(self): class TestChildKeys: - def test_known_node_type(self): + def test_known_node_type(self) -> None: node = {'type': 'BinaryExpression', 'left': {}, 'right': {}, 'operator': '+'} assert get_child_keys(node) == ('left', 'right') - def test_literal_has_no_children(self): + def test_literal_has_no_children(self) -> None: assert get_child_keys({'type': 'Literal', 'value': 1}) == () - def test_identifier_has_no_children(self): + def test_identifier_has_no_children(self) -> None: assert get_child_keys({'type': 'Identifier', 'name': 'x'}) == () - def test_non_dict_returns_empty(self): + def test_non_dict_returns_empty(self) -> None: assert get_child_keys('not a node') == () assert get_child_keys(None) == () assert get_child_keys(42) == () - def test_missing_type_returns_empty(self): + def test_missing_type_returns_empty(self) -> None: assert get_child_keys({'value': 42}) == () - def test_unknown_node_type_fallback(self): + def test_unknown_node_type_fallback(self) -> None: # Unknown type falls back to heuristic: keys with dict/list values not in _SKIP_KEYS node = {'type': 'UnknownThing', 'body': [], 'extra': {}, 'name': 'test'} keys = get_child_keys(node) @@ -517,13 +515,13 @@ def test_unknown_node_type_fallback(self): # 'name' is in _SKIP_KEYS so should not appear assert 'name' not in keys - def test_fallback_skips_type_key(self): + def test_fallback_skips_type_key(self) -> None: node = {'type': 'CustomNode', 'child': {'type': 'Literal'}} keys = get_child_keys(node) assert 'type' not in keys assert 'child' in keys - def test_all_known_types_in_child_keys(self): + def test_all_known_types_in_child_keys(self) -> None: # Sanity check: every value in _CHILD_KEYS is a tuple for node_type, keys in _CHILD_KEYS.items(): assert isinstance(keys, tuple), f'{node_type} keys is not a tuple' @@ -535,7 +533,7 @@ def test_all_known_types_in_child_keys(self): class TestReplaceIdentifiers: - def test_simple_replacement(self): + def test_simple_replacement(self) -> None: node = { 'type': 'BinaryExpression', 'operator': '+', @@ -547,7 +545,7 @@ def test_simple_replacement(self): assert node['left'] == {'type': 'Literal', 'value': 1, 'raw': '1'} assert node['right'] == {'type': 'Identifier', 'name': 'b'} - def test_replacement_is_deep_copied(self): + def test_replacement_is_deep_copied(self) -> None: replacement = {'type': 'Literal', 'value': 1, 'raw': '1'} node = { 'type': 'ExpressionStatement', @@ -558,7 +556,7 @@ def test_replacement_is_deep_copied(self): replacement['value'] = 999 assert node['expression']['value'] == 1 - def test_skips_non_computed_member_property(self): + def test_skips_non_computed_member_property(self) -> None: # obj.foo -- 'foo' should NOT be replaced even if in param_map node = { 'type': 'MemberExpression', @@ -576,7 +574,7 @@ def test_skips_non_computed_member_property(self): # property should NOT be replaced (non-computed) assert node['property']['name'] == 'foo' - def test_replaces_computed_member_property(self): + def test_replaces_computed_member_property(self) -> None: # obj[foo] -- 'foo' SHOULD be replaced node = { 'type': 'MemberExpression', @@ -588,7 +586,7 @@ def test_replaces_computed_member_property(self): replace_identifiers(node, param_map) assert node['property'] == {'type': 'Literal', 'value': 'bar', 'raw': '"bar"'} - def test_replaces_in_array_children(self): + def test_replaces_in_array_children(self) -> None: node = { 'type': 'ArrayExpression', 'elements': [ @@ -602,7 +600,7 @@ def test_replaces_in_array_children(self): assert node['elements'][0] == {'type': 'Literal', 'value': 1, 'raw': '1'} assert node['elements'][1] == {'type': 'Identifier', 'name': 'b'} - def test_recursive_replacement(self): + def test_recursive_replacement(self) -> None: node = { 'type': 'ExpressionStatement', 'expression': { @@ -616,18 +614,18 @@ def test_recursive_replacement(self): replace_identifiers(node, param_map) assert node['expression']['left'] == {'type': 'Literal', 'value': 10, 'raw': '10'} - def test_non_dict_input_noop(self): + def test_non_dict_input_noop(self) -> None: # Should not raise replace_identifiers(None, {'x': {'type': 'Literal'}}) replace_identifiers('string', {'x': {'type': 'Literal'}}) replace_identifiers([], {'x': {'type': 'Literal'}}) - def test_no_type_key_noop(self): + def test_no_type_key_noop(self) -> None: node = {'value': 42} replace_identifiers(node, {'value': {'type': 'Literal'}}) assert node == {'value': 42} - def test_child_is_none_skipped(self): + def test_child_is_none_skipped(self) -> None: node = { 'type': 'ReturnStatement', 'argument': None, @@ -635,7 +633,7 @@ def test_child_is_none_skipped(self): # Should not raise replace_identifiers(node, {'x': {'type': 'Literal'}}) - def test_nested_list_with_non_identifier_dicts(self): + def test_nested_list_with_non_identifier_dicts(self) -> None: # Items in a list that are dicts with 'type' but not Identifier should recurse node = { 'type': 'ArrayExpression', @@ -659,69 +657,69 @@ def test_nested_list_with_non_identifier_dicts(self): class TestNodesEqual: - def test_identical_literals(self): + def test_identical_literals(self) -> None: a = {'type': 'Literal', 'value': 42, 'raw': '42'} b = {'type': 'Literal', 'value': 42, 'raw': '42'} assert nodes_equal(a, b) - def test_different_values(self): + def test_different_values(self) -> None: a = {'type': 'Literal', 'value': 42, 'raw': '42'} b = {'type': 'Literal', 'value': 43, 'raw': '43'} assert not nodes_equal(a, b) - def test_ignores_position_info(self): + def test_ignores_position_info(self) -> None: a = {'type': 'Literal', 'value': 1, 'raw': '1', 'start': 0, 'end': 1} b = {'type': 'Literal', 'value': 1, 'raw': '1', 'start': 50, 'end': 51} assert nodes_equal(a, b) - def test_ignores_loc(self): + def test_ignores_loc(self) -> None: a = {'type': 'Identifier', 'name': 'x', 'loc': {'start': {'line': 1}}} b = {'type': 'Identifier', 'name': 'x', 'loc': {'start': {'line': 5}}} assert nodes_equal(a, b) - def test_ignores_range(self): + def test_ignores_range(self) -> None: a = {'type': 'Identifier', 'name': 'x', 'range': [0, 1]} b = {'type': 'Identifier', 'name': 'x', 'range': [10, 11]} assert nodes_equal(a, b) - def test_different_types(self): + def test_different_types(self) -> None: a = {'type': 'Literal', 'value': 1} b = {'type': 'Identifier', 'name': '1'} assert not nodes_equal(a, b) - def test_list_equality(self): + def test_list_equality(self) -> None: a = [{'type': 'Literal', 'value': 1}, {'type': 'Literal', 'value': 2}] b = [{'type': 'Literal', 'value': 1}, {'type': 'Literal', 'value': 2}] assert nodes_equal(a, b) - def test_list_different_length(self): + def test_list_different_length(self) -> None: a = [{'type': 'Literal', 'value': 1}] b = [{'type': 'Literal', 'value': 1}, {'type': 'Literal', 'value': 2}] assert not nodes_equal(a, b) - def test_list_different_order(self): + def test_list_different_order(self) -> None: a = [{'type': 'Literal', 'value': 1}, {'type': 'Literal', 'value': 2}] b = [{'type': 'Literal', 'value': 2}, {'type': 'Literal', 'value': 1}] assert not nodes_equal(a, b) - def test_scalar_equality(self): + def test_scalar_equality(self) -> None: assert nodes_equal(42, 42) assert nodes_equal('hello', 'hello') assert not nodes_equal(42, 43) assert not nodes_equal('a', 'b') - def test_type_mismatch(self): + def test_type_mismatch(self) -> None: # type(a) != type(b) returns False assert not nodes_equal(42, '42') assert not nodes_equal({}, []) assert not nodes_equal(None, {}) - def test_dict_extra_key(self): + def test_dict_extra_key(self) -> None: a = {'type': 'Literal', 'value': 1} b = {'type': 'Literal', 'value': 1, 'extra': True} assert not nodes_equal(a, b) - def test_nested_structures(self): + def test_nested_structures(self) -> None: a = { 'type': 'BinaryExpression', 'operator': '+', @@ -733,7 +731,7 @@ def test_nested_structures(self): b['right']['value'] = 3 assert not nodes_equal(a, b) - def test_nested_with_position_ignored(self): + def test_nested_with_position_ignored(self) -> None: a = { 'type': 'BinaryExpression', 'operator': '+', @@ -748,15 +746,15 @@ def test_nested_with_position_ignored(self): } assert nodes_equal(a, b) - def test_empty_dicts(self): + def test_empty_dicts(self) -> None: assert nodes_equal({}, {}) - def test_empty_lists(self): + def test_empty_lists(self) -> None: assert nodes_equal([], []) - def test_none_values(self): + def test_none_values(self) -> None: assert nodes_equal(None, None) - def test_bool_equality(self): + def test_bool_equality(self) -> None: assert nodes_equal(True, True) assert not nodes_equal(True, False) diff --git a/tests/unit/utils/string_decoders_test.py b/tests/unit/utils/string_decoders_test.py index 50e96f1..1ddc8b9 100644 --- a/tests/unit/utils/string_decoders_test.py +++ b/tests/unit/utils/string_decoders_test.py @@ -16,19 +16,19 @@ class TestDecoderType: - def test_basic_value(self): + def test_basic_value(self) -> None: assert DecoderType.BASIC.value == 'basic' - def test_base64_value(self): + def test_base64_value(self) -> None: assert DecoderType.BASE_64.value == 'base64' - def test_rc4_value(self): + def test_rc4_value(self) -> None: assert DecoderType.RC4.value == 'rc4' - def test_enum_members(self): + def test_enum_members(self) -> None: assert set(DecoderType) == {DecoderType.BASIC, DecoderType.BASE_64, DecoderType.RC4} - def test_lookup_by_value(self): + def test_lookup_by_value(self) -> None: assert DecoderType('basic') is DecoderType.BASIC assert DecoderType('base64') is DecoderType.BASE_64 assert DecoderType('rc4') is DecoderType.RC4 @@ -40,44 +40,44 @@ def test_lookup_by_value(self): class TestBase64Transform: - def test_empty_string(self): + def test_empty_string(self) -> None: assert base64_transform('') == '' - def test_non_standard_alphabet(self): + def test_non_standard_alphabet(self) -> None: """Standard base64 'aGVsbG8=' decodes to 'hello', but the custom alphabet reverses case mapping so the result must differ.""" result = base64_transform('aGVsbG8=') assert result != 'hello' - def test_decode_known_4char_group(self): + def test_decode_known_4char_group(self) -> None: # 'abcd' -> indices 0,1,2,3 in custom alphabet -> bytes [0, 16, 131] assert [ord(c) for c in base64_transform('abcd')] == [0, 16, 131] - def test_decode_uppercase_group(self): + def test_decode_uppercase_group(self) -> None: # 'ABCD' -> indices 26,27,28,29 -> bytes [105, 183, 29] result = base64_transform('ABCD') assert [ord(c) for c in result] == [105, 183, 29] - def test_decode_mixed_case(self): + def test_decode_mixed_case(self) -> None: assert [ord(c) for c in base64_transform('aBcD')] == [1, 176, 157] - def test_decode_two_groups(self): + def test_decode_two_groups(self) -> None: # 8 chars = two 4-char groups = 6 decoded bytes assert [ord(c) for c in base64_transform('abcdefgh')] == [0, 16, 131, 16, 81, 135] - def test_padding_double_equals(self): + def test_padding_double_equals(self) -> None: # 'ab==' has one meaningful 6-bit pair assert [ord(c) for c in base64_transform('ab==')] == [0, 32, 64] - def test_padding_single_equals(self): + def test_padding_single_equals(self) -> None: assert [ord(c) for c in base64_transform('abc=')] == [0, 16, 192] - def test_invalid_chars_are_skipped(self): + def test_invalid_chars_are_skipped(self) -> None: """Characters not in the alphabet should be silently ignored.""" # '$' and '!' are not in the alphabet assert base64_transform('$ab!cd') == base64_transform('abcd') - def test_returns_string(self): + def test_returns_string(self) -> None: result = base64_transform('abcd') assert isinstance(result, str) @@ -88,24 +88,24 @@ def test_returns_string(self): class TestStringDecoder: - def test_get_string_raises_not_implemented(self): + def test_get_string_raises_not_implemented(self) -> None: decoder = StringDecoder(['a', 'b'], 0) with pytest.raises(NotImplementedError): decoder.get_string(0) - def test_get_string_for_rotation_raises_on_first_call(self): + def test_get_string_for_rotation_raises_on_first_call(self) -> None: decoder = BasicStringDecoder(['hello'], 0) with pytest.raises(RuntimeError, match='First call'): decoder.get_string_for_rotation(0) - def test_get_string_for_rotation_works_on_second_call(self): + def test_get_string_for_rotation_works_on_second_call(self) -> None: decoder = BasicStringDecoder(['hello', 'world'], 0) with pytest.raises(RuntimeError): decoder.get_string_for_rotation(0) # Second call should succeed assert decoder.get_string_for_rotation(0) == 'hello' - def test_get_string_for_rotation_passes_args(self): + def test_get_string_for_rotation_passes_args(self) -> None: decoder = BasicStringDecoder(['hello', 'world'], 0) # Exhaust first-call guard with pytest.raises(RuntimeError): @@ -113,7 +113,7 @@ def test_get_string_for_rotation_passes_args(self): # Second call with index 1 assert decoder.get_string_for_rotation(1) == 'world' - def test_is_first_call_flag(self): + def test_is_first_call_flag(self) -> None: decoder = BasicStringDecoder(['x'], 0) assert decoder.is_first_call is True with pytest.raises(RuntimeError): @@ -127,58 +127,58 @@ def test_is_first_call_flag(self): class TestBasicStringDecoder: - def test_type_property(self): + def test_type_property(self) -> None: decoder = BasicStringDecoder(['a'], 0) assert decoder.type == DecoderType.BASIC - def test_zero_offset(self): + def test_zero_offset(self) -> None: arr = ['alpha', 'beta', 'gamma'] decoder = BasicStringDecoder(arr, 0) assert decoder.get_string(0) == 'alpha' assert decoder.get_string(1) == 'beta' assert decoder.get_string(2) == 'gamma' - def test_positive_offset(self): + def test_positive_offset(self) -> None: arr = ['a', 'b', 'c', 'd'] decoder = BasicStringDecoder(arr, 2) # index 0 -> array[0+2] = 'c' assert decoder.get_string(0) == 'c' assert decoder.get_string(1) == 'd' - def test_negative_offset(self): + def test_negative_offset(self) -> None: arr = ['a', 'b', 'c', 'd'] decoder = BasicStringDecoder(arr, -2) # index 2 -> array[2-2] = 'a' assert decoder.get_string(2) == 'a' assert decoder.get_string(3) == 'b' - def test_out_of_bounds_positive(self): + def test_out_of_bounds_positive(self) -> None: arr = ['a', 'b'] decoder = BasicStringDecoder(arr, 0) assert decoder.get_string(5) is None - def test_out_of_bounds_negative(self): + def test_out_of_bounds_negative(self) -> None: arr = ['a', 'b'] decoder = BasicStringDecoder(arr, -5) assert decoder.get_string(0) is None - def test_exact_boundary(self): + def test_exact_boundary(self) -> None: arr = ['a', 'b', 'c'] decoder = BasicStringDecoder(arr, 0) assert decoder.get_string(2) == 'c' # last valid index assert decoder.get_string(3) is None # one past end - def test_negative_index_yields_negative_array_index(self): + def test_negative_index_yields_negative_array_index(self) -> None: arr = ['a', 'b'] decoder = BasicStringDecoder(arr, 0) # index -1 -> array[-1] which is < 0 -> returns None assert decoder.get_string(-1) is None - def test_empty_array(self): + def test_empty_array(self) -> None: decoder = BasicStringDecoder([], 0) assert decoder.get_string(0) is None - def test_extra_args_are_ignored(self): + def test_extra_args_are_ignored(self) -> None: decoder = BasicStringDecoder(['x'], 0) assert decoder.get_string(0, 'extra', 'args') == 'x' @@ -189,24 +189,24 @@ def test_extra_args_are_ignored(self): class TestBase64StringDecoder: - def test_type_property(self): + def test_type_property(self) -> None: decoder = Base64StringDecoder(['x'], 0) assert decoder.type == DecoderType.BASE_64 - def test_decodes_value(self): + def test_decodes_value(self) -> None: decoder = Base64StringDecoder(['abcd'], 0) assert decoder.get_string(0) == base64_transform('abcd') - def test_with_offset(self): + def test_with_offset(self) -> None: decoder = Base64StringDecoder(['SKIP', 'abcd'], 1) # index 0 -> array[0+1] = 'abcd' assert decoder.get_string(0) == base64_transform('abcd') - def test_out_of_bounds_returns_none(self): + def test_out_of_bounds_returns_none(self) -> None: decoder = Base64StringDecoder(['abcd'], 0) assert decoder.get_string(5) is None - def test_caching(self): + def test_caching(self) -> None: decoder = Base64StringDecoder(['abcd'], 0) result1 = decoder.get_string(0) result2 = decoder.get_string(0) @@ -215,7 +215,7 @@ def test_caching(self): assert 0 in decoder._cache assert decoder._cache[0] == result1 - def test_multiple_indices(self): + def test_multiple_indices(self) -> None: decoder = Base64StringDecoder(['abcd', 'ABCD'], 0) r0 = decoder.get_string(0) r1 = decoder.get_string(1) @@ -223,11 +223,11 @@ def test_multiple_indices(self): assert r1 == base64_transform('ABCD') assert r0 != r1 - def test_empty_encoded_string(self): + def test_empty_encoded_string(self) -> None: decoder = Base64StringDecoder([''], 0) assert decoder.get_string(0) == '' - def test_get_string_for_rotation(self): + def test_get_string_for_rotation(self) -> None: decoder = Base64StringDecoder(['abcd'], 0) with pytest.raises(RuntimeError): decoder.get_string_for_rotation(0) @@ -241,36 +241,36 @@ def test_get_string_for_rotation(self): class TestRc4StringDecoder: - def test_type_property(self): + def test_type_property(self) -> None: decoder = Rc4StringDecoder(['x'], 0) assert decoder.type == DecoderType.RC4 - def test_key_none_returns_none(self): + def test_key_none_returns_none(self) -> None: decoder = Rc4StringDecoder(['abcd'], 0) assert decoder.get_string(0) is None - def test_key_none_explicit(self): + def test_key_none_explicit(self) -> None: decoder = Rc4StringDecoder(['abcd'], 0) assert decoder.get_string(0, key=None) is None - def test_out_of_bounds_returns_none(self): + def test_out_of_bounds_returns_none(self) -> None: decoder = Rc4StringDecoder(['abcd'], 0) assert decoder.get_string(5, key='k') is None - def test_decodes_with_key(self): + def test_decodes_with_key(self) -> None: decoder = Rc4StringDecoder(['abcd'], 0) result = decoder.get_string(0, key='k') assert result is not None assert isinstance(result, str) assert len(result) == 3 - def test_different_keys_give_different_results(self): + def test_different_keys_give_different_results(self) -> None: decoder = Rc4StringDecoder(['abcd'], 0) r1 = decoder.get_string(0, key='testkey') r2 = decoder.get_string(0, key='otherkey') assert r1 != r2 - def test_caching_with_same_key(self): + def test_caching_with_same_key(self) -> None: decoder = Rc4StringDecoder(['abcd'], 0) r1 = decoder.get_string(0, key='mykey') r2 = decoder.get_string(0, key='mykey') @@ -278,7 +278,7 @@ def test_caching_with_same_key(self): assert r1 is r2 # same cached object assert (0, 'mykey') in decoder._cache - def test_cache_keyed_by_index_and_key(self): + def test_cache_keyed_by_index_and_key(self) -> None: decoder = Rc4StringDecoder(['abcd', 'ABCD'], 0) r1 = decoder.get_string(0, key='k') r2 = decoder.get_string(1, key='k') @@ -286,25 +286,25 @@ def test_cache_keyed_by_index_and_key(self): assert (1, 'k') in decoder._cache assert r1 != r2 - def test_with_offset(self): + def test_with_offset(self) -> None: decoder = Rc4StringDecoder(['SKIP', 'abcd'], 1) result = decoder.get_string(0, key='k') assert result is not None - def test_same_input_same_key_deterministic(self): + def test_same_input_same_key_deterministic(self) -> None: """Same input and key always produce the same output.""" dec1 = Rc4StringDecoder(['abcd'], 0) dec2 = Rc4StringDecoder(['abcd'], 0) assert dec1.get_string(0, key='k') == dec2.get_string(0, key='k') - def test_get_string_for_rotation(self): + def test_get_string_for_rotation(self) -> None: decoder = Rc4StringDecoder(['abcd'], 0) with pytest.raises(RuntimeError): decoder.get_string_for_rotation(0, key='k') result = decoder.get_string_for_rotation(0, key='k') assert result is not None - def test_empty_array(self): + def test_empty_array(self) -> None: decoder = Rc4StringDecoder([], 0) assert decoder.get_string(0, key='k') is None @@ -315,7 +315,7 @@ def test_empty_array(self): class TestCrossDecoder: - def test_basic_does_not_decode(self): + def test_basic_does_not_decode(self) -> None: """BasicStringDecoder returns the raw string, not decoded.""" arr = ['abcd'] basic = BasicStringDecoder(arr, 0) From 3dc0e26b10a42498d361400d0ba2501bec40b55d Mon Sep 17 00:00:00 2001 From: Itamar Gafni Date: Fri, 13 Mar 2026 09:37:11 +0200 Subject: [PATCH 3/8] Add depth-limited hybrid traversal to recover 16% perf regression The iterative-only traversal caused a 16% regression due to Python tuple allocation overhead. This adds recursive fast paths for enter-only visitors (all 46 call sites) and scope tree building, with automatic fallback to iterative traversal at depth > 500 to prevent stack overflow on deep ASTs. Benchmark: 2.95s mean (was 3.67s iterative-only, 3.17s original recursive). Co-Authored-By: Claude Opus 4.6 --- pyjsclear/scope.py | 474 ++++++++++++++++++++++------------- pyjsclear/traverser.py | 374 +++++++++++++++++++++------ tests/unit/scope_test.py | 92 +++++++ tests/unit/traverser_test.py | 68 +++++ 4 files changed, 762 insertions(+), 246 deletions(-) diff --git a/pyjsclear/scope.py b/pyjsclear/scope.py index 7b97a3a..d72fd55 100644 --- a/pyjsclear/scope.py +++ b/pyjsclear/scope.py @@ -7,6 +7,15 @@ from .utils.ast_helpers import get_child_keys +# Local aliases for hot-path performance +_isinstance = isinstance +_dict = dict +_list = list + +# Maximum recursion depth before falling back to iterative traversal. +_MAX_RECURSIVE_DEPTH = 500 + + class Binding: """Represents a variable binding in a scope.""" @@ -50,7 +59,7 @@ def add_binding(self, name: str, node: dict, kind: str) -> Binding: self.bindings[name] = binding return binding - def get_binding(self, name: str) -> Binding | None: + def get_binding(self, name: str) -> 'Binding | None': """Look up a binding, walking up the scope chain.""" if name in self.bindings: return self.bindings[name] @@ -58,11 +67,11 @@ def get_binding(self, name: str) -> Binding | None: return self.parent.get_binding(name) return None - def get_own_binding(self, name: str) -> Binding | None: + def get_own_binding(self, name: str) -> 'Binding | None': return self.bindings.get(name) -def _nearest_function_scope(scope: Scope | None) -> Scope | None: +def _nearest_function_scope(scope: 'Scope | None') -> 'Scope | None': """Walk up to the nearest function (or root) scope.""" while scope and not scope.is_function: scope = scope.parent @@ -87,208 +96,279 @@ def _is_non_reference_identifier(parent: dict | None, parent_key: str | None) -> return False -def _recurse_into_children( - node: dict, child_keys_map: dict, callback: Callable[[dict], Any] +def _collect_pattern_names( + pattern: dict | None, + scope: 'Scope', + kind: str, + declaration: dict, ) -> None: - """Walk child nodes, calling callback(child_node) for each dict with 'type'.""" + """Collect binding names from destructuring patterns.""" + if not _isinstance(pattern, _dict): + return + match pattern.get('type', ''): + case 'ArrayPattern': + for element in pattern.get('elements', []): + if not element: + continue + if element.get('type') == 'Identifier': + scope.add_binding(element['name'], declaration, kind) + else: + _collect_pattern_names(element, scope, kind, declaration) + case 'ObjectPattern': + for property_node in pattern.get('properties', []): + value_node = property_node.get('value', property_node.get('argument')) + if not value_node: + continue + if value_node.get('type') == 'Identifier': + scope.add_binding(value_node['name'], declaration, kind) + else: + _collect_pattern_names(value_node, scope, kind, declaration) + case 'RestElement': + argument_node = pattern.get('argument') + if argument_node and argument_node.get('type') == 'Identifier': + scope.add_binding(argument_node['name'], declaration, kind) + case 'AssignmentPattern': + left = pattern.get('left') + if left and left.get('type') == 'Identifier': + scope.add_binding(left['name'], declaration, kind) + + +def _push_children_to_stack( + node: dict, + scope: 'Scope', + stack: list, + child_keys_map: dict, +) -> None: + """Push child nodes onto a stack in reversed order for left-to-right processing.""" node_type = node.get('type') child_keys = child_keys_map.get(node_type) if child_keys is None: child_keys = get_child_keys(node) - for key in child_keys: + for key in reversed(child_keys): child = node.get(key) if child is None: continue - if isinstance(child, list): - for item in child: - if isinstance(item, dict) and 'type' in item: - callback(item) - elif isinstance(child, dict) and 'type' in child: - callback(child) + if _isinstance(child, _list): + for index in range(len(child) - 1, -1, -1): + item = child[index] + if _isinstance(item, _dict) and 'type' in item: + stack.append((item, scope)) + elif _isinstance(child, _dict) and 'type' in child: + stack.append((child, scope)) + + +def _process_declaration_node( + node: dict, + node_type: str, + scope: 'Scope', + node_scope: dict[int, 'Scope'], + all_scopes: list['Scope'], + push_target: list, + push_children_fn: Callable, +) -> None: + """Process a single node for declaration collection. + + push_target: a list to append (node, scope) tuples to. + push_children_fn: callable(node, scope, push_target) to push child nodes. + """ + if node_type in ('FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'): + new_scope = Scope(scope, node, is_function=True) + node_scope[id(node)] = new_scope + all_scopes.append(new_scope) + + if node_type == 'FunctionDeclaration' and node.get('id'): + scope.add_binding(node['id']['name'], node, 'function') + elif node_type == 'FunctionExpression' and node.get('id'): + new_scope.add_binding(node['id']['name'], node, 'function') + + for param in node.get('params', []): + param_type = param.get('type') + if param_type == 'Identifier': + new_scope.add_binding(param['name'], param, 'param') + elif param_type == 'AssignmentPattern': + left = param.get('left', {}) + if left.get('type') == 'Identifier': + new_scope.add_binding(left['name'], param, 'param') + elif param_type == 'RestElement': + argument = param.get('argument') + if argument and argument.get('type') == 'Identifier': + new_scope.add_binding(argument['name'], param, 'param') + + body = node.get('body') + if not body: + return + if _isinstance(body, _dict) and body.get('type') == 'BlockStatement': + node_scope[id(body)] = new_scope + statements = body.get('body', []) + for index in range(len(statements) - 1, -1, -1): + push_target.append((statements[index], new_scope)) + else: + push_target.append((body, new_scope)) + + elif node_type in ('ClassExpression', 'ClassDeclaration'): + class_id = node.get('id') + inner_scope = scope + if class_id and class_id.get('type') == 'Identifier': + name = class_id['name'] + if node_type == 'ClassDeclaration': + scope.add_binding(name, node, 'function') + else: + inner_scope = Scope(scope, node) + node_scope[id(node)] = inner_scope + all_scopes.append(inner_scope) + inner_scope.add_binding(name, node, 'function') + superclass = node.get('superClass') + body = node.get('body') + if body: + push_target.append((body, inner_scope)) + if superclass: + push_target.append((superclass, scope)) + + elif node_type == 'VariableDeclaration': + kind = node.get('kind', 'var') + target_scope = (_nearest_function_scope(scope) or scope) if kind == 'var' else scope + declarations = node.get('declarations', []) + inits_to_push = [] + for declaration in declarations: + declaration_id = declaration.get('id') + if declaration_id and declaration_id.get('type') == 'Identifier': + target_scope.add_binding(declaration_id['name'], declaration, kind) + _collect_pattern_names(declaration_id, target_scope, kind, declaration) + init = declaration.get('init') + if init: + inits_to_push.append((init, scope)) + for index in range(len(inits_to_push) - 1, -1, -1): + push_target.append(inits_to_push[index]) + + elif node_type == 'BlockStatement' and id(node) not in node_scope: + new_scope = Scope(scope, node) + node_scope[id(node)] = new_scope + all_scopes.append(new_scope) + statements = node.get('body', []) + for index in range(len(statements) - 1, -1, -1): + push_target.append((statements[index], new_scope)) + + elif node_type == 'ForStatement': + new_scope = Scope(scope, node) + node_scope[id(node)] = new_scope + all_scopes.append(new_scope) + if node.get('body'): + push_target.append((node['body'], new_scope)) + if node.get('init'): + push_target.append((node['init'], new_scope)) + + elif node_type == 'CatchClause': + catch_body = node.get('body') + if catch_body and catch_body.get('type') == 'BlockStatement': + catch_scope = Scope(scope, catch_body) + node_scope[id(catch_body)] = catch_scope + all_scopes.append(catch_scope) + param = node.get('param') + if param and param.get('type') == 'Identifier': + catch_scope.add_binding(param['name'], param, 'param') + statements = catch_body.get('body', []) + for index in range(len(statements) - 1, -1, -1): + push_target.append((statements[index], catch_scope)) + + else: + push_children_fn(node, scope, push_target) def build_scope_tree(ast: dict) -> tuple[Scope, dict[int, Scope]]: """Build a scope tree from an AST, collecting bindings and references. Returns the root Scope and a dict mapping node id -> Scope. + Uses recursive traversal with automatic fallback to iterative for deep subtrees. """ root_scope = Scope(None, ast, is_function=True) - # Maps id(node) -> scope for function/block scope nodes node_scope: dict[int, Scope] = {id(ast): root_scope} - # We need to collect all declarations first, then references all_scopes: list[Scope] = [root_scope] - def _get_scope_for(node: dict, current_scope: Scope) -> Scope: - """Get or create the scope for a node.""" - node_id = id(node) - if node_id in node_scope: - return node_scope[node_id] - return current_scope - _child_keys_map = _CHILD_KEYS + _get_child_keys = get_child_keys + _max_depth = _MAX_RECURSIVE_DEPTH + + # ---- Pass 1: Collect declarations (recursive with iterative fallback) ---- - def _collect_declarations(node: dict, scope: Scope) -> None: - """Walk the AST collecting variable declarations into scopes.""" - if not isinstance(node, dict): + def _push_children(node: dict, scope: Scope, target_list: list) -> None: + """Push child nodes onto a list.""" + node_type = node.get('type') + child_keys = _child_keys_map.get(node_type) + if child_keys is None: + child_keys = _get_child_keys(node) + for key in reversed(child_keys): + child = node.get(key) + if child is None: + continue + if _isinstance(child, _list): + for index in range(len(child) - 1, -1, -1): + item = child[index] + if _isinstance(item, _dict) and 'type' in item: + target_list.append((item, scope)) + elif _isinstance(child, _dict) and 'type' in child: + target_list.append((child, scope)) + + def _visit_declaration(node: dict, scope: Scope, depth: int) -> None: + if not _isinstance(node, _dict): return node_type = node.get('type') if node_type is None: return - match node_type: - case 'FunctionDeclaration' | 'FunctionExpression' | 'ArrowFunctionExpression': - new_scope = Scope(scope, node, is_function=True) - node_scope[id(node)] = new_scope - all_scopes.append(new_scope) - - # Function name goes in outer scope (for declarations) or inner (for expressions) - if node_type == 'FunctionDeclaration' and node.get('id'): - scope.add_binding(node['id']['name'], node, 'function') - elif node_type == 'FunctionExpression' and node.get('id'): - new_scope.add_binding(node['id']['name'], node, 'function') - - # Params go in function scope - for param in node.get('params', []): - match param.get('type'): - case 'Identifier': - new_scope.add_binding(param['name'], param, 'param') - case 'AssignmentPattern' if param.get('left', {}).get('type') == 'Identifier': - new_scope.add_binding(param['left']['name'], param, 'param') - case 'RestElement': - arg = param.get('argument') - if arg and arg.get('type') == 'Identifier': - new_scope.add_binding(arg['name'], param, 'param') - - # Body - use the new scope - body = node.get('body') - if not body: - return - if isinstance(body, dict) and body.get('type') == 'BlockStatement': - node_scope[id(body)] = new_scope - for statement in body.get('body', []): - _collect_declarations(statement, new_scope) - else: - _collect_declarations(body, new_scope) - - case 'ClassExpression' | 'ClassDeclaration': - class_id = node.get('id') - inner_scope = scope - if class_id and class_id.get('type') == 'Identifier': - name = class_id['name'] - if node_type == 'ClassDeclaration': - scope.add_binding(name, node, 'function') - else: - inner_scope = Scope(scope, node) - node_scope[id(node)] = inner_scope - all_scopes.append(inner_scope) - inner_scope.add_binding(name, node, 'function') - body = node.get('body') - if body: - _collect_declarations(body, inner_scope) - superclass = node.get('superClass') - if superclass: - _collect_declarations(superclass, scope) - - case 'VariableDeclaration': - kind = node.get('kind', 'var') - target_scope = (_nearest_function_scope(scope) or scope) if kind == 'var' else scope - for declaration in node.get('declarations', []): - declaration_id = declaration.get('id') - if declaration_id and declaration_id.get('type') == 'Identifier': - target_scope.add_binding(declaration_id['name'], declaration, kind) - _collect_pattern_names(declaration_id, target_scope, kind, declaration) - init = declaration.get('init') - if init: - _collect_declarations(init, scope) - - case 'BlockStatement' if id(node) not in node_scope: - new_scope = Scope(scope, node) - node_scope[id(node)] = new_scope - all_scopes.append(new_scope) - for statement in node.get('body', []): - _collect_declarations(statement, new_scope) - - case 'ForStatement': - new_scope = Scope(scope, node) - node_scope[id(node)] = new_scope - all_scopes.append(new_scope) - if node.get('init'): - _collect_declarations(node['init'], new_scope) - if node.get('body'): - _collect_declarations(node['body'], new_scope) - - case 'CatchClause': - catch_body = node.get('body') - if catch_body and catch_body.get('type') == 'BlockStatement': - catch_scope = Scope(scope, catch_body) - node_scope[id(catch_body)] = catch_scope - all_scopes.append(catch_scope) - param = node.get('param') - if param and param.get('type') == 'Identifier': - catch_scope.add_binding(param['name'], param, 'param') - for statement in catch_body.get('body', []): - _collect_declarations(statement, catch_scope) - - case _: - _recurse_into_children( - node, _child_keys_map, lambda child_node: _collect_declarations(child_node, scope) - ) - - def _collect_pattern_names(pattern: dict | None, scope: Scope, kind: str, declaration: dict) -> None: - """Collect binding names from destructuring patterns.""" - if not isinstance(pattern, dict): + if depth > _max_depth: + # Fall back to iterative for this subtree + _collect_declarations_iterative_from(node, scope) return - match pattern.get('type', ''): - case 'ArrayPattern': - for element in pattern.get('elements', []): - if not element: - continue - if element.get('type') == 'Identifier': - scope.add_binding(element['name'], declaration, kind) - else: - _collect_pattern_names(element, scope, kind, declaration) - case 'ObjectPattern': - for property_node in pattern.get('properties', []): - value_node = property_node.get('value', property_node.get('argument')) - if not value_node: - continue - if value_node.get('type') == 'Identifier': - scope.add_binding(value_node['name'], declaration, kind) - else: - _collect_pattern_names(value_node, scope, kind, declaration) - case 'RestElement': - argument_node = pattern.get('argument') - if argument_node and argument_node.get('type') == 'Identifier': - scope.add_binding(argument_node['name'], declaration, kind) - case 'AssignmentPattern': - left = pattern.get('left') - if left and left.get('type') == 'Identifier': - scope.add_binding(left['name'], declaration, kind) - - _collect_declarations(ast, root_scope) - - # Second pass: collect references and assignments - def _collect_references( + + # Collect children into a local list, then recurse + children: list = [] + _process_declaration_node(node, node_type, scope, node_scope, all_scopes, children, _push_children) + next_depth = depth + 1 + # Children were appended in stack order (reversed), so iterate + # in reverse to get left-to-right processing order. + for index in range(len(children) - 1, -1, -1): + _visit_declaration(children[index][0], children[index][1], next_depth) + + def _collect_declarations_iterative_from(start_node: dict, start_scope: Scope) -> None: + """Run iterative declaration collection starting from a specific node/scope.""" + decl_stack = [(start_node, start_scope)] + while decl_stack: + node, scope = decl_stack.pop() + if not _isinstance(node, _dict): + continue + node_type = node.get('type') + if node_type is None: + continue + _process_declaration_node( + node, node_type, scope, node_scope, all_scopes, decl_stack, _push_children + ) + + _visit_declaration(ast, root_scope, 0) + + # ---- Pass 2: Collect references (recursive with iterative fallback) ---- + + def _visit_reference( node: dict, scope: Scope, - parent: dict | None = None, - parent_key: str | None = None, - parent_index: int | None = None, + parent: dict | None, + parent_key: str | None, + parent_index: int | None, + depth: int, ) -> None: - if not isinstance(node, dict): + if not _isinstance(node, _dict): return node_type = node.get('type') if node_type is None: return - # Look up scope for this node - scope = _get_scope_for(node, scope) + node_id = id(node) + if node_id in node_scope: + scope = node_scope[node_id] if node_type == 'Identifier': name = node.get('name', '') if _is_non_reference_identifier(parent, parent_key): return - binding = scope.get_binding(name) if not binding: return @@ -299,22 +379,66 @@ def _collect_references( binding.assignments.append(parent) return - # Recurse — can't use _recurse_into_children here because we need - # per-child (key, index) args for reference tracking + if depth > _max_depth: + _collect_references_iterative_from(node, scope) + return + child_keys = _child_keys_map.get(node_type) if child_keys is None: - child_keys = get_child_keys(node) + child_keys = _get_child_keys(node) + next_depth = depth + 1 for key in child_keys: child = node.get(key) if child is None: continue - if isinstance(child, list): + if _isinstance(child, _list): for child_index, item in enumerate(child): - if isinstance(item, dict) and 'type' in item: - _collect_references(item, scope, node, key, child_index) - elif isinstance(child, dict) and 'type' in child: - _collect_references(child, scope, node, key, None) - - _collect_references(ast, root_scope) + if _isinstance(item, _dict) and 'type' in item: + _visit_reference(item, scope, node, key, child_index, next_depth) + elif _isinstance(child, _dict) and 'type' in child: + _visit_reference(child, scope, node, key, None, next_depth) + + def _collect_references_iterative_from(start_node: dict, start_scope: Scope) -> None: + """Run iterative reference collection starting from a specific node/scope.""" + ref_stack = [(start_node, start_scope, None, None, None)] + while ref_stack: + node, scope, parent, parent_key, parent_index = ref_stack.pop() + if not _isinstance(node, _dict): + continue + node_type = node.get('type') + if node_type is None: + continue + node_id = id(node) + if node_id in node_scope: + scope = node_scope[node_id] + if node_type == 'Identifier': + name = node.get('name', '') + if _is_non_reference_identifier(parent, parent_key): + continue + binding = scope.get_binding(name) + if not binding: + continue + binding.references.append((node, parent, parent_key, parent_index)) + if parent and parent.get('type') == 'AssignmentExpression' and parent_key == 'left': + binding.assignments.append(parent) + elif parent and parent.get('type') == 'UpdateExpression': + binding.assignments.append(parent) + continue + child_keys = _child_keys_map.get(node_type) + if child_keys is None: + child_keys = _get_child_keys(node) + for key in reversed(child_keys): + child = node.get(key) + if child is None: + continue + if _isinstance(child, _list): + for index in range(len(child) - 1, -1, -1): + item = child[index] + if _isinstance(item, _dict) and 'type' in item: + ref_stack.append((item, scope, node, key, index)) + elif _isinstance(child, _dict) and 'type' in child: + ref_stack.append((child, scope, node, key, None)) + + _visit_reference(ast, root_scope, None, None, None, 0) return root_scope, node_scope diff --git a/pyjsclear/traverser.py b/pyjsclear/traverser.py index dbd272d..afa5a69 100644 --- a/pyjsclear/traverser.py +++ b/pyjsclear/traverser.py @@ -17,6 +17,203 @@ _list = list _isinstance = isinstance +# Maximum recursion depth before falling back to iterative traversal. +# CPython default recursion limit is ~1000; we switch well before that. +_MAX_RECURSIVE_DEPTH = 500 + +# Stack frame opcodes for iterative traverse +_OP_ENTER = 0 +_OP_EXIT = 1 +_OP_LIST_START = 2 +_OP_LIST_RESUME = 3 + + +def _traverse_iterative(node: dict, enter_fn: Callable | None, exit_fn: Callable | None) -> None: + """Iterative stack-based traverse. Handles both enter and exit callbacks.""" + child_keys_map = _CHILD_KEYS + _REMOVE = REMOVE + _SKIP = SKIP + _get_child_keys = get_child_keys + + stack = [(_OP_ENTER, node, None, None, None)] + stack_pop = stack.pop + stack_append = stack.append + + while stack: + frame = stack_pop() + op = frame[0] + + if op == _OP_ENTER: + current_node = frame[1] + parent = frame[2] + key = frame[3] + index = frame[4] + + node_type = current_node.get('type') + if node_type is None: + continue + + if enter_fn: + result = enter_fn(current_node, parent, key, index) + if result is _REMOVE: + if parent is not None: + if index is not None: + parent[key].pop(index) + else: + parent[key] = None + continue + if result is _SKIP: + if exit_fn: + exit_result = exit_fn(current_node, parent, key, index) + if exit_result is _REMOVE: + if parent is not None: + if index is not None: + parent[key].pop(index) + else: + parent[key] = None + elif _isinstance(exit_result, _dict) and 'type' in exit_result: + if parent is not None: + if index is not None: + parent[key][index] = exit_result + else: + parent[key] = exit_result + continue + if _isinstance(result, _dict) and 'type' in result: + current_node = result + if parent is not None: + if index is not None: + parent[key][index] = current_node + else: + parent[key] = current_node + node_type = current_node.get('type') + + if exit_fn: + stack_append((_OP_EXIT, current_node, parent, key, index)) + + child_keys = child_keys_map.get(node_type) + if child_keys is None: + child_keys = _get_child_keys(current_node) + + for index in range(len(child_keys) - 1, -1, -1): + child_key = child_keys[index] + child = current_node.get(child_key) + if child is None: + continue + if _isinstance(child, _list): + stack_append((_OP_LIST_START, current_node, child_key, 0, None)) + elif _isinstance(child, _dict) and 'type' in child: + stack_append((_OP_ENTER, child, current_node, child_key, None)) + + elif op == _OP_EXIT: + current_node = frame[1] + parent = frame[2] + key = frame[3] + index = frame[4] + result = exit_fn(current_node, parent, key, index) + if result is _REMOVE: + if parent is not None: + if index is not None: + parent[key].pop(index) + else: + parent[key] = None + elif _isinstance(result, _dict) and 'type' in result: + if parent is not None: + if index is not None: + parent[key][index] = result + else: + parent[key] = result + + elif op == _OP_LIST_START: + parent_node = frame[1] + child_key = frame[2] + idx = frame[3] + child_list = parent_node[child_key] + if idx >= len(child_list): + continue + item = child_list[idx] + if _isinstance(item, _dict) and 'type' in item: + stack_append((_OP_LIST_RESUME, parent_node, child_key, idx, len(child_list))) + stack_append((_OP_ENTER, item, parent_node, child_key, idx)) + else: + stack_append((_OP_LIST_START, parent_node, child_key, idx + 1, None)) + + else: # _OP_LIST_RESUME + parent_node = frame[1] + child_key = frame[2] + idx = frame[3] + pre_len = frame[4] + child_list = parent_node[child_key] + current_len = len(child_list) + if current_len < pre_len: + next_idx = idx + else: + next_idx = idx + 1 + if next_idx < current_len: + stack_append((_OP_LIST_START, parent_node, child_key, next_idx, None)) + + +def _traverse_enter_only(node: dict, enter_fn: Callable) -> None: + """Recursive enter-only traverse with depth-limited fallback to iterative.""" + child_keys_map = _CHILD_KEYS + _REMOVE = REMOVE + _SKIP = SKIP + _get_child_keys = get_child_keys + _max_depth = _MAX_RECURSIVE_DEPTH + + def _visit(current_node: dict, parent: dict | None, key: str | None, index: int | None, depth: int) -> None: + node_type = current_node.get('type') + if node_type is None: + return + + result = enter_fn(current_node, parent, key, index) + if result is _REMOVE: + if parent is not None: + if index is not None: + parent[key].pop(index) + else: + parent[key] = None + return + if result is _SKIP: + return + if _isinstance(result, _dict) and 'type' in result: + current_node = result + if parent is not None: + if index is not None: + parent[key][index] = current_node + else: + parent[key] = current_node + node_type = current_node.get('type') + + # Depth check: fall back to iterative for deep subtrees + if depth > _max_depth: + _traverse_iterative(current_node, enter_fn, None) + return + + child_keys = child_keys_map.get(node_type) + if child_keys is None: + child_keys = _get_child_keys(current_node) + + next_depth = depth + 1 + for child_key in child_keys: + child = current_node.get(child_key) + if child is None: + continue + if _isinstance(child, _list): + list_index = 0 + while list_index < len(child): + item = child[list_index] + if _isinstance(item, _dict) and 'type' in item: + pre_len = len(child) + _visit(item, current_node, child_key, list_index, next_depth) + # If item was removed, list shrunk - stay at same index + if len(child) < pre_len: + continue + list_index += 1 + elif _isinstance(child, _dict) and 'type' in child: + _visit(child, current_node, child_key, None, next_depth) + + _visit(node, None, None, None, 0) + def traverse(node: dict, visitor: dict | object) -> None: """Traverse an ESTree AST calling visitor callbacks. @@ -27,6 +224,10 @@ def traverse(node: dict, visitor: dict | object) -> None: - REMOVE: remove this node from parent - SKIP: (enter only) skip traversing children - a dict (node): replace this node with the returned node + + Uses recursive traversal for enter-only visitors (fast path) with + automatic fallback to iterative for deep subtrees. Uses iterative + traversal when an exit callback is present. """ if _isinstance(visitor, _dict): enter_fn = visitor.get('enter') @@ -35,103 +236,97 @@ def traverse(node: dict, visitor: dict | object) -> None: enter_fn = getattr(visitor, 'enter', None) exit_fn = getattr(visitor, 'exit', None) + if exit_fn is None and enter_fn is not None: + _traverse_enter_only(node, enter_fn) + else: + _traverse_iterative(node, enter_fn, exit_fn) + + +def _simple_traverse_iterative(node: dict, callback: Callable) -> None: + """Iterative stack-based simple traversal.""" child_keys_map = _CHILD_KEYS - _REMOVE = REMOVE - _SKIP = SKIP + _get_child_keys = get_child_keys + + stack = [(node, None)] + stack_pop = stack.pop + stack_append = stack.append - def _visit(current_node: dict, parent: dict | None, key: str | None, index: int | None) -> Any: + while stack: + current_node, parent = stack_pop() node_type = current_node.get('type') if node_type is None: - return current_node - - # Enter - if enter_fn: - result = enter_fn(current_node, parent, key, index) - if result is _REMOVE: - return _REMOVE - if result is _SKIP: - if not exit_fn: - return current_node - exit_result = exit_fn(current_node, parent, key, index) - if exit_result is _REMOVE: - return _REMOVE - if _isinstance(exit_result, _dict) and 'type' in exit_result: - return exit_result - return current_node - if _isinstance(result, _dict) and 'type' in result: - current_node = result - if parent is not None: - if index is not None: - parent[key][index] = current_node - else: - parent[key] = current_node - - # Visit children - child_keys = child_keys_map.get(current_node.get('type')) + continue + callback(current_node, parent) + child_keys = child_keys_map.get(node_type) if child_keys is None: - child_keys = get_child_keys(current_node) - for child_key in child_keys: - child = current_node.get(child_key) + child_keys = _get_child_keys(current_node) + for key in reversed(child_keys): + child = current_node.get(key) if child is None: continue if _isinstance(child, _list): - child_index = 0 - while child_index < len(child): - item = child[child_index] + for list_index in range(len(child) - 1, -1, -1): + item = child[list_index] if _isinstance(item, _dict) and 'type' in item: - result = _visit(item, current_node, child_key, child_index) - if result is _REMOVE: - child.pop(child_index) - continue - elif result is not item: - child[child_index] = result - child_index += 1 + stack_append((item, current_node)) elif _isinstance(child, _dict) and 'type' in child: - result = _visit(child, current_node, child_key, None) - if result is _REMOVE: - current_node[child_key] = None - elif result is not child: - current_node[child_key] = result + stack_append((child, current_node)) - # Exit - if exit_fn: - result = exit_fn(current_node, parent, key, index) - if result is _REMOVE: - return _REMOVE - if _isinstance(result, _dict) and 'type' in result: - return result - return current_node - - _visit(node, None, None, None) - - -def simple_traverse(node: dict, callback: Callable) -> None: - """Simple traversal that calls callback(node, parent) for every node. - No replacement support - just visiting. - """ +def _simple_traverse_recursive(node: dict, callback: Callable) -> None: + """Recursive simple traversal with depth-limited fallback to iterative.""" child_keys_map = _CHILD_KEYS + _get_child_keys = get_child_keys + _max_depth = _MAX_RECURSIVE_DEPTH - def _visit(current_node: dict, parent: dict | None) -> None: + def _visit(current_node: dict, parent: dict | None, depth: int) -> None: node_type = current_node.get('type') if node_type is None: return callback(current_node, parent) + + if depth > _max_depth: + # Fall back to iterative for this subtree's children + child_keys = child_keys_map.get(node_type) + if child_keys is None: + child_keys = _get_child_keys(current_node) + for key in child_keys: + child = current_node.get(key) + if child is None: + continue + if _isinstance(child, _list): + for item in child: + if _isinstance(item, _dict) and 'type' in item: + _simple_traverse_iterative(item, callback) + elif _isinstance(child, _dict) and 'type' in child: + _simple_traverse_iterative(child, callback) + return + child_keys = child_keys_map.get(node_type) if child_keys is None: - child_keys = get_child_keys(current_node) - for child_key in child_keys: - child = current_node.get(child_key) + child_keys = _get_child_keys(current_node) + next_depth = depth + 1 + for key in child_keys: + child = current_node.get(key) if child is None: continue if _isinstance(child, _list): for item in child: if _isinstance(item, _dict) and 'type' in item: - _visit(item, current_node) + _visit(item, current_node, next_depth) elif _isinstance(child, _dict) and 'type' in child: - _visit(child, current_node) + _visit(child, current_node, next_depth) - _visit(node, None) + _visit(node, None, 0) + + +def simple_traverse(node: dict, callback: Callable) -> None: + """Simple traversal that calls callback(node, parent) for every node. + No replacement support - just visiting. + + Uses recursive traversal with automatic fallback to iterative for deep subtrees. + """ + _simple_traverse_recursive(node, callback) def collect_nodes(ast: dict, node_type: str) -> list[dict]: @@ -146,6 +341,40 @@ def collect_callback(node: dict, parent: dict | None) -> None: return collected +def build_parent_map(ast: dict) -> dict: + """Build a map from id(node) -> (parent, key, index) for all nodes in the AST. + + This allows O(1) parent lookups instead of O(n) find_parent() calls. + """ + parent_map = {} + child_keys_map = _CHILD_KEYS + _get_child_keys = get_child_keys + + stack = [(ast, None, None, None)] + while stack: + current_node, parent, key, index = stack.pop() + node_type = current_node.get('type') + if node_type is None: + continue + parent_map[id(current_node)] = (parent, key, index) + child_keys = child_keys_map.get(node_type) + if child_keys is None: + child_keys = _get_child_keys(current_node) + for child_key in child_keys: + child = current_node.get(child_key) + if child is None: + continue + if _isinstance(child, _list): + for list_index in range(len(child) - 1, -1, -1): + item = child[list_index] + if _isinstance(item, _dict) and 'type' in item: + stack.append((item, current_node, child_key, list_index)) + elif _isinstance(child, _dict) and 'type' in child: + stack.append((child, current_node, child_key, None)) + + return parent_map + + class _FoundParent(Exception): """Raised to short-circuit find_parent search.""" @@ -156,7 +385,10 @@ def __init__(self, value: tuple) -> None: def find_parent(ast: dict, target_node: dict) -> tuple | None: - """Find the parent of a node in the AST. Returns (parent, key, index) or None.""" + """Find the parent of a node in the AST. Returns (parent, key, index) or None. + + For multiple lookups, consider using build_parent_map() instead. + """ def _visit(node: dict) -> None: if not isinstance(node, dict) or 'type' not in node: diff --git a/tests/unit/scope_test.py b/tests/unit/scope_test.py index 67ea10b..c17aee2 100644 --- a/tests/unit/scope_test.py +++ b/tests/unit/scope_test.py @@ -695,3 +695,95 @@ def test_unbound_identifier_reference(self): root_scope, _ = build_scope_tree(ast) # 'x' is not declared, so no binding should exist assert root_scope.get_own_binding('x') is None + + +# --------------------------------------------------------------------------- +# Deep AST test (exercises iterative fallback at depth > 500) +# --------------------------------------------------------------------------- + + +class TestDeepASTScope: + """Verify that build_scope_tree handles ASTs deeper than 500.""" + + def test_deep_nested_functions(self): + """Build deeply nested function scopes and verify bindings resolve.""" + # Build a chain: Program -> func0 body -> func1 body -> ... -> funcN body -> var x = 1; + var_decl = { + 'type': 'VariableDeclaration', + 'kind': 'var', + 'declarations': [ + { + 'type': 'VariableDeclarator', + 'id': {'type': 'Identifier', 'name': 'x'}, + 'init': {'type': 'Literal', 'value': 1, 'raw': '1'}, + } + ], + } + # Reference to x + ref_stmt = { + 'type': 'ExpressionStatement', + 'expression': {'type': 'Identifier', 'name': 'x'}, + } + node = {'type': 'BlockStatement', 'body': [var_decl, ref_stmt]} + depth = 600 + for i in range(depth): + node = { + 'type': 'FunctionDeclaration', + 'id': {'type': 'Identifier', 'name': f'f{i}'}, + 'params': [], + 'body': node, + } + ast = {'type': 'Program', 'sourceType': 'script', 'body': [node]} + + root_scope, node_scope = build_scope_tree(ast) + # Should have created many scopes without stack overflow + assert len(node_scope) > depth + # The innermost x binding should be resolvable + # Walk to the deepest function scope + scope = root_scope + for _ in range(depth): + assert len(scope.children) >= 1 + scope = scope.children[0] + x_binding = scope.get_own_binding('x') + assert x_binding is not None + assert x_binding.kind == 'var' + + def test_deep_block_statements(self): + """Build deeply nested block statements and verify scope creation.""" + # Innermost has a let binding + inner = { + 'type': 'BlockStatement', + 'body': [ + { + 'type': 'VariableDeclaration', + 'kind': 'let', + 'declarations': [ + { + 'type': 'VariableDeclarator', + 'id': {'type': 'Identifier', 'name': 'deep'}, + 'init': {'type': 'Literal', 'value': 42, 'raw': '42'}, + } + ], + } + ], + } + node = inner + for _ in range(600): + node = {'type': 'BlockStatement', 'body': [node]} + ast = {'type': 'Program', 'sourceType': 'script', 'body': [node]} + + root_scope, node_scope = build_scope_tree(ast) + # Should complete without stack overflow + assert root_scope is not None + # The 'deep' binding should exist somewhere in the scope tree + found = False + + def _check(scope): + nonlocal found + if scope.get_own_binding('deep'): + found = True + for child in scope.children: + _check(child) + + _check(root_scope) + assert found diff --git a/tests/unit/traverser_test.py b/tests/unit/traverser_test.py index 4660df0..00b0102 100644 --- a/tests/unit/traverser_test.py +++ b/tests/unit/traverser_test.py @@ -647,3 +647,71 @@ def test_traverse_unknown_node_type(self): assert 'Program' in visited assert 'UnknownCustomNode' in visited assert 'Identifier' in visited + + +# =========================================================================== +# Deep AST tests (exercises iterative fallback at depth > 500) +# =========================================================================== + + +def _make_deep_ast(depth): + """Build an AST of nested IfStatements to the given depth. + + Structure: Program -> nested IfStatements each containing a BlockStatement + with a single child, down to a Literal leaf. + """ + leaf = {'type': 'Literal', 'value': 42, 'raw': '42'} + node = leaf + for _ in range(depth): + node = { + 'type': 'IfStatement', + 'test': {'type': 'Literal', 'value': True, 'raw': 'true'}, + 'consequent': { + 'type': 'BlockStatement', + 'body': [{'type': 'ExpressionStatement', 'expression': node}], + }, + 'alternate': None, + } + return {'type': 'Program', 'sourceType': 'script', 'body': [node]} + + +class TestDeepASTTraversal: + """Verify that traverse() and simple_traverse() handle ASTs deeper than 500.""" + + def test_traverse_deep_ast(self): + ast = _make_deep_ast(600) + types = [] + traverse(ast, {'enter': lambda n, p, k, i: types.append(n['type'])}) + assert 'Program' in types + assert 'Literal' in types + # Should have visited all IfStatements + assert types.count('IfStatement') == 600 + + def test_simple_traverse_deep_ast(self): + ast = _make_deep_ast(600) + types = [] + simple_traverse(ast, lambda n, p: types.append(n['type'])) + assert 'Program' in types + assert 'Literal' in types + assert types.count('IfStatement') == 600 + + def test_traverse_deep_ast_with_exit(self): + """Exit-callback path uses fully iterative traversal.""" + ast = _make_deep_ast(600) + exit_types = [] + traverse(ast, {'exit': lambda n, p, k, i: exit_types.append(n['type'])}) + assert 'Program' in exit_types + assert 'Literal' in exit_types + + def test_traverse_deep_ast_removal(self): + """REMOVE works correctly across the recursive/iterative boundary.""" + ast = _make_deep_ast(600) + removed_count = [] + + def enter(node, parent, key, index): + if node['type'] == 'Literal' and node.get('value') == 42: + removed_count.append(1) + return REMOVE + + traverse(ast, {'enter': enter}) + assert len(removed_count) > 0 From 072658b176fa689583949e5aa3f5a81277c6a2a1 Mon Sep 17 00:00:00 2001 From: Itamar Gafni Date: Fri, 13 Mar 2026 09:49:24 +0200 Subject: [PATCH 4/8] Cache scope tree and parent map across transform passes Avoid redundant build_scope_tree() calls by caching the result in the deobfuscator loop and passing it to scope-using transforms. The cache is invalidated when any transform modifies the AST. Also add a lazy parent map (build_parent_map) to the base Transform class, replacing O(n) find_parent() tree walks with O(1) lookups. Co-Authored-By: Claude Opus 4.6 --- .claude/settings.local.json | 55 +++++++++++++++++++ .gitignore | 4 +- pyjsclear/deobfuscator.py | 39 ++++++++++++- pyjsclear/transforms/base.py | 22 ++++++++ pyjsclear/transforms/class_static_resolver.py | 6 +- pyjsclear/transforms/cleanup.py | 10 +++- pyjsclear/transforms/constant_prop.py | 5 +- pyjsclear/transforms/object_simplifier.py | 13 +++-- pyjsclear/transforms/proxy_functions.py | 5 +- pyjsclear/transforms/reassignment.py | 5 +- pyjsclear/transforms/single_use_vars.py | 8 ++- pyjsclear/transforms/string_revealer.py | 9 ++- pyjsclear/transforms/unused_vars.py | 5 +- pyjsclear/transforms/variable_renamer.py | 5 +- 14 files changed, 166 insertions(+), 25 deletions(-) create mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..2468b6e --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,55 @@ +{ + "permissions": { + "allow": [ + "Bash(.venv/bin/python:*)", + "Bash(/Users/itamar/work/PyJSClear/.venv/bin/python:*)", + "Bash(git:*)", + "Bash(gh run:*)", + "Bash(for dir in /Users/itamar/work/PyJSClear/tests/resources /Users/itamar/work/PyJSClear/tests/resources/jsimplifier /Users/itamar/work/PyJSClear/tests/resources/jsimplifier/dataset /Users/itamar/work/PyJSClear/tests/resources/jsimplifier/experiments)", + "Bash(do echo \"=== $dir ===\")", + "Read(//Users/itamar/work/PyJSClear/**)", + "Bash(done)", + "Bash(gh api:*)", + "Bash(python -m pytest tests/unit/ -x -q 2>&1 | tail -20)", + "Bash(uv run:*)", + "Bash(/tmp/match_candidates.txt:*)", + "Read(//tmp/**)", + "Bash(git stash:*)", + "Bash(python -m pytest tests/test_regression.py --collect-only 2>&1 | tail -5)", + "Bash(python3 -m pytest tests/test_regression.py --collect-only 2>&1 | tail -5)", + "Bash(python -m pytest tests/ -x -q 2>&1 | tail -20)", + "Bash(/Users/itamar/work/PyJSClear/.venv/bin/pytest tests/ -x -q 2>&1 | tail -20)", + "Bash(python -m pytest tests/unit/transforms/jsfuck_decode_test.py tests/unit/transforms/jj_decode_test.py tests/unit/transforms/deobfuscator_prepasses_test.py -v 2>&1 | tail -60)", + "Bash(python3 -m pytest tests/unit/transforms/jsfuck_decode_test.py tests/unit/transforms/jj_decode_test.py tests/unit/transforms/deobfuscator_prepasses_test.py -v 2>&1 | tail -60)", + "Bash(.venv/bin/pytest tests/unit/transforms/jsfuck_decode_test.py tests/unit/transforms/jj_decode_test.py tests/unit/transforms/deobfuscator_prepasses_test.py -v 2>&1 | tail -60)", + "Bash(.venv/bin/pytest tests/ -x --tb=short 2>&1 | tail -30)", + "Bash(gh pr:*)", + "Bash(find /Users/itamar/work/PyJSClear/tests/resources/jsimplifier/node_modules -maxdepth 2 -name \"LICENSE*\" | xargs -I {} sh -c 'echo \"=== {} ===\" && head -5 {}' 2>/dev/null | head -100)", + "WebFetch(domain:raw.githubusercontent.com)", + "WebSearch", + "WebFetch(domain:zenodo.org)", + "Bash(.venv/bin/pytest tests/ -x -q 2>&1 | tail -20)", + "Bash(python3 -m pytest tests/ -x -q 2>&1 | tail -20)", + "Bash(source .venv/bin/activate && pytest tests/ -x -q 2>&1 | tail -20)", + "Bash(python3 --version && python3 -c \"import pyjsclear\" 2>&1)", + "WebFetch(domain:pypi.org)", + "Bash(ls /Users/itamar/work/PyJSClear/pyjsclear/transforms/*.py | xargs basename -a)", + "WebFetch(domain:github.com)", + "WebFetch(domain:segmentfault.com)", + "Bash(which pytest:*)", + "Bash(ls /Users/itamar/work/PyJSClear/.venv/bin/python* 2>/dev/null; /Users/itamar/work/PyJSClear/.venv/bin/python3 -m pytest --version 2>&1)", + "Bash(/Users/itamar/work/PyJSClear/.venv/bin/python3.12 -m pytest tests/ -x 2>&1)", + "Bash(python3.12 --version 2>&1 || python3.13 --version 2>&1 || python3.11 --version 2>&1 || python3.14 --version 2>&1)", + "Bash(python3.13 -m venv .venv --clear && .venv/bin/pip install -e \".[dev]\" 2>&1 | tail -5)", + "Bash(.venv/bin/pytest tests/ -x 2>&1)", + "Bash(/Users/itamar/work/PyJSClear/.venv/bin/python3.13 -m pytest tests/ -x 2>&1)", + "WebFetch(domain:warehouse.pypa.io)", + "WebFetch(domain:docs.pypi.org)", + "Bash(.venv/bin/pip install:*)" + ] + }, + "sandbox": { + "enabled": true, + "autoAllowBashIfSandboxed": true + } +} diff --git a/.gitignore b/.gitignore index ecafa45..f03e3f3 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,6 @@ pyjsclear/**/*.c /uv.lock /tests/resources/jsimplifier /tests/resources/obfuscated-javascript-dataset -/.gitignore \ No newline at end of file +/.gitignore +/.nodeenv +/.emdash.json \ No newline at end of file diff --git a/pyjsclear/deobfuscator.py b/pyjsclear/deobfuscator.py index 8e645fb..d8876b4 100644 --- a/pyjsclear/deobfuscator.py +++ b/pyjsclear/deobfuscator.py @@ -2,6 +2,9 @@ from .generator import generate from .parser import parse +from .scope import build_scope_tree +from .transforms.aa_decode import aa_decode +from .transforms.aa_decode import is_aa_encoded from .transforms.anti_tamper import AntiTamperRemover from .transforms.class_static_resolver import ClassStaticResolver from .transforms.class_string_decoder import ClassStringDecoder @@ -26,8 +29,6 @@ from .transforms.hex_escapes import HexEscapes from .transforms.hex_escapes import decode_hex_escapes_source from .transforms.hex_numerics import HexNumerics -from .transforms.aa_decode import aa_decode -from .transforms.aa_decode import is_aa_encoded from .transforms.jj_decode import is_jj_encoded from .transforms.jj_decode import jj_decode from .transforms.jsfuck_decode import is_jsfuck @@ -53,6 +54,22 @@ from .traverser import simple_traverse +# Transforms that use build_scope_tree and benefit from cached scope +_SCOPE_TRANSFORMS = frozenset( + { + ConstantProp, + SingleUseVarInliner, + ReassignmentRemover, + ProxyFunctionInliner, + UnusedVariableRemover, + ObjectSimplifier, + StringRevealer, + VariableRenamer, + VarToConst, + LetToConst, + } +) + # StringRevealer runs first to handle string arrays before other transforms # modify the wrapper function structure. # Remaining transforms follow obfuscator-io-deobfuscator order. @@ -259,6 +276,12 @@ def _run_ast_transforms(self, ast: dict, code_size: int = 0) -> bool: # Track which transforms are no longer productive skip_transforms = set() + # Cache scope tree across transforms — only rebuild when a transform + # that modifies bindings returns changed=True + scope_tree = None + node_scope = None + scope_dirty = True # Start dirty to build on first use + # Multi-pass transform loop any_transform_changed = False for iteration in range(max_iterations): @@ -267,13 +290,23 @@ def _run_ast_transforms(self, ast: dict, code_size: int = 0) -> bool: if transform_class in skip_transforms: continue try: - transform = transform_class(ast) + # Build scope tree lazily when needed by a scope-using transform + if transform_class in _SCOPE_TRANSFORMS and scope_dirty: + scope_tree, node_scope = build_scope_tree(ast) + scope_dirty = False + + if transform_class in _SCOPE_TRANSFORMS: + transform = transform_class(ast, scope_tree=scope_tree, node_scope=node_scope) + else: + transform = transform_class(ast) result = transform.execute() except Exception: continue if result: modified = True any_transform_changed = True + # Any AST change invalidates the cached scope tree + scope_dirty = True elif iteration > 0: # Skip transforms that haven't changed anything after the first pass skip_transforms.add(transform_class) diff --git a/pyjsclear/transforms/base.py b/pyjsclear/transforms/base.py index 594b264..8a9d00d 100644 --- a/pyjsclear/transforms/base.py +++ b/pyjsclear/transforms/base.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING +from ..traverser import build_parent_map + if TYPE_CHECKING: from ..scope import Scope @@ -24,6 +26,7 @@ def __init__( self.scope_tree = scope_tree self.node_scope = node_scope self._changed = False + self._parent_map = None def execute(self) -> bool: """Execute the transform. Returns True if the AST was modified.""" @@ -34,3 +37,22 @@ def set_changed(self) -> None: def has_changed(self) -> bool: return self._changed + + def get_parent_map(self): + """Lazily build and return a parent map for the AST. + + Returns dict mapping id(node) -> (parent, key, index). + Call invalidate_parent_map() after AST modifications. + """ + if self._parent_map is None: + self._parent_map = build_parent_map(self.ast) + return self._parent_map + + def invalidate_parent_map(self): + """Invalidate the cached parent map after AST modifications.""" + self._parent_map = None + + def find_parent(self, target_node): + """Find the parent of a node using the parent map. Returns (parent, key, index) or None.""" + pm = self.get_parent_map() + return pm.get(id(target_node)) diff --git a/pyjsclear/transforms/class_static_resolver.py b/pyjsclear/transforms/class_static_resolver.py index 75cb1ce..62db501 100644 --- a/pyjsclear/transforms/class_static_resolver.py +++ b/pyjsclear/transforms/class_static_resolver.py @@ -11,7 +11,6 @@ ... C.id(expr) ... → ... expr ... """ -from ..traverser import find_parent from ..traverser import simple_traverse from ..traverser import traverse from ..utils.ast_helpers import deep_copy @@ -180,7 +179,7 @@ def _is_identity_function(self, func_node: dict) -> bool: def _try_inline_identity(self, member_expr: dict, method_node: dict) -> None: """Inline Class.identity(arg) → arg.""" - result = find_parent(self.ast, member_expr) + result = self.find_parent(member_expr) if not result: return call_parent, call_key, call_index = result @@ -191,7 +190,7 @@ def _try_inline_identity(self, member_expr: dict, method_node: dict) -> None: return replacement = deep_copy(args[0]) # Replace the CallExpression with the argument - grandparent_result = find_parent(self.ast, call_parent) + grandparent_result = self.find_parent(call_parent) if not grandparent_result: return grandparent, grandparent_key, grandparent_index = grandparent_result @@ -204,3 +203,4 @@ def _replace_in_parent(self, target: dict, replacement: dict, parent: dict, key: parent[key][index] = replacement else: parent[key] = replacement + self.invalidate_parent_map() diff --git a/pyjsclear/transforms/cleanup.py b/pyjsclear/transforms/cleanup.py index 7294c74..3219dcd 100644 --- a/pyjsclear/transforms/cleanup.py +++ b/pyjsclear/transforms/cleanup.py @@ -162,7 +162,10 @@ class LetToConst(Transform): """ def execute(self) -> bool: - scope_tree, _ = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree = self.scope_tree + else: + scope_tree, _ = build_scope_tree(self.ast) safe_declarators: set[int] = set() self._collect_let_const_candidates(scope_tree, safe_declarators) @@ -212,7 +215,10 @@ class VarToConst(Transform): """ def execute(self) -> bool: - scope_tree, _ = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree = self.scope_tree + else: + scope_tree, _ = build_scope_tree(self.ast) safe_declarators: set[int] = set() self._collect_const_candidates(scope_tree, safe_declarators, in_function=True) diff --git a/pyjsclear/transforms/constant_prop.py b/pyjsclear/transforms/constant_prop.py index 943135d..a6fcab9 100644 --- a/pyjsclear/transforms/constant_prop.py +++ b/pyjsclear/transforms/constant_prop.py @@ -31,7 +31,10 @@ class ConstantProp(Transform): rebuild_scope = True def execute(self) -> bool: - scope_tree, node_scope = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree, node_scope = self.scope_tree, self.node_scope + else: + scope_tree, node_scope = build_scope_tree(self.ast) replacements = dict(self._iter_constant_bindings(scope_tree)) if not replacements: diff --git a/pyjsclear/transforms/object_simplifier.py b/pyjsclear/transforms/object_simplifier.py index 2e3febf..7da9432 100644 --- a/pyjsclear/transforms/object_simplifier.py +++ b/pyjsclear/transforms/object_simplifier.py @@ -5,7 +5,6 @@ """ from ..scope import build_scope_tree -from ..traverser import find_parent from ..utils.ast_helpers import deep_copy from ..utils.ast_helpers import is_literal from ..utils.ast_helpers import is_string_literal @@ -19,7 +18,10 @@ class ObjectSimplifier(Transform): rebuild_scope = True def execute(self) -> bool: - scope_tree, _ = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree = self.scope_tree + else: + scope_tree, _ = build_scope_tree(self.ast) self._process_scope(scope_tree) return self.has_changed() @@ -92,7 +94,7 @@ def _has_property_assignment(self, binding) -> bool: for reference_node, reference_parent, ref_key, ref_index in binding.references: if not (reference_parent and reference_parent.get('type') == 'MemberExpression' and ref_key == 'object'): continue - member_expression_parent_info = find_parent(self.ast, reference_parent) + member_expression_parent_info = self.find_parent(reference_parent) if not member_expression_parent_info: continue parent, key, _ = member_expression_parent_info @@ -102,7 +104,7 @@ def _has_property_assignment(self, binding) -> bool: def _try_inline_function_call(self, member_expression, function_value) -> None: """Try to inline a function call at a MemberExpression site.""" - member_expression_parent_info = find_parent(self.ast, member_expression) + member_expression_parent_info = self.find_parent(member_expression) if not member_expression_parent_info: return parent, key, _ = member_expression_parent_info @@ -156,13 +158,14 @@ def _get_member_prop_name(self, member_expression) -> str | None: def _replace_node(self, target, replacement) -> bool: """Replace target node in the AST. Returns True if replaced.""" - result = find_parent(self.ast, target) + result = self.find_parent(target) if result: parent, key, index = result if index is not None: parent[key][index] = replacement else: parent[key] = replacement + self.invalidate_parent_map() return True return False diff --git a/pyjsclear/transforms/proxy_functions.py b/pyjsclear/transforms/proxy_functions.py index 2732419..a34d5f9 100644 --- a/pyjsclear/transforms/proxy_functions.py +++ b/pyjsclear/transforms/proxy_functions.py @@ -25,7 +25,10 @@ class ProxyFunctionInliner(Transform): rebuild_scope = True def execute(self): - scope_tree, node_scope = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree, node_scope = self.scope_tree, self.node_scope + else: + scope_tree, node_scope = build_scope_tree(self.ast) # Find proxy functions proxy_functions = {} # name -> (func_node, scope, binding) diff --git a/pyjsclear/transforms/reassignment.py b/pyjsclear/transforms/reassignment.py index 5352dc8..e84afbc 100644 --- a/pyjsclear/transforms/reassignment.py +++ b/pyjsclear/transforms/reassignment.py @@ -59,7 +59,10 @@ class ReassignmentRemover(Transform): rebuild_scope = True def execute(self) -> bool: - scope_tree, _ = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree = self.scope_tree + else: + scope_tree, _ = build_scope_tree(self.ast) self._process_scope(scope_tree) self._inline_assignment_aliases(scope_tree) return self.has_changed() diff --git a/pyjsclear/transforms/single_use_vars.py b/pyjsclear/transforms/single_use_vars.py index 7d28b17..30984ea 100644 --- a/pyjsclear/transforms/single_use_vars.py +++ b/pyjsclear/transforms/single_use_vars.py @@ -21,7 +21,6 @@ from ..scope import build_scope_tree from ..traverser import REMOVE -from ..traverser import find_parent from ..traverser import simple_traverse from ..traverser import traverse from ..utils.ast_helpers import deep_copy @@ -53,7 +52,10 @@ class SingleUseVarInliner(Transform): _MAX_INIT_NODES = 15 def execute(self) -> bool: - scope_tree, _ = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree = self.scope_tree + else: + scope_tree, _ = build_scope_tree(self.ast) inlined = self._process_scope(scope_tree) if not inlined: return False @@ -133,7 +135,7 @@ def _is_mutated_member_object(self, ref_parent: dict | None, ref_key: str | None if ref_key != 'object': return False # Now check if this MemberExpression is an assignment target - parent_info = find_parent(self.ast, ref_parent) + parent_info = self.find_parent(ref_parent) if not parent_info: return False grandparent, grandparent_key, _ = parent_info diff --git a/pyjsclear/transforms/string_revealer.py b/pyjsclear/transforms/string_revealer.py index 7d9c99b..eed74f5 100644 --- a/pyjsclear/transforms/string_revealer.py +++ b/pyjsclear/transforms/string_revealer.py @@ -7,7 +7,6 @@ from ..generator import generate from ..scope import build_scope_tree from ..traverser import REMOVE -from ..traverser import find_parent from ..traverser import simple_traverse from ..traverser import traverse from ..utils.ast_helpers import is_identifier @@ -208,7 +207,10 @@ class StringRevealer(Transform): _rotation_locals = {} def execute(self) -> bool: - scope_tree, node_scope = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree, node_scope = self.scope_tree, self.node_scope + else: + scope_tree, node_scope = build_scope_tree(self.ast) # Strategy 1: Direct string array declarations (var arr = ["a","b","c"]) self._process_direct_arrays(scope_tree) @@ -1252,13 +1254,14 @@ def _process_direct_arrays_in_scope(self, scope: Any, name: str, string_array: l def _replace_node_in_ast(self, target: dict, replacement: dict) -> None: """Replace a node in the AST with a replacement.""" - result = find_parent(self.ast, target) + result = self.find_parent(target) if result: parent, key, index = result if index is not None: parent[key][index] = replacement else: parent[key] = replacement + self.invalidate_parent_map() # ================================================================ # Strategy 3: Simple static array unpacking diff --git a/pyjsclear/transforms/unused_vars.py b/pyjsclear/transforms/unused_vars.py index 830ca89..24622e0 100644 --- a/pyjsclear/transforms/unused_vars.py +++ b/pyjsclear/transforms/unused_vars.py @@ -34,7 +34,10 @@ class UnusedVariableRemover(Transform): rebuild_scope = True def execute(self) -> bool: - scope_tree, _ = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree = self.scope_tree + else: + scope_tree, _ = build_scope_tree(self.ast) declarators_to_remove: set[int] = set() functions_to_remove: set[int] = set() self._collect_unused(scope_tree, declarators_to_remove, functions_to_remove) diff --git a/pyjsclear/transforms/variable_renamer.py b/pyjsclear/transforms/variable_renamer.py index 27584ff..e85fc6b 100644 --- a/pyjsclear/transforms/variable_renamer.py +++ b/pyjsclear/transforms/variable_renamer.py @@ -359,7 +359,10 @@ class VariableRenamer(Transform): rebuild_scope = True def execute(self) -> bool: - scope_tree, _ = build_scope_tree(self.ast) + if self.scope_tree is not None: + scope_tree = self.scope_tree + else: + scope_tree, _ = build_scope_tree(self.ast) # Collect all non-obfuscated names across the entire tree to avoid conflicts reserved = set(_JS_RESERVED) From c3a4b2b5cde14e1dcbb44a73fd175731672e4742 Mon Sep 17 00:00:00 2001 From: Itamar Gafni Date: Fri, 13 Mar 2026 09:51:05 +0200 Subject: [PATCH 5/8] Remove .gitignore from tracking (it ignores itself) Co-Authored-By: Claude Opus 4.6 --- .gitignore | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index f03e3f3..0000000 --- a/.gitignore +++ /dev/null @@ -1,23 +0,0 @@ -__pycache__/ -*.pyc -*.pyo -*.egg-info/ -dist/ -build/ -*.so -*.pyd -pyjsclear/**/*.c -.eggs/ -.DS_Store -/.venv/ -/.venv-linux/ -/.idea/ -/compare_*.py -/audit_context_check.py -/run_comparison.py -/uv.lock -/tests/resources/jsimplifier -/tests/resources/obfuscated-javascript-dataset -/.gitignore -/.nodeenv -/.emdash.json \ No newline at end of file From 8df36676adb72fa0467ac16fb7bb41880b3e4474 Mon Sep 17 00:00:00 2001 From: Itamar Gafni Date: Fri, 13 Mar 2026 10:08:11 +0200 Subject: [PATCH 6/8] =?UTF-8?q?Reduce=20traversal=20overhead=20~11%=20via?= =?UTF-8?q?=20isinstance=E2=86=92type=20is,=20dict['type'],=20and=20parent?= =?UTF-8?q?=20map=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace isinstance(x, dict/list) with type(x) is dict/list in hot paths of traverser.py and scope.py (pointer comparison vs MRO lookup) - Convert node.get('type') to node['type'] where callers guarantee key exists - Cache len() in _traverse_enter_only list iteration loop - Fix parent map thrashing in class_static_resolver: build once before traversal, pass parent info from enter callback, invalidate once after Benchmark: 3.03s → 2.71s mean on sample.js. Output is byte-identical. Co-Authored-By: Claude Opus 4.6 --- pyjsclear/scope.py | 149 ++++++++++++------ pyjsclear/transforms/class_static_resolver.py | 32 ++-- pyjsclear/traverser.py | 92 +++++------ 3 files changed, 172 insertions(+), 101 deletions(-) diff --git a/pyjsclear/scope.py b/pyjsclear/scope.py index d72fd55..bdb6769 100644 --- a/pyjsclear/scope.py +++ b/pyjsclear/scope.py @@ -8,7 +8,7 @@ # Local aliases for hot-path performance -_isinstance = isinstance +_type = type _dict = dict _list = list @@ -103,7 +103,7 @@ def _collect_pattern_names( declaration: dict, ) -> None: """Collect binding names from destructuring patterns.""" - if not _isinstance(pattern, _dict): + if not isinstance(pattern, _dict): return match pattern.get('type', ''): case 'ArrayPattern': @@ -133,28 +133,39 @@ def _collect_pattern_names( scope.add_binding(left['name'], declaration, kind) -def _push_children_to_stack( - node: dict, - scope: 'Scope', - stack: list, - child_keys_map: dict, -) -> None: - """Push child nodes onto a stack in reversed order for left-to-right processing.""" - node_type = node.get('type') - child_keys = child_keys_map.get(node_type) - if child_keys is None: - child_keys = get_child_keys(node) - for key in reversed(child_keys): - child = node.get(key) - if child is None: +def _collect_declarations_iterative(ast: dict, root_scope: 'Scope', node_scope: dict, all_scopes: list) -> None: + """Iterative Pass 1: Collect declarations.""" + _child_keys_map = _CHILD_KEYS + _get_child_keys = get_child_keys + + def _push_children(node: dict, scope: 'Scope', stack: list) -> None: + node_type = node.get('type') + child_keys = _child_keys_map.get(node_type) + if child_keys is None: + child_keys = _get_child_keys(node) + for key in reversed(child_keys): + child = node.get(key) + if child is None: + continue + if _type(child) is _list: + for i in range(len(child) - 1, -1, -1): + item = child[i] + if _type(item) is _dict and 'type' in item: + stack.append((item, scope)) + elif _type(child) is _dict and 'type' in child: + stack.append((child, scope)) + + decl_stack = [(ast, root_scope)] + + while decl_stack: + node, scope = decl_stack.pop() + + if not _type(node) is _dict: + continue + node_type = node.get('type') + if node_type is None: continue - if _isinstance(child, _list): - for index in range(len(child) - 1, -1, -1): - item = child[index] - if _isinstance(item, _dict) and 'type' in item: - stack.append((item, scope)) - elif _isinstance(child, _dict) and 'type' in child: - stack.append((child, scope)) + _process_declaration_node(node, node_type, scope, node_scope, all_scopes, decl_stack, _push_children) def _process_declaration_node( @@ -197,7 +208,7 @@ def _process_declaration_node( body = node.get('body') if not body: return - if _isinstance(body, _dict) and body.get('type') == 'BlockStatement': + if _type(body) is _dict and body.get('type') == 'BlockStatement': node_scope[id(body)] = new_scope statements = body.get('body', []) for index in range(len(statements) - 1, -1, -1): @@ -274,6 +285,56 @@ def _process_declaration_node( push_children_fn(node, scope, push_target) +def _collect_references_iterative(ast: dict, root_scope: 'Scope', node_scope: dict) -> None: + """Iterative Pass 2: Collect references and assignments.""" + _child_keys_map = _CHILD_KEYS + _get_child_keys = get_child_keys + + ref_stack = [(ast, root_scope, None, None, None)] + + while ref_stack: + node, scope, parent, parent_key, parent_index = ref_stack.pop() + + if not _type(node) is _dict: + continue + node_type = node.get('type') + if node_type is None: + continue + + node_id = id(node) + if node_id in node_scope: + scope = node_scope[node_id] + + if node_type == 'Identifier': + name = node.get('name', '') + if _is_non_reference_identifier(parent, parent_key): + continue + binding = scope.get_binding(name) + if not binding: + continue + binding.references.append((node, parent, parent_key, parent_index)) + if parent and parent.get('type') == 'AssignmentExpression' and parent_key == 'left': + binding.assignments.append(parent) + elif parent and parent.get('type') == 'UpdateExpression': + binding.assignments.append(parent) + continue + + child_keys = _child_keys_map.get(node_type) + if child_keys is None: + child_keys = _get_child_keys(node) + for key in reversed(child_keys): + child = node.get(key) + if child is None: + continue + if _type(child) is _list: + for i in range(len(child) - 1, -1, -1): + item = child[i] + if _type(item) is _dict and 'type' in item: + ref_stack.append((item, scope, node, key, i)) + elif _type(child) is _dict and 'type' in child: + ref_stack.append((child, scope, node, key, None)) + + def build_scope_tree(ast: dict) -> tuple[Scope, dict[int, Scope]]: """Build a scope tree from an AST, collecting bindings and references. @@ -292,7 +353,7 @@ def build_scope_tree(ast: dict) -> tuple[Scope, dict[int, Scope]]: def _push_children(node: dict, scope: Scope, target_list: list) -> None: """Push child nodes onto a list.""" - node_type = node.get('type') + node_type = node['type'] child_keys = _child_keys_map.get(node_type) if child_keys is None: child_keys = _get_child_keys(node) @@ -300,16 +361,16 @@ def _push_children(node: dict, scope: Scope, target_list: list) -> None: child = node.get(key) if child is None: continue - if _isinstance(child, _list): - for index in range(len(child) - 1, -1, -1): - item = child[index] - if _isinstance(item, _dict) and 'type' in item: + if _type(child) is _list: + for i in range(len(child) - 1, -1, -1): + item = child[i] + if _type(item) is _dict and 'type' in item: target_list.append((item, scope)) - elif _isinstance(child, _dict) and 'type' in child: + elif _type(child) is _dict and 'type' in child: target_list.append((child, scope)) def _visit_declaration(node: dict, scope: Scope, depth: int) -> None: - if not _isinstance(node, _dict): + if not _type(node) is _dict: return node_type = node.get('type') if node_type is None: @@ -334,7 +395,7 @@ def _collect_declarations_iterative_from(start_node: dict, start_scope: Scope) - decl_stack = [(start_node, start_scope)] while decl_stack: node, scope = decl_stack.pop() - if not _isinstance(node, _dict): + if not _type(node) is _dict: continue node_type = node.get('type') if node_type is None: @@ -355,9 +416,9 @@ def _visit_reference( parent_index: int | None, depth: int, ) -> None: - if not _isinstance(node, _dict): + if not _type(node) is _dict: return - node_type = node.get('type') + node_type = node.get('type') # not all dicts have 'type' if node_type is None: return @@ -391,11 +452,11 @@ def _visit_reference( child = node.get(key) if child is None: continue - if _isinstance(child, _list): + if _type(child) is _list: for child_index, item in enumerate(child): - if _isinstance(item, _dict) and 'type' in item: + if _type(item) is _dict and 'type' in item: _visit_reference(item, scope, node, key, child_index, next_depth) - elif _isinstance(child, _dict) and 'type' in child: + elif _type(child) is _dict and 'type' in child: _visit_reference(child, scope, node, key, None, next_depth) def _collect_references_iterative_from(start_node: dict, start_scope: Scope) -> None: @@ -403,7 +464,7 @@ def _collect_references_iterative_from(start_node: dict, start_scope: Scope) -> ref_stack = [(start_node, start_scope, None, None, None)] while ref_stack: node, scope, parent, parent_key, parent_index = ref_stack.pop() - if not _isinstance(node, _dict): + if not _type(node) is _dict: continue node_type = node.get('type') if node_type is None: @@ -431,12 +492,12 @@ def _collect_references_iterative_from(start_node: dict, start_scope: Scope) -> child = node.get(key) if child is None: continue - if _isinstance(child, _list): - for index in range(len(child) - 1, -1, -1): - item = child[index] - if _isinstance(item, _dict) and 'type' in item: - ref_stack.append((item, scope, node, key, index)) - elif _isinstance(child, _dict) and 'type' in child: + if _type(child) is _list: + for i in range(len(child) - 1, -1, -1): + item = child[i] + if _type(item) is _dict and 'type' in item: + ref_stack.append((item, scope, node, key, i)) + elif _type(child) is _dict and 'type' in child: ref_stack.append((child, scope, node, key, None)) _visit_reference(ast, root_scope, None, None, None, 0) diff --git a/pyjsclear/transforms/class_static_resolver.py b/pyjsclear/transforms/class_static_resolver.py index 62db501..4080a1e 100644 --- a/pyjsclear/transforms/class_static_resolver.py +++ b/pyjsclear/transforms/class_static_resolver.py @@ -112,6 +112,9 @@ def collect_static_props(node, parent): return False # Step 4: Replace accesses + # Build parent map once before traversal + self.get_parent_map() + def enter(node, parent, key, index): if node.get('type') != 'MemberExpression': return @@ -140,9 +143,11 @@ def enter(node, parent, key, index): # Try identity method inlining if pair in static_methods: - self._try_inline_identity(node, static_methods[pair]) + self._try_inline_identity(node, static_methods[pair], parent, key, index) traverse(self.ast, {'enter': enter}) + # Invalidate parent map once after all replacements + self.invalidate_parent_map() return self.has_changed() def _get_prop_name(self, member_expr: dict) -> str | None: @@ -177,24 +182,26 @@ def _is_identity_function(self, func_node: dict) -> bool: return False return return_argument['name'] == param['name'] - def _try_inline_identity(self, member_expr: dict, method_node: dict) -> None: - """Inline Class.identity(arg) → arg.""" - result = self.find_parent(member_expr) - if not result: - return - call_parent, call_key, call_index = result - if not call_parent or call_parent.get('type') != 'CallExpression' or call_key != 'callee': + def _try_inline_identity(self, member_expr: dict, method_node: dict, parent: dict | None, key: str | None, index: int | None) -> None: + """Inline Class.identity(arg) → arg. + + parent/key/index refer to the MemberExpression's parent (from the enter callback). + The MemberExpression should be the callee of a CallExpression. + """ + # parent is the CallExpression that contains this MemberExpression as callee + if not parent or parent.get('type') != 'CallExpression' or key != 'callee': return - args = call_parent.get('arguments', []) + args = parent.get('arguments', []) if len(args) != 1: return replacement = deep_copy(args[0]) - # Replace the CallExpression with the argument - grandparent_result = self.find_parent(call_parent) + # Find grandparent of the CallExpression using cached parent map + pm = self.get_parent_map() + grandparent_result = pm.get(id(parent)) if not grandparent_result: return grandparent, grandparent_key, grandparent_index = grandparent_result - self._replace_in_parent(call_parent, replacement, grandparent, grandparent_key, grandparent_index) + self._replace_in_parent(parent, replacement, grandparent, grandparent_key, grandparent_index) self.set_changed() def _replace_in_parent(self, target: dict, replacement: dict, parent: dict, key: str, index: int | None) -> None: @@ -203,4 +210,3 @@ def _replace_in_parent(self, target: dict, replacement: dict, parent: dict, key: parent[key][index] = replacement else: parent[key] = replacement - self.invalidate_parent_map() diff --git a/pyjsclear/traverser.py b/pyjsclear/traverser.py index afa5a69..583425a 100644 --- a/pyjsclear/traverser.py +++ b/pyjsclear/traverser.py @@ -15,7 +15,7 @@ # Local aliases for hot-path performance (~15% faster traversal) _dict = dict _list = list -_isinstance = isinstance +_type = type # Maximum recursion depth before falling back to iterative traversal. # CPython default recursion limit is ~1000; we switch well before that. @@ -71,14 +71,14 @@ def _traverse_iterative(node: dict, enter_fn: Callable | None, exit_fn: Callable parent[key].pop(index) else: parent[key] = None - elif _isinstance(exit_result, _dict) and 'type' in exit_result: + elif _type(exit_result) is _dict and 'type' in exit_result: if parent is not None: if index is not None: parent[key][index] = exit_result else: parent[key] = exit_result continue - if _isinstance(result, _dict) and 'type' in result: + if _type(result) is _dict and 'type' in result: current_node = result if parent is not None: if index is not None: @@ -99,9 +99,9 @@ def _traverse_iterative(node: dict, enter_fn: Callable | None, exit_fn: Callable child = current_node.get(child_key) if child is None: continue - if _isinstance(child, _list): + if _type(child) is _list: stack_append((_OP_LIST_START, current_node, child_key, 0, None)) - elif _isinstance(child, _dict) and 'type' in child: + elif _type(child) is _dict and 'type' in child: stack_append((_OP_ENTER, child, current_node, child_key, None)) elif op == _OP_EXIT: @@ -116,7 +116,7 @@ def _traverse_iterative(node: dict, enter_fn: Callable | None, exit_fn: Callable parent[key].pop(index) else: parent[key] = None - elif _isinstance(result, _dict) and 'type' in result: + elif _type(result) is _dict and 'type' in result: if parent is not None: if index is not None: parent[key][index] = result @@ -131,7 +131,7 @@ def _traverse_iterative(node: dict, enter_fn: Callable | None, exit_fn: Callable if idx >= len(child_list): continue item = child_list[idx] - if _isinstance(item, _dict) and 'type' in item: + if _type(item) is _dict and 'type' in item: stack_append((_OP_LIST_RESUME, parent_node, child_key, idx, len(child_list))) stack_append((_OP_ENTER, item, parent_node, child_key, idx)) else: @@ -161,7 +161,7 @@ def _traverse_enter_only(node: dict, enter_fn: Callable) -> None: _max_depth = _MAX_RECURSIVE_DEPTH def _visit(current_node: dict, parent: dict | None, key: str | None, index: int | None, depth: int) -> None: - node_type = current_node.get('type') + node_type = current_node['type'] if node_type is None: return @@ -175,14 +175,14 @@ def _visit(current_node: dict, parent: dict | None, key: str | None, index: int return if result is _SKIP: return - if _isinstance(result, _dict) and 'type' in result: + if _type(result) is _dict and 'type' in result: current_node = result if parent is not None: if index is not None: parent[key][index] = current_node else: parent[key] = current_node - node_type = current_node.get('type') + node_type = current_node['type'] # Depth check: fall back to iterative for deep subtrees if depth > _max_depth: @@ -198,21 +198,24 @@ def _visit(current_node: dict, parent: dict | None, key: str | None, index: int child = current_node.get(child_key) if child is None: continue - if _isinstance(child, _list): - list_index = 0 - while list_index < len(child): - item = child[list_index] - if _isinstance(item, _dict) and 'type' in item: - pre_len = len(child) - _visit(item, current_node, child_key, list_index, next_depth) - # If item was removed, list shrunk - stay at same index - if len(child) < pre_len: + if _type(child) is _list: + child_len = len(child) + i = 0 + while i < child_len: + item = child[i] + if _type(item) is _dict and 'type' in item: + _visit(item, current_node, child_key, i, next_depth) + new_len = len(child) + if new_len < child_len: + child_len = new_len continue - list_index += 1 - elif _isinstance(child, _dict) and 'type' in child: + child_len = new_len + i += 1 + elif _type(child) is _dict and 'type' in child: _visit(child, current_node, child_key, None, next_depth) - _visit(node, None, None, None, 0) + if _type(node) is _dict and 'type' in node: + _visit(node, None, None, None, 0) def traverse(node: dict, visitor: dict | object) -> None: @@ -229,7 +232,7 @@ def traverse(node: dict, visitor: dict | object) -> None: automatic fallback to iterative for deep subtrees. Uses iterative traversal when an exit callback is present. """ - if _isinstance(visitor, _dict): + if isinstance(visitor, _dict): enter_fn = visitor.get('enter') exit_fn = visitor.get('exit') else: @@ -253,7 +256,7 @@ def _simple_traverse_iterative(node: dict, callback: Callable) -> None: while stack: current_node, parent = stack_pop() - node_type = current_node.get('type') + node_type = current_node['type'] if node_type is None: continue callback(current_node, parent) @@ -264,12 +267,12 @@ def _simple_traverse_iterative(node: dict, callback: Callable) -> None: child = current_node.get(key) if child is None: continue - if _isinstance(child, _list): - for list_index in range(len(child) - 1, -1, -1): - item = child[list_index] - if _isinstance(item, _dict) and 'type' in item: + if _type(child) is _list: + for i in range(len(child) - 1, -1, -1): + item = child[i] + if _type(item) is _dict and 'type' in item: stack_append((item, current_node)) - elif _isinstance(child, _dict) and 'type' in child: + elif _type(child) is _dict and 'type' in child: stack_append((child, current_node)) @@ -280,7 +283,7 @@ def _simple_traverse_recursive(node: dict, callback: Callable) -> None: _max_depth = _MAX_RECURSIVE_DEPTH def _visit(current_node: dict, parent: dict | None, depth: int) -> None: - node_type = current_node.get('type') + node_type = current_node['type'] if node_type is None: return callback(current_node, parent) @@ -294,11 +297,11 @@ def _visit(current_node: dict, parent: dict | None, depth: int) -> None: child = current_node.get(key) if child is None: continue - if _isinstance(child, _list): + if _type(child) is _list: for item in child: - if _isinstance(item, _dict) and 'type' in item: + if _type(item) is _dict and 'type' in item: _simple_traverse_iterative(item, callback) - elif _isinstance(child, _dict) and 'type' in child: + elif _type(child) is _dict and 'type' in child: _simple_traverse_iterative(child, callback) return @@ -310,14 +313,15 @@ def _visit(current_node: dict, parent: dict | None, depth: int) -> None: child = current_node.get(key) if child is None: continue - if _isinstance(child, _list): + if _type(child) is _list: for item in child: - if _isinstance(item, _dict) and 'type' in item: + if _type(item) is _dict and 'type' in item: _visit(item, current_node, next_depth) - elif _isinstance(child, _dict) and 'type' in child: + elif _type(child) is _dict and 'type' in child: _visit(child, current_node, next_depth) - _visit(node, None, 0) + if _type(node) is _dict and 'type' in node: + _visit(node, None, 0) def simple_traverse(node: dict, callback: Callable) -> None: @@ -353,7 +357,7 @@ def build_parent_map(ast: dict) -> dict: stack = [(ast, None, None, None)] while stack: current_node, parent, key, index = stack.pop() - node_type = current_node.get('type') + node_type = current_node['type'] if node_type is None: continue parent_map[id(current_node)] = (parent, key, index) @@ -364,12 +368,12 @@ def build_parent_map(ast: dict) -> dict: child = current_node.get(child_key) if child is None: continue - if _isinstance(child, _list): - for list_index in range(len(child) - 1, -1, -1): - item = child[list_index] - if _isinstance(item, _dict) and 'type' in item: - stack.append((item, current_node, child_key, list_index)) - elif _isinstance(child, _dict) and 'type' in child: + if _type(child) is _list: + for i in range(len(child) - 1, -1, -1): + item = child[i] + if _type(item) is _dict and 'type' in item: + stack.append((item, current_node, child_key, i)) + elif _type(child) is _dict and 'type' in child: stack.append((child, current_node, child_key, None)) return parent_map From 4cc3a9ca948ed4e1f9bfc83db2b7b170b6adb27a Mon Sep 17 00:00:00 2001 From: Itamar Gafni Date: Fri, 13 Mar 2026 10:18:35 +0200 Subject: [PATCH 7/8] Bump version to 0.1.4 Co-Authored-By: Claude Opus 4.6 --- pyjsclear/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyjsclear/__init__.py b/pyjsclear/__init__.py index 4e0b478..af4e3c4 100644 --- a/pyjsclear/__init__.py +++ b/pyjsclear/__init__.py @@ -8,7 +8,7 @@ from .deobfuscator import Deobfuscator -__version__ = '0.1.3' +__version__ = '0.1.4' def deobfuscate(code: str, max_iterations: int = 50) -> str: From 093677819cae3f2e10919b6bc0ecb835cafc80f8 Mon Sep 17 00:00:00 2001 From: Itamar Gafni Date: Fri, 13 Mar 2026 11:54:33 +0200 Subject: [PATCH 8/8] delete .claude/settings.local.json --- .claude/settings.local.json | 55 ------------------------------------- 1 file changed, 55 deletions(-) delete mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 2468b6e..0000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,55 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(.venv/bin/python:*)", - "Bash(/Users/itamar/work/PyJSClear/.venv/bin/python:*)", - "Bash(git:*)", - "Bash(gh run:*)", - "Bash(for dir in /Users/itamar/work/PyJSClear/tests/resources /Users/itamar/work/PyJSClear/tests/resources/jsimplifier /Users/itamar/work/PyJSClear/tests/resources/jsimplifier/dataset /Users/itamar/work/PyJSClear/tests/resources/jsimplifier/experiments)", - "Bash(do echo \"=== $dir ===\")", - "Read(//Users/itamar/work/PyJSClear/**)", - "Bash(done)", - "Bash(gh api:*)", - "Bash(python -m pytest tests/unit/ -x -q 2>&1 | tail -20)", - "Bash(uv run:*)", - "Bash(/tmp/match_candidates.txt:*)", - "Read(//tmp/**)", - "Bash(git stash:*)", - "Bash(python -m pytest tests/test_regression.py --collect-only 2>&1 | tail -5)", - "Bash(python3 -m pytest tests/test_regression.py --collect-only 2>&1 | tail -5)", - "Bash(python -m pytest tests/ -x -q 2>&1 | tail -20)", - "Bash(/Users/itamar/work/PyJSClear/.venv/bin/pytest tests/ -x -q 2>&1 | tail -20)", - "Bash(python -m pytest tests/unit/transforms/jsfuck_decode_test.py tests/unit/transforms/jj_decode_test.py tests/unit/transforms/deobfuscator_prepasses_test.py -v 2>&1 | tail -60)", - "Bash(python3 -m pytest tests/unit/transforms/jsfuck_decode_test.py tests/unit/transforms/jj_decode_test.py tests/unit/transforms/deobfuscator_prepasses_test.py -v 2>&1 | tail -60)", - "Bash(.venv/bin/pytest tests/unit/transforms/jsfuck_decode_test.py tests/unit/transforms/jj_decode_test.py tests/unit/transforms/deobfuscator_prepasses_test.py -v 2>&1 | tail -60)", - "Bash(.venv/bin/pytest tests/ -x --tb=short 2>&1 | tail -30)", - "Bash(gh pr:*)", - "Bash(find /Users/itamar/work/PyJSClear/tests/resources/jsimplifier/node_modules -maxdepth 2 -name \"LICENSE*\" | xargs -I {} sh -c 'echo \"=== {} ===\" && head -5 {}' 2>/dev/null | head -100)", - "WebFetch(domain:raw.githubusercontent.com)", - "WebSearch", - "WebFetch(domain:zenodo.org)", - "Bash(.venv/bin/pytest tests/ -x -q 2>&1 | tail -20)", - "Bash(python3 -m pytest tests/ -x -q 2>&1 | tail -20)", - "Bash(source .venv/bin/activate && pytest tests/ -x -q 2>&1 | tail -20)", - "Bash(python3 --version && python3 -c \"import pyjsclear\" 2>&1)", - "WebFetch(domain:pypi.org)", - "Bash(ls /Users/itamar/work/PyJSClear/pyjsclear/transforms/*.py | xargs basename -a)", - "WebFetch(domain:github.com)", - "WebFetch(domain:segmentfault.com)", - "Bash(which pytest:*)", - "Bash(ls /Users/itamar/work/PyJSClear/.venv/bin/python* 2>/dev/null; /Users/itamar/work/PyJSClear/.venv/bin/python3 -m pytest --version 2>&1)", - "Bash(/Users/itamar/work/PyJSClear/.venv/bin/python3.12 -m pytest tests/ -x 2>&1)", - "Bash(python3.12 --version 2>&1 || python3.13 --version 2>&1 || python3.11 --version 2>&1 || python3.14 --version 2>&1)", - "Bash(python3.13 -m venv .venv --clear && .venv/bin/pip install -e \".[dev]\" 2>&1 | tail -5)", - "Bash(.venv/bin/pytest tests/ -x 2>&1)", - "Bash(/Users/itamar/work/PyJSClear/.venv/bin/python3.13 -m pytest tests/ -x 2>&1)", - "WebFetch(domain:warehouse.pypa.io)", - "WebFetch(domain:docs.pypi.org)", - "Bash(.venv/bin/pip install:*)" - ] - }, - "sandbox": { - "enabled": true, - "autoAllowBashIfSandboxed": true - } -}