diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml
index 08ff6e4..9f1ef52 100644
--- a/.github/workflows/fuzz.yml
+++ b/.github/workflows/fuzz.yml
@@ -3,6 +3,8 @@ name: Fuzz Testing
on:
push:
branches: [main]
+ pull_request:
+ branches: [main]
jobs:
fuzz:
diff --git a/.gitignore b/.gitignore
deleted file mode 100644
index ecafa45..0000000
--- a/.gitignore
+++ /dev/null
@@ -1,21 +0,0 @@
-__pycache__/
-*.pyc
-*.pyo
-*.egg-info/
-dist/
-build/
-*.so
-*.pyd
-pyjsclear/**/*.c
-.eggs/
-.DS_Store
-/.venv/
-/.venv-linux/
-/.idea/
-/compare_*.py
-/audit_context_check.py
-/run_comparison.py
-/uv.lock
-/tests/resources/jsimplifier
-/tests/resources/obfuscated-javascript-dataset
-/.gitignore
\ No newline at end of file
diff --git a/pyjsclear/__init__.py b/pyjsclear/__init__.py
index 39945b4..af4e3c4 100644
--- a/pyjsclear/__init__.py
+++ b/pyjsclear/__init__.py
@@ -8,10 +8,10 @@
from .deobfuscator import Deobfuscator
-__version__ = '0.1.3'
+__version__ = '0.1.4'
-def deobfuscate(code, max_iterations=50):
+def deobfuscate(code: str, max_iterations: int = 50) -> str:
"""Deobfuscate JavaScript code. Returns cleaned source.
Args:
@@ -24,7 +24,7 @@ def deobfuscate(code, max_iterations=50):
return Deobfuscator(code, max_iterations=max_iterations).execute()
-def deobfuscate_file(input_path, output_path=None, max_iterations=50):
+def deobfuscate_file(input_path: str, output_path: str | None = None, max_iterations: int = 50) -> str | bool:
"""Deobfuscate a JavaScript file.
Args:
@@ -35,13 +35,13 @@ def deobfuscate_file(input_path, output_path=None, max_iterations=50):
Returns:
True if content changed (when output_path given), or the deobfuscated string.
"""
- with open(input_path, 'r', errors='replace') as f:
- code = f.read()
+ with open(input_path, 'r', errors='replace') as input_file:
+ code = input_file.read()
result = deobfuscate(code, max_iterations=max_iterations)
if output_path:
- with open(output_path, 'w') as f:
- f.write(result)
+ with open(output_path, 'w') as output_file:
+ output_file.write(result)
return result != code
return result
diff --git a/pyjsclear/__main__.py b/pyjsclear/__main__.py
index faf2262..1671410 100644
--- a/pyjsclear/__main__.py
+++ b/pyjsclear/__main__.py
@@ -6,7 +6,7 @@
from . import deobfuscate
-def main():
+def main() -> None:
parser = argparse.ArgumentParser(description='Deobfuscate JavaScript files.')
parser.add_argument('input', help='Input JS file (use - for stdin)')
parser.add_argument('-o', '--output', help='Output file (default: stdout)')
@@ -29,8 +29,9 @@ def main():
if args.output:
with open(args.output, 'w') as output_file:
output_file.write(result)
- else:
- sys.stdout.write(result)
+ return
+
+ sys.stdout.write(result)
if __name__ == '__main__':
diff --git a/pyjsclear/deobfuscator.py b/pyjsclear/deobfuscator.py
index 8636cad..d8876b4 100644
--- a/pyjsclear/deobfuscator.py
+++ b/pyjsclear/deobfuscator.py
@@ -2,6 +2,9 @@
from .generator import generate
from .parser import parse
+from .scope import build_scope_tree
+from .transforms.aa_decode import aa_decode
+from .transforms.aa_decode import is_aa_encoded
from .transforms.anti_tamper import AntiTamperRemover
from .transforms.class_static_resolver import ClassStaticResolver
from .transforms.class_string_decoder import ClassStringDecoder
@@ -26,8 +29,6 @@
from .transforms.hex_escapes import HexEscapes
from .transforms.hex_escapes import decode_hex_escapes_source
from .transforms.hex_numerics import HexNumerics
-from .transforms.aa_decode import aa_decode
-from .transforms.aa_decode import is_aa_encoded
from .transforms.jj_decode import is_jj_encoded
from .transforms.jj_decode import jj_decode
from .transforms.jsfuck_decode import is_jsfuck
@@ -53,6 +54,22 @@
from .traverser import simple_traverse
+# Transforms that use build_scope_tree and benefit from cached scope
+_SCOPE_TRANSFORMS = frozenset(
+ {
+ ConstantProp,
+ SingleUseVarInliner,
+ ReassignmentRemover,
+ ProxyFunctionInliner,
+ UnusedVariableRemover,
+ ObjectSimplifier,
+ StringRevealer,
+ VariableRenamer,
+ VarToConst,
+ LetToConst,
+ }
+)
+
# StringRevealer runs first to handle string arrays before other transforms
# modify the wrapper function structure.
# Remaining transforms follow obfuscator-io-deobfuscator order.
@@ -106,26 +123,26 @@
_NODE_COUNT_LIMIT = 50_000 # Skip ControlFlowRecoverer above this
-def _count_nodes(ast):
+def _count_nodes(ast: dict) -> int:
"""Count total AST nodes."""
count = 0
- def cb(node, parent):
+ def increment_count(node: dict, parent: dict | None) -> None:
nonlocal count
count += 1
- simple_traverse(ast, cb)
+ simple_traverse(ast, increment_count)
return count
class Deobfuscator:
"""Multi-pass JavaScript deobfuscator."""
- def __init__(self, code, max_iterations=50):
+ def __init__(self, code: str, max_iterations: int = 50) -> None:
self.original_code = code
self.max_iterations = max_iterations
- def _run_pre_passes(self, code):
+ def _run_pre_passes(self, code: str) -> str | None:
"""Run encoding detection and eval unpacking pre-passes.
Returns decoded code if an encoding/packing was detected and decoded,
@@ -160,7 +177,7 @@ def _run_pre_passes(self, code):
# Maximum number of outer re-parse cycles (generate → re-parse → re-transform)
_MAX_OUTER_CYCLES = 5
- def execute(self):
+ def execute(self) -> str:
"""Run all transforms and return cleaned source."""
code = self.original_code
@@ -238,7 +255,7 @@ def execute(self):
# but also recursive. Return best result so far.
return previous_code
- def _run_ast_transforms(self, ast, code_size=0):
+ def _run_ast_transforms(self, ast: dict, code_size: int = 0) -> bool:
"""Run all AST transform passes. Returns True if any transform changed the AST."""
node_count = _count_nodes(ast) if code_size > _LARGE_FILE_SIZE else 0
@@ -259,26 +276,40 @@ def _run_ast_transforms(self, ast, code_size=0):
# Track which transforms are no longer productive
skip_transforms = set()
+ # Cache scope tree across transforms — only rebuild when a transform
+ # that modifies bindings returns changed=True
+ scope_tree = None
+ node_scope = None
+ scope_dirty = True # Start dirty to build on first use
+
# Multi-pass transform loop
any_transform_changed = False
- for i in range(max_iterations):
+ for iteration in range(max_iterations):
modified = False
for transform_class in transform_classes:
if transform_class in skip_transforms:
continue
try:
- transform = transform_class(ast)
+ # Build scope tree lazily when needed by a scope-using transform
+ if transform_class in _SCOPE_TRANSFORMS and scope_dirty:
+ scope_tree, node_scope = build_scope_tree(ast)
+ scope_dirty = False
+
+ if transform_class in _SCOPE_TRANSFORMS:
+ transform = transform_class(ast, scope_tree=scope_tree, node_scope=node_scope)
+ else:
+ transform = transform_class(ast)
result = transform.execute()
except Exception:
continue
if result:
modified = True
any_transform_changed = True
- else:
- # If a transform didn't change anything after the first pass,
- # skip it in subsequent iterations
- if i > 0:
- skip_transforms.add(transform_class)
+ # Any AST change invalidates the cached scope tree
+ scope_dirty = True
+ elif iteration > 0:
+ # Skip transforms that haven't changed anything after the first pass
+ skip_transforms.add(transform_class)
if not modified:
break
diff --git a/pyjsclear/generator.py b/pyjsclear/generator.py
index 71f0053..c680fd9 100644
--- a/pyjsclear/generator.py
+++ b/pyjsclear/generator.py
@@ -1,4 +1,5 @@
"""ESTree AST to JavaScript code generator."""
+from __future__ import annotations
# Operator precedence (higher = binds tighter)
_PRECEDENCE = {
@@ -63,7 +64,7 @@
)
-def generate(node, indent=0):
+def generate(node: dict | None, indent: int = 0) -> str:
"""Generate JavaScript source from an ESTree AST node."""
if node is None:
return ''
@@ -77,11 +78,11 @@ def generate(node, indent=0):
return f'/* unknown: {node_type} */'
-def _indent_str(level):
+def _indent_str(level: int) -> str:
return ' ' * level
-def _is_directive(stmt):
+def _is_directive(stmt: dict) -> bool:
"""Check if a statement is a string-literal directive (like 'use strict')."""
return (
stmt.get('type') == 'ExpressionStatement'
@@ -91,10 +92,10 @@ def _is_directive(stmt):
)
-def _gen_program(node, indent):
+def _gen_program(node: dict, indent: int) -> str:
parts = []
body = node.get('body', [])
- for i, stmt in enumerate(body):
+ for index, stmt in enumerate(body):
if stmt is None:
continue
if stmt.get('type') == 'EmptyStatement':
@@ -102,12 +103,12 @@ def _gen_program(node, indent):
statement_code = _gen_stmt(stmt, indent)
if statement_code.strip():
parts.append(statement_code)
- if _is_directive(stmt) and i + 1 < len(body):
+ if _is_directive(stmt) and index + 1 < len(body):
parts.append('')
return '\n'.join(parts)
-def _gen_stmt(node, indent):
+def _gen_stmt(node: dict | None, indent: int) -> str:
"""Generate a statement with indentation."""
if node is None:
return ''
@@ -121,20 +122,20 @@ def _gen_stmt(node, indent):
return prefix + code + ';'
-def _gen_block(node, indent):
+def _gen_block(node: dict, indent: int) -> str:
if not node.get('body'):
return '{}'
lines = ['{']
body = node.get('body', [])
- for i, stmt in enumerate(body):
+ for index, stmt in enumerate(body):
lines.append(_gen_stmt(stmt, indent + 1))
- if _is_directive(stmt) and i + 1 < len(body):
+ if _is_directive(stmt) and index + 1 < len(body):
lines.append('')
lines.append(_indent_str(indent) + '}')
return '\n'.join(lines)
-def _gen_var_declaration(node, indent):
+def _gen_var_declaration(node: dict, indent: int) -> str:
kind = node.get('kind', 'var')
declarations = []
for declaration in node.get('declarations', []):
@@ -147,10 +148,10 @@ def _gen_var_declaration(node, indent):
return f'{kind} {", ".join(declarations)}'
-def _gen_function(node, indent, is_expression=False):
+def _gen_function(node: dict, indent: int, is_expression: bool = False) -> str:
"""Shared generator for FunctionDeclaration and FunctionExpression."""
name = generate(node['id'], indent) if node.get('id') else ''
- params = ', '.join(generate(p, indent) for p in node.get('params', []))
+ params = ', '.join(generate(param, indent) for param in node.get('params', []))
async_prefix = 'async ' if node.get('async') else ''
gen_prefix = '*' if node.get('generator') else ''
body = generate(node['body'], indent)
@@ -160,18 +161,18 @@ def _gen_function(node, indent, is_expression=False):
return f'{async_prefix}function{gen_prefix} ({params}) {body}'
-def _gen_function_decl(node, indent):
+def _gen_function_decl(node: dict, indent: int) -> str:
return _gen_function(node, indent)
-def _gen_function_expr(node, indent):
+def _gen_function_expr(node: dict, indent: int) -> str:
return _gen_function(node, indent, is_expression=True)
-def _gen_arrow(node, indent):
+def _gen_arrow(node: dict, indent: int) -> str:
params = node.get('params', [])
async_prefix = 'async ' if node.get('async') else ''
- param_str = '(' + ', '.join(generate(p, indent) for p in params) + ')'
+ param_str = '(' + ', '.join(generate(param, indent) for param in params) + ')'
body = node.get('body', {})
body_str = generate(body, indent)
# Wrap object literal in parens to avoid ambiguity with block
@@ -180,14 +181,14 @@ def _gen_arrow(node, indent):
return f'{async_prefix}{param_str} => {body_str}'
-def _gen_return(node, indent):
+def _gen_return(node: dict, indent: int) -> str:
argument = node.get('argument')
if argument:
return f'return {generate(argument, indent)}'
return 'return'
-def _gen_if(node, indent):
+def _gen_if(node: dict, indent: int) -> str:
test = generate(node['test'], indent)
consequent_code = generate(node['consequent'], indent)
if node['consequent'].get('type') != 'BlockStatement':
@@ -202,19 +203,19 @@ def _gen_if(node, indent):
return f'if ({test}) {consequent_code}'
-def _gen_while(node, indent):
+def _gen_while(node: dict, indent: int) -> str:
test = generate(node['test'], indent)
body = generate(node['body'], indent)
return f'while ({test}) {body}'
-def _gen_do_while(node, indent):
+def _gen_do_while(node: dict, indent: int) -> str:
body = generate(node['body'], indent)
test = generate(node['test'], indent)
return f'do {body} while ({test})'
-def _gen_for(node, indent):
+def _gen_for(node: dict, indent: int) -> str:
init = ''
if node.get('init'):
init = generate(node['init'], indent)
@@ -224,21 +225,21 @@ def _gen_for(node, indent):
return f'for ({init}; {test}; {update}) {body}'
-def _gen_for_in(node, indent):
+def _gen_for_in(node: dict, indent: int) -> str:
left = generate(node['left'], indent)
right = generate(node['right'], indent)
body = generate(node['body'], indent)
return f'for ({left} in {right}) {body}'
-def _gen_for_of(node, indent):
+def _gen_for_of(node: dict, indent: int) -> str:
left = generate(node['left'], indent)
right = generate(node['right'], indent)
body = generate(node['body'], indent)
return f'for ({left} of {right}) {body}'
-def _gen_switch(node, indent):
+def _gen_switch(node: dict, indent: int) -> str:
discriminant = generate(node['discriminant'], indent)
lines = [f'switch ({discriminant}) {{']
for case in node.get('cases', []):
@@ -252,15 +253,15 @@ def _gen_switch(node, indent):
return '\n'.join(lines)
-def _gen_try(node, indent):
+def _gen_try(node: dict, indent: int) -> str:
block = generate(node['block'], indent)
result = f'try {block}'
handler = node.get('handler')
if handler:
- param = generate(handler.get('param'), indent) if handler.get('param') else ''
+ catch_param = generate(handler.get('param'), indent) if handler.get('param') else ''
handler_body = generate(handler['body'], indent)
- if param:
- result += f' catch ({param}) {handler_body}'
+ if catch_param:
+ result += f' catch ({catch_param}) {handler_body}'
else:
result += f' catch {handler_body}'
finalizer = node.get('finalizer')
@@ -269,33 +270,33 @@ def _gen_try(node, indent):
return result
-def _gen_throw(node, indent):
+def _gen_throw(node: dict, indent: int) -> str:
return f'throw {generate(node["argument"], indent)}'
-def _gen_break(node, indent):
+def _gen_break(node: dict, indent: int) -> str:
if node.get('label'):
return f'break {generate(node["label"], indent)}'
return 'break'
-def _gen_continue(node, indent):
+def _gen_continue(node: dict, indent: int) -> str:
if node.get('label'):
return f'continue {generate(node["label"], indent)}'
return 'continue'
-def _gen_labeled(node, indent):
+def _gen_labeled(node: dict, indent: int) -> str:
label = generate(node['label'], indent)
body = _gen_stmt(node['body'], indent)
return f'{label}:\n{body}'
-def _gen_expr_stmt(node, indent):
+def _gen_expr_stmt(node: dict, indent: int) -> str:
return generate(node['expression'], indent)
-def _gen_binary(node, indent):
+def _gen_binary(node: dict, indent: int) -> str:
operator = node.get('operator', '')
left = generate(node['left'], indent)
right = generate(node['right'], indent)
@@ -309,11 +310,11 @@ def _gen_binary(node, indent):
return f'{left} {operator} {right}'
-def _gen_logical(node, indent):
+def _gen_logical(node: dict, indent: int) -> str:
return _gen_binary(node, indent)
-def _gen_unary(node, indent):
+def _gen_unary(node: dict, indent: int) -> str:
operator = node.get('operator', '')
operand = generate(node['argument'], indent)
operand_prec = _expr_precedence(node['argument'])
@@ -326,7 +327,7 @@ def _gen_unary(node, indent):
return f'{operand}{operator}'
-def _gen_update(node, indent):
+def _gen_update(node: dict, indent: int) -> str:
argument = generate(node['argument'], indent)
operator = node.get('operator', '++')
if node.get('prefix'):
@@ -334,14 +335,14 @@ def _gen_update(node, indent):
return f'{argument}{operator}'
-def _gen_assignment(node, indent):
+def _gen_assignment(node: dict, indent: int) -> str:
left = generate(node['left'], indent)
right = generate(node['right'], indent)
operator = node.get('operator', '=')
return f'{left} {operator} {right}'
-def _gen_member(node, indent):
+def _gen_member(node: dict, indent: int) -> str:
object_code = generate(node['object'], indent)
obj_type = node['object'].get('type', '')
computed = node.get('computed')
@@ -371,56 +372,56 @@ def _gen_member(node, indent):
return f'{object_code}{dot}{property_code}'
-def _gen_call(node, indent):
+def _gen_call(node: dict, indent: int) -> str:
callee = generate(node['callee'], indent)
callee_type = node['callee'].get('type', '')
if callee_type in ('FunctionExpression', 'ArrowFunctionExpression', 'SequenceExpression'):
callee = f'({callee})'
- args = ', '.join(generate(a, indent) for a in node.get('arguments', []))
+ args = ', '.join(generate(argument, indent) for argument in node.get('arguments', []))
if node.get('optional'):
return f'{callee}?.({args})'
return f'{callee}({args})'
-def _gen_new(node, indent):
+def _gen_new(node: dict, indent: int) -> str:
callee = generate(node['callee'], indent)
args = node.get('arguments', [])
if args:
- arg_str = ', '.join(generate(a, indent) for a in args)
+ arg_str = ', '.join(generate(argument, indent) for argument in args)
return f'new {callee}({arg_str})'
return f'new {callee}()'
-def _wrap_if_sequence(node, code):
+def _wrap_if_sequence(node: dict | None, code: str) -> str:
"""Wrap code in parens if node is a SequenceExpression."""
if isinstance(node, dict) and node.get('type') == 'SequenceExpression':
return f'({code})'
return code
-def _gen_conditional(node, indent):
+def _gen_conditional(node: dict, indent: int) -> str:
test = generate(node['test'], indent)
consequent_code = _wrap_if_sequence(node.get('consequent'), generate(node['consequent'], indent))
alternate_code = _wrap_if_sequence(node.get('alternate'), generate(node['alternate'], indent))
return f'{test} ? {consequent_code} : {alternate_code}'
-def _gen_sequence(node, indent):
- exprs = ', '.join(generate(e, indent) for e in node.get('expressions', []))
+def _gen_sequence(node: dict, indent: int) -> str:
+ exprs = ', '.join(generate(expression, indent) for expression in node.get('expressions', []))
return exprs
-def _gen_bracket_list(elements, indent):
+def _gen_bracket_list(elements: list, indent: int) -> str:
"""Generate a bracketed list of elements, replacing None with empty slots."""
- elems = [generate(e, indent) if e is not None else '' for e in elements]
+ elems = [generate(element, indent) if element is not None else '' for element in elements]
return '[' + ', '.join(elems) + ']'
-def _gen_array(node, indent):
+def _gen_array(node: dict, indent: int) -> str:
return _gen_bracket_list(node.get('elements', []), indent)
-def _gen_object_property(property_node, indent):
+def _gen_object_property(property_node: dict, indent: int) -> str:
"""Generate a single object property string."""
if property_node.get('type') == 'SpreadElement':
return '...' + generate(property_node['argument'], indent)
@@ -432,7 +433,7 @@ def _gen_object_property(property_node, indent):
kind = property_node.get('kind', 'init')
if kind in ('get', 'set') or property_node.get('method'):
prefix = f'{kind} ' if kind in ('get', 'set') else ''
- params = ', '.join(generate(pp, indent) for pp in property_node['value'].get('params', []))
+ params = ', '.join(generate(param, indent) for param in property_node['value'].get('params', []))
body = generate(property_node['value'].get('body'), indent)
return f'{prefix}{key}({params}) {body}'
@@ -443,7 +444,7 @@ def _gen_object_property(property_node, indent):
return f'{key}: {value}'
-def _gen_object(node, indent):
+def _gen_object(node: dict, indent: int) -> str:
properties = node.get('properties', [])
if not properties:
return '{}'
@@ -454,23 +455,23 @@ def _gen_object(node, indent):
return '{\n' + lines + '\n' + outer_indent + '}'
-def _gen_property(node, indent):
+def _gen_property(node: dict, indent: int) -> str:
key = generate(node['key'], indent)
value = generate(node['value'], indent)
return f'{key}: {value}'
-def _gen_spread(node, indent):
+def _gen_spread(node: dict, indent: int) -> str:
return '...' + generate(node['argument'], indent)
-def _escape_string(val, raw):
+def _escape_string(string_value: str, raw: str | None) -> str:
"""Escape a string value and wrap in the appropriate quotes."""
if raw and len(raw) >= 2 and raw[0] in ('"', "'"):
quote = raw[0]
else:
quote = '"'
- escaped = val.replace('\\', '\\\\')
+ escaped = string_value.replace('\\', '\\\\')
escaped = escaped.replace('\n', '\\n')
escaped = escaped.replace('\r', '\\r')
escaped = escaped.replace('\t', '\\t')
@@ -478,7 +479,7 @@ def _escape_string(val, raw):
return f'{quote}{escaped}{quote}'
-def _gen_literal(node, indent):
+def _gen_literal(node: dict, indent: int) -> str:
raw = node.get('raw')
value = node.get('value')
if isinstance(value, str):
@@ -498,37 +499,37 @@ def _gen_literal(node, indent):
return str(value)
-def _gen_identifier(node, indent):
+def _gen_identifier(node: dict, indent: int) -> str:
return node.get('name', '')
-def _gen_this(node, indent):
+def _gen_this(node: dict, indent: int) -> str:
return 'this'
-def _gen_empty(node, indent):
+def _gen_empty(node: dict, indent: int) -> str:
return ';'
-def _gen_template_literal(node, indent):
+def _gen_template_literal(node: dict, indent: int) -> str:
quasis = node.get('quasis', [])
- exprs = node.get('expressions', [])
+ expressions = node.get('expressions', [])
parts = []
- for i, quasi in enumerate(quasis):
+ for index, quasi in enumerate(quasis):
raw = quasi.get('value', {}).get('raw', '')
parts.append(raw)
- if i < len(exprs):
- parts.append('${' + generate(exprs[i], indent) + '}')
+ if index < len(expressions):
+ parts.append('${' + generate(expressions[index], indent) + '}')
return '`' + ''.join(parts) + '`'
-def _gen_tagged_template(node, indent):
+def _gen_tagged_template(node: dict, indent: int) -> str:
tag = generate(node['tag'], indent)
quasi = generate(node['quasi'], indent)
return f'{tag}{quasi}'
-def _gen_class_decl(node, indent):
+def _gen_class_decl(node: dict, indent: int) -> str:
name = generate(node['id'], indent) if node.get('id') else ''
superclass_clause = ''
if node.get('superClass'):
@@ -539,7 +540,7 @@ def _gen_class_decl(node, indent):
return f'class{superclass_clause} {body}'
-def _gen_class_body(node, indent):
+def _gen_class_body(node: dict, indent: int) -> str:
if not node.get('body'):
return '{}'
lines = ['{']
@@ -549,29 +550,29 @@ def _gen_class_body(node, indent):
return '\n'.join(lines)
-def _gen_method_def(node, indent):
+def _gen_method_def(node: dict, indent: int) -> str:
key = generate(node['key'], indent)
if node.get('computed') or node['key'].get('type') == 'Literal':
key = f'[{key}]'
- static = 'static ' if node.get('static') else ''
+ static_prefix = 'static ' if node.get('static') else ''
kind = node.get('kind', 'method')
value = node.get('value', {})
- params = ', '.join(generate(p, indent) for p in value.get('params', []))
+ params = ', '.join(generate(param, indent) for param in value.get('params', []))
body = generate(value.get('body'), indent)
match kind:
case 'constructor':
- return f'{static}constructor({params}) {body}'
+ return f'{static_prefix}constructor({params}) {body}'
case 'get':
- return f'{static}get {key}({params}) {body}'
+ return f'{static_prefix}get {key}({params}) {body}'
case 'set':
- return f'{static}set {key}({params}) {body}'
+ return f'{static_prefix}set {key}({params}) {body}'
case _:
async_prefix = 'async ' if value.get('async') else ''
gen_prefix = '*' if value.get('generator') else ''
- return f'{static}{async_prefix}{gen_prefix}{key}({params}) {body}'
+ return f'{static_prefix}{async_prefix}{gen_prefix}{key}({params}) {body}'
-def _gen_yield(node, indent):
+def _gen_yield(node: dict, indent: int) -> str:
argument = generate(node.get('argument'), indent) if node.get('argument') else ''
delegate = '*' if node.get('delegate') else ''
if argument:
@@ -579,21 +580,21 @@ def _gen_yield(node, indent):
return f'yield{delegate}'
-def _gen_await(node, indent):
+def _gen_await(node: dict, indent: int) -> str:
return f'await {generate(node["argument"], indent)}'
-def _gen_assignment_pattern(node, indent):
+def _gen_assignment_pattern(node: dict, indent: int) -> str:
left = generate(node['left'], indent)
right = generate(node['right'], indent)
return f'{left} = {right}'
-def _gen_array_pattern(node, indent):
+def _gen_array_pattern(node: dict, indent: int) -> str:
return _gen_bracket_list(node.get('elements', []), indent)
-def _gen_object_pattern_part(property_node, indent):
+def _gen_object_pattern_part(property_node: dict, indent: int) -> str:
"""Generate a single destructuring pattern property."""
if property_node.get('type') == 'RestElement':
return '...' + generate(property_node['argument'], indent)
@@ -604,7 +605,7 @@ def _gen_object_pattern_part(property_node, indent):
return f'{key}: {value}'
-def _gen_object_pattern(node, indent):
+def _gen_object_pattern(node: dict, indent: int) -> str:
properties = [_gen_object_pattern_part(property_node, indent + 1) for property_node in node.get('properties', [])]
if not properties:
return '{}'
@@ -614,75 +615,75 @@ def _gen_object_pattern(node, indent):
return '{\n' + lines + '\n' + outer_indent + '}'
-def _gen_rest_element(node, indent):
+def _gen_rest_element(node: dict, indent: int) -> str:
return '...' + generate(node['argument'], indent)
-def _gen_import_specifier(spec, indent):
+def _gen_import_specifier(specifier: dict, indent: int) -> str:
"""Generate a single import specifier."""
- spec_type = spec.get('type', '')
- if spec_type == 'ImportDefaultSpecifier':
- return generate(spec['local'], indent)
- if spec_type == 'ImportNamespaceSpecifier':
- return '* as ' + generate(spec['local'], indent)
+ specifier_type = specifier.get('type', '')
+ if specifier_type == 'ImportDefaultSpecifier':
+ return generate(specifier['local'], indent)
+ if specifier_type == 'ImportNamespaceSpecifier':
+ return '* as ' + generate(specifier['local'], indent)
# ImportSpecifier
- imported = generate(spec['imported'], indent)
- local = generate(spec['local'], indent)
+ imported = generate(specifier['imported'], indent)
+ local = generate(specifier['local'], indent)
if imported == local:
return imported
return f'{imported} as {local}'
-def _gen_import_declaration(node, indent):
+def _gen_import_declaration(node: dict, indent: int) -> str:
source = generate(node['source'], indent)
specifiers = node.get('specifiers', [])
if not specifiers:
return f'import {source}'
- default_specs = [s for s in specifiers if s.get('type') == 'ImportDefaultSpecifier']
- namespace_specs = [s for s in specifiers if s.get('type') == 'ImportNamespaceSpecifier']
- named_specs = [s for s in specifiers if s.get('type') == 'ImportSpecifier']
+ default_specifiers = [s for s in specifiers if s.get('type') == 'ImportDefaultSpecifier']
+ namespace_specifiers = [s for s in specifiers if s.get('type') == 'ImportNamespaceSpecifier']
+ named_specifiers = [s for s in specifiers if s.get('type') == 'ImportSpecifier']
parts = []
- if default_specs:
- parts.append(_gen_import_specifier(default_specs[0], indent))
- if namespace_specs:
- parts.append(_gen_import_specifier(namespace_specs[0], indent))
- if named_specs:
- names = ', '.join(_gen_import_specifier(s, indent) for s in named_specs)
+ if default_specifiers:
+ parts.append(_gen_import_specifier(default_specifiers[0], indent))
+ if namespace_specifiers:
+ parts.append(_gen_import_specifier(namespace_specifiers[0], indent))
+ if named_specifiers:
+ names = ', '.join(_gen_import_specifier(specifier, indent) for specifier in named_specifiers)
parts.append('{' + names + '}')
return f'import {", ".join(parts)} from {source}'
-def _gen_export_specifier(spec, indent):
- exported = generate(spec['exported'], indent)
- local = generate(spec['local'], indent)
+def _gen_export_specifier(specifier: dict, indent: int) -> str:
+ exported = generate(specifier['exported'], indent)
+ local = generate(specifier['local'], indent)
if exported == local:
return exported
return f'{local} as {exported}'
-def _gen_export_named(node, indent):
+def _gen_export_named(node: dict, indent: int) -> str:
declaration = node.get('declaration')
if declaration:
return f'export {generate(declaration, indent)}'
specifiers = node.get('specifiers', [])
- names = ', '.join(_gen_export_specifier(s, indent) for s in specifiers)
+ names = ', '.join(_gen_export_specifier(specifier, indent) for specifier in specifiers)
source = node.get('source')
if source:
return f'export {{{names}}} from {generate(source, indent)}'
return f'export {{{names}}}'
-def _gen_export_default(node, indent):
+def _gen_export_default(node: dict, indent: int) -> str:
declaration = node.get('declaration', {})
return f'export default {generate(declaration, indent)}'
-def _gen_export_all(node, indent):
+def _gen_export_all(node: dict, indent: int) -> str:
source = generate(node['source'], indent)
return f'export * from {source}'
-def _expr_precedence(node):
+def _expr_precedence(node: dict) -> int:
"""Get the precedence level of an expression node."""
if not isinstance(node, dict):
return 20
diff --git a/pyjsclear/parser.py b/pyjsclear/parser.py
index 049a87f..f794e72 100644
--- a/pyjsclear/parser.py
+++ b/pyjsclear/parser.py
@@ -8,7 +8,7 @@
_ASYNC_MAP = {'isAsync': 'async', 'allowAwait': 'await'}
-def _fast_to_dict(obj):
+def _fast_to_dict(obj: object) -> object:
"""Convert esprima AST objects to plain dicts, ~2x faster than toDict()."""
if isinstance(obj, (str, int, float, bool, type(None))):
return obj
@@ -29,7 +29,7 @@ def _fast_to_dict(obj):
return output
-def parse(code):
+def parse(code: str) -> dict:
"""Parse JavaScript code into an ESTree-compatible AST.
Returns a Program node (dict).
diff --git a/pyjsclear/scope.py b/pyjsclear/scope.py
index f4ee9f5..bdb6769 100644
--- a/pyjsclear/scope.py
+++ b/pyjsclear/scope.py
@@ -1,24 +1,36 @@
"""Variable scope and binding analysis for ESTree ASTs."""
+from collections.abc import Callable
+from typing import Any
+
from .utils.ast_helpers import _CHILD_KEYS
from .utils.ast_helpers import get_child_keys
+# Local aliases for hot-path performance
+_type = type
+_dict = dict
+_list = list
+
+# Maximum recursion depth before falling back to iterative traversal.
+_MAX_RECURSIVE_DEPTH = 500
+
+
class Binding:
"""Represents a variable binding in a scope."""
__slots__ = ('name', 'node', 'kind', 'scope', 'references', 'assignments')
- def __init__(self, name, node, kind, scope):
+ def __init__(self, name: str, node: dict, kind: str, scope: 'Scope') -> None:
self.name = name
self.node = node # The declaration node
self.kind = kind # 'var', 'let', 'const', 'function', 'param'
self.scope = scope
- self.references = [] # List of (node, parent, key, index) where name is referenced
- self.assignments = [] # List of assignment nodes
+ self.references: list = [] # List of (node, parent, key, index) where name is referenced
+ self.assignments: list = [] # List of assignment nodes
@property
- def is_constant(self):
+ def is_constant(self) -> bool:
"""True if the binding is never reassigned after declaration."""
if self.kind == 'const':
return True
@@ -33,21 +45,21 @@ class Scope:
__slots__ = ('parent', 'node', 'bindings', 'children', 'is_function')
- def __init__(self, parent, node, is_function=False):
+ def __init__(self, parent: 'Scope | None', node: dict, is_function: bool = False) -> None:
self.parent = parent
self.node = node
- self.bindings = {} # name -> Binding
- self.children = []
+ self.bindings: dict[str, Binding] = {} # name -> Binding
+ self.children: list['Scope'] = []
self.is_function = is_function
if parent:
parent.children.append(self)
- def add_binding(self, name, node, kind):
+ def add_binding(self, name: str, node: dict, kind: str) -> Binding:
binding = Binding(name, node, kind, self)
self.bindings[name] = binding
return binding
- def get_binding(self, name):
+ def get_binding(self, name: str) -> 'Binding | None':
"""Look up a binding, walking up the scope chain."""
if name in self.bindings:
return self.bindings[name]
@@ -55,18 +67,18 @@ def get_binding(self, name):
return self.parent.get_binding(name)
return None
- def get_own_binding(self, name):
+ def get_own_binding(self, name: str) -> 'Binding | None':
return self.bindings.get(name)
-def _nearest_function_scope(scope):
+def _nearest_function_scope(scope: 'Scope | None') -> 'Scope | None':
"""Walk up to the nearest function (or root) scope."""
while scope and not scope.is_function:
scope = scope.parent
return scope
-def _is_non_reference_identifier(parent, parent_key):
+def _is_non_reference_identifier(parent: dict | None, parent_key: str | None) -> bool:
"""Return True if this Identifier usage is not a variable reference."""
if not parent:
return False
@@ -84,200 +96,340 @@ def _is_non_reference_identifier(parent, parent_key):
return False
-def _recurse_into_children(node, child_keys_map, callback):
- """Walk child nodes, calling callback(child_node) for each dict with 'type'."""
- node_type = node.get('type')
- child_keys = child_keys_map.get(node_type)
- if child_keys is None:
- child_keys = get_child_keys(node)
- for key in child_keys:
- child = node.get(key)
- if child is None:
+def _collect_pattern_names(
+ pattern: dict | None,
+ scope: 'Scope',
+ kind: str,
+ declaration: dict,
+) -> None:
+ """Collect binding names from destructuring patterns."""
+ if not isinstance(pattern, _dict):
+ return
+ match pattern.get('type', ''):
+ case 'ArrayPattern':
+ for element in pattern.get('elements', []):
+ if not element:
+ continue
+ if element.get('type') == 'Identifier':
+ scope.add_binding(element['name'], declaration, kind)
+ else:
+ _collect_pattern_names(element, scope, kind, declaration)
+ case 'ObjectPattern':
+ for property_node in pattern.get('properties', []):
+ value_node = property_node.get('value', property_node.get('argument'))
+ if not value_node:
+ continue
+ if value_node.get('type') == 'Identifier':
+ scope.add_binding(value_node['name'], declaration, kind)
+ else:
+ _collect_pattern_names(value_node, scope, kind, declaration)
+ case 'RestElement':
+ argument_node = pattern.get('argument')
+ if argument_node and argument_node.get('type') == 'Identifier':
+ scope.add_binding(argument_node['name'], declaration, kind)
+ case 'AssignmentPattern':
+ left = pattern.get('left')
+ if left and left.get('type') == 'Identifier':
+ scope.add_binding(left['name'], declaration, kind)
+
+
+def _collect_declarations_iterative(ast: dict, root_scope: 'Scope', node_scope: dict, all_scopes: list) -> None:
+ """Iterative Pass 1: Collect declarations."""
+ _child_keys_map = _CHILD_KEYS
+ _get_child_keys = get_child_keys
+
+ def _push_children(node: dict, scope: 'Scope', stack: list) -> None:
+ node_type = node.get('type')
+ child_keys = _child_keys_map.get(node_type)
+ if child_keys is None:
+ child_keys = _get_child_keys(node)
+ for key in reversed(child_keys):
+ child = node.get(key)
+ if child is None:
+ continue
+ if _type(child) is _list:
+ for i in range(len(child) - 1, -1, -1):
+ item = child[i]
+ if _type(item) is _dict and 'type' in item:
+ stack.append((item, scope))
+ elif _type(child) is _dict and 'type' in child:
+ stack.append((child, scope))
+
+ decl_stack = [(ast, root_scope)]
+
+ while decl_stack:
+ node, scope = decl_stack.pop()
+
+ if not _type(node) is _dict:
+ continue
+ node_type = node.get('type')
+ if node_type is None:
+ continue
+ _process_declaration_node(node, node_type, scope, node_scope, all_scopes, decl_stack, _push_children)
+
+
+def _process_declaration_node(
+ node: dict,
+ node_type: str,
+ scope: 'Scope',
+ node_scope: dict[int, 'Scope'],
+ all_scopes: list['Scope'],
+ push_target: list,
+ push_children_fn: Callable,
+) -> None:
+ """Process a single node for declaration collection.
+
+ push_target: a list to append (node, scope) tuples to.
+ push_children_fn: callable(node, scope, push_target) to push child nodes.
+ """
+ if node_type in ('FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'):
+ new_scope = Scope(scope, node, is_function=True)
+ node_scope[id(node)] = new_scope
+ all_scopes.append(new_scope)
+
+ if node_type == 'FunctionDeclaration' and node.get('id'):
+ scope.add_binding(node['id']['name'], node, 'function')
+ elif node_type == 'FunctionExpression' and node.get('id'):
+ new_scope.add_binding(node['id']['name'], node, 'function')
+
+ for param in node.get('params', []):
+ param_type = param.get('type')
+ if param_type == 'Identifier':
+ new_scope.add_binding(param['name'], param, 'param')
+ elif param_type == 'AssignmentPattern':
+ left = param.get('left', {})
+ if left.get('type') == 'Identifier':
+ new_scope.add_binding(left['name'], param, 'param')
+ elif param_type == 'RestElement':
+ argument = param.get('argument')
+ if argument and argument.get('type') == 'Identifier':
+ new_scope.add_binding(argument['name'], param, 'param')
+
+ body = node.get('body')
+ if not body:
+ return
+ if _type(body) is _dict and body.get('type') == 'BlockStatement':
+ node_scope[id(body)] = new_scope
+ statements = body.get('body', [])
+ for index in range(len(statements) - 1, -1, -1):
+ push_target.append((statements[index], new_scope))
+ else:
+ push_target.append((body, new_scope))
+
+ elif node_type in ('ClassExpression', 'ClassDeclaration'):
+ class_id = node.get('id')
+ inner_scope = scope
+ if class_id and class_id.get('type') == 'Identifier':
+ name = class_id['name']
+ if node_type == 'ClassDeclaration':
+ scope.add_binding(name, node, 'function')
+ else:
+ inner_scope = Scope(scope, node)
+ node_scope[id(node)] = inner_scope
+ all_scopes.append(inner_scope)
+ inner_scope.add_binding(name, node, 'function')
+ superclass = node.get('superClass')
+ body = node.get('body')
+ if body:
+ push_target.append((body, inner_scope))
+ if superclass:
+ push_target.append((superclass, scope))
+
+ elif node_type == 'VariableDeclaration':
+ kind = node.get('kind', 'var')
+ target_scope = (_nearest_function_scope(scope) or scope) if kind == 'var' else scope
+ declarations = node.get('declarations', [])
+ inits_to_push = []
+ for declaration in declarations:
+ declaration_id = declaration.get('id')
+ if declaration_id and declaration_id.get('type') == 'Identifier':
+ target_scope.add_binding(declaration_id['name'], declaration, kind)
+ _collect_pattern_names(declaration_id, target_scope, kind, declaration)
+ init = declaration.get('init')
+ if init:
+ inits_to_push.append((init, scope))
+ for index in range(len(inits_to_push) - 1, -1, -1):
+ push_target.append(inits_to_push[index])
+
+ elif node_type == 'BlockStatement' and id(node) not in node_scope:
+ new_scope = Scope(scope, node)
+ node_scope[id(node)] = new_scope
+ all_scopes.append(new_scope)
+ statements = node.get('body', [])
+ for index in range(len(statements) - 1, -1, -1):
+ push_target.append((statements[index], new_scope))
+
+ elif node_type == 'ForStatement':
+ new_scope = Scope(scope, node)
+ node_scope[id(node)] = new_scope
+ all_scopes.append(new_scope)
+ if node.get('body'):
+ push_target.append((node['body'], new_scope))
+ if node.get('init'):
+ push_target.append((node['init'], new_scope))
+
+ elif node_type == 'CatchClause':
+ catch_body = node.get('body')
+ if catch_body and catch_body.get('type') == 'BlockStatement':
+ catch_scope = Scope(scope, catch_body)
+ node_scope[id(catch_body)] = catch_scope
+ all_scopes.append(catch_scope)
+ param = node.get('param')
+ if param and param.get('type') == 'Identifier':
+ catch_scope.add_binding(param['name'], param, 'param')
+ statements = catch_body.get('body', [])
+ for index in range(len(statements) - 1, -1, -1):
+ push_target.append((statements[index], catch_scope))
+
+ else:
+ push_children_fn(node, scope, push_target)
+
+
+def _collect_references_iterative(ast: dict, root_scope: 'Scope', node_scope: dict) -> None:
+ """Iterative Pass 2: Collect references and assignments."""
+ _child_keys_map = _CHILD_KEYS
+ _get_child_keys = get_child_keys
+
+ ref_stack = [(ast, root_scope, None, None, None)]
+
+ while ref_stack:
+ node, scope, parent, parent_key, parent_index = ref_stack.pop()
+
+ if not _type(node) is _dict:
+ continue
+ node_type = node.get('type')
+ if node_type is None:
continue
- if isinstance(child, list):
- for item in child:
- if isinstance(item, dict) and 'type' in item:
- callback(item)
- elif isinstance(child, dict) and 'type' in child:
- callback(child)
+ node_id = id(node)
+ if node_id in node_scope:
+ scope = node_scope[node_id]
-def build_scope_tree(ast):
+ if node_type == 'Identifier':
+ name = node.get('name', '')
+ if _is_non_reference_identifier(parent, parent_key):
+ continue
+ binding = scope.get_binding(name)
+ if not binding:
+ continue
+ binding.references.append((node, parent, parent_key, parent_index))
+ if parent and parent.get('type') == 'AssignmentExpression' and parent_key == 'left':
+ binding.assignments.append(parent)
+ elif parent and parent.get('type') == 'UpdateExpression':
+ binding.assignments.append(parent)
+ continue
+
+ child_keys = _child_keys_map.get(node_type)
+ if child_keys is None:
+ child_keys = _get_child_keys(node)
+ for key in reversed(child_keys):
+ child = node.get(key)
+ if child is None:
+ continue
+ if _type(child) is _list:
+ for i in range(len(child) - 1, -1, -1):
+ item = child[i]
+ if _type(item) is _dict and 'type' in item:
+ ref_stack.append((item, scope, node, key, i))
+ elif _type(child) is _dict and 'type' in child:
+ ref_stack.append((child, scope, node, key, None))
+
+
+def build_scope_tree(ast: dict) -> tuple[Scope, dict[int, Scope]]:
"""Build a scope tree from an AST, collecting bindings and references.
Returns the root Scope and a dict mapping node id -> Scope.
+ Uses recursive traversal with automatic fallback to iterative for deep subtrees.
"""
root_scope = Scope(None, ast, is_function=True)
- # Maps id(node) -> scope for function/block scope nodes
- node_scope = {id(ast): root_scope}
- # We need to collect all declarations first, then references
- all_scopes = [root_scope]
-
- def _get_scope_for(node, current_scope):
- """Get or create the scope for a node."""
- node_id = id(node)
- if node_id in node_scope:
- return node_scope[node_id]
- return current_scope
+ node_scope: dict[int, Scope] = {id(ast): root_scope}
+ all_scopes: list[Scope] = [root_scope]
_child_keys_map = _CHILD_KEYS
+ _get_child_keys = get_child_keys
+ _max_depth = _MAX_RECURSIVE_DEPTH
+
+ # ---- Pass 1: Collect declarations (recursive with iterative fallback) ----
- def _collect_declarations(node, scope):
- """Walk the AST collecting variable declarations into scopes."""
- if not isinstance(node, dict):
+ def _push_children(node: dict, scope: Scope, target_list: list) -> None:
+ """Push child nodes onto a list."""
+ node_type = node['type']
+ child_keys = _child_keys_map.get(node_type)
+ if child_keys is None:
+ child_keys = _get_child_keys(node)
+ for key in reversed(child_keys):
+ child = node.get(key)
+ if child is None:
+ continue
+ if _type(child) is _list:
+ for i in range(len(child) - 1, -1, -1):
+ item = child[i]
+ if _type(item) is _dict and 'type' in item:
+ target_list.append((item, scope))
+ elif _type(child) is _dict and 'type' in child:
+ target_list.append((child, scope))
+
+ def _visit_declaration(node: dict, scope: Scope, depth: int) -> None:
+ if not _type(node) is _dict:
return
node_type = node.get('type')
if node_type is None:
return
- match node_type:
- case 'FunctionDeclaration' | 'FunctionExpression' | 'ArrowFunctionExpression':
- new_scope = Scope(scope, node, is_function=True)
- node_scope[id(node)] = new_scope
- all_scopes.append(new_scope)
-
- # Function name goes in outer scope (for declarations) or inner (for expressions)
- if node_type == 'FunctionDeclaration' and node.get('id'):
- scope.add_binding(node['id']['name'], node, 'function')
- elif node_type == 'FunctionExpression' and node.get('id'):
- new_scope.add_binding(node['id']['name'], node, 'function')
-
- # Params go in function scope
- for param in node.get('params', []):
- match param.get('type'):
- case 'Identifier':
- new_scope.add_binding(param['name'], param, 'param')
- case 'AssignmentPattern' if param.get('left', {}).get('type') == 'Identifier':
- new_scope.add_binding(param['left']['name'], param, 'param')
- case 'RestElement':
- arg = param.get('argument')
- if arg and arg.get('type') == 'Identifier':
- new_scope.add_binding(arg['name'], param, 'param')
-
- # Body - use the new scope
- body = node.get('body')
- if not body:
- return
- if isinstance(body, dict) and body.get('type') == 'BlockStatement':
- node_scope[id(body)] = new_scope
- for statement in body.get('body', []):
- _collect_declarations(statement, new_scope)
- else:
- _collect_declarations(body, new_scope)
-
- case 'ClassExpression' | 'ClassDeclaration':
- class_id = node.get('id')
- inner_scope = scope
- if class_id and class_id.get('type') == 'Identifier':
- name = class_id['name']
- if node_type == 'ClassDeclaration':
- scope.add_binding(name, node, 'function')
- else:
- inner_scope = Scope(scope, node)
- node_scope[id(node)] = inner_scope
- all_scopes.append(inner_scope)
- inner_scope.add_binding(name, node, 'function')
- body = node.get('body')
- if body:
- _collect_declarations(body, inner_scope)
- superclass = node.get('superClass')
- if superclass:
- _collect_declarations(superclass, scope)
-
- case 'VariableDeclaration':
- kind = node.get('kind', 'var')
- target_scope = (_nearest_function_scope(scope) or scope) if kind == 'var' else scope
- for declaration in node.get('declarations', []):
- declaration_id = declaration.get('id')
- if declaration_id and declaration_id.get('type') == 'Identifier':
- target_scope.add_binding(declaration_id['name'], declaration, kind)
- _collect_pattern_names(declaration_id, target_scope, kind, declaration)
- init = declaration.get('init')
- if init:
- _collect_declarations(init, scope)
-
- case 'BlockStatement' if id(node) not in node_scope:
- new_scope = Scope(scope, node)
- node_scope[id(node)] = new_scope
- all_scopes.append(new_scope)
- for statement in node.get('body', []):
- _collect_declarations(statement, new_scope)
-
- case 'ForStatement':
- new_scope = Scope(scope, node)
- node_scope[id(node)] = new_scope
- all_scopes.append(new_scope)
- if node.get('init'):
- _collect_declarations(node['init'], new_scope)
- if node.get('body'):
- _collect_declarations(node['body'], new_scope)
-
- case 'CatchClause':
- catch_body = node.get('body')
- if catch_body and catch_body.get('type') == 'BlockStatement':
- catch_scope = Scope(scope, catch_body)
- node_scope[id(catch_body)] = catch_scope
- all_scopes.append(catch_scope)
- param = node.get('param')
- if param and param.get('type') == 'Identifier':
- catch_scope.add_binding(param['name'], param, 'param')
- for statement in catch_body.get('body', []):
- _collect_declarations(statement, catch_scope)
-
- case _:
- _recurse_into_children(
- node, _child_keys_map, lambda child_node: _collect_declarations(child_node, scope)
- )
-
- def _collect_pattern_names(pattern, scope, kind, declaration):
- """Collect binding names from destructuring patterns."""
- if not isinstance(pattern, dict):
+ if depth > _max_depth:
+ # Fall back to iterative for this subtree
+ _collect_declarations_iterative_from(node, scope)
return
- match pattern.get('type', ''):
- case 'ArrayPattern':
- for element in pattern.get('elements', []):
- if not element:
- continue
- if element.get('type') == 'Identifier':
- scope.add_binding(element['name'], declaration, kind)
- else:
- _collect_pattern_names(element, scope, kind, declaration)
- case 'ObjectPattern':
- for property_node in pattern.get('properties', []):
- value_node = property_node.get('value', property_node.get('argument'))
- if not value_node:
- continue
- if value_node.get('type') == 'Identifier':
- scope.add_binding(value_node['name'], declaration, kind)
- else:
- _collect_pattern_names(value_node, scope, kind, declaration)
- case 'RestElement':
- argument_node = pattern.get('argument')
- if argument_node and argument_node.get('type') == 'Identifier':
- scope.add_binding(argument_node['name'], declaration, kind)
- case 'AssignmentPattern':
- left = pattern.get('left')
- if left and left.get('type') == 'Identifier':
- scope.add_binding(left['name'], declaration, kind)
-
- _collect_declarations(ast, root_scope)
-
- # Second pass: collect references and assignments
- def _collect_references(node, scope, parent=None, parent_key=None, parent_index=None):
- if not isinstance(node, dict):
+
+ # Collect children into a local list, then recurse
+ children: list = []
+ _process_declaration_node(node, node_type, scope, node_scope, all_scopes, children, _push_children)
+ next_depth = depth + 1
+ # Children were appended in stack order (reversed), so iterate
+ # in reverse to get left-to-right processing order.
+ for index in range(len(children) - 1, -1, -1):
+ _visit_declaration(children[index][0], children[index][1], next_depth)
+
+ def _collect_declarations_iterative_from(start_node: dict, start_scope: Scope) -> None:
+ """Run iterative declaration collection starting from a specific node/scope."""
+ decl_stack = [(start_node, start_scope)]
+ while decl_stack:
+ node, scope = decl_stack.pop()
+ if not _type(node) is _dict:
+ continue
+ node_type = node.get('type')
+ if node_type is None:
+ continue
+ _process_declaration_node(
+ node, node_type, scope, node_scope, all_scopes, decl_stack, _push_children
+ )
+
+ _visit_declaration(ast, root_scope, 0)
+
+ # ---- Pass 2: Collect references (recursive with iterative fallback) ----
+
+ def _visit_reference(
+ node: dict,
+ scope: Scope,
+ parent: dict | None,
+ parent_key: str | None,
+ parent_index: int | None,
+ depth: int,
+ ) -> None:
+ if not _type(node) is _dict:
return
- node_type = node.get('type')
+ node_type = node.get('type') # not all dicts have 'type'
if node_type is None:
return
- # Look up scope for this node
- scope = _get_scope_for(node, scope)
+ node_id = id(node)
+ if node_id in node_scope:
+ scope = node_scope[node_id]
if node_type == 'Identifier':
name = node.get('name', '')
if _is_non_reference_identifier(parent, parent_key):
return
-
binding = scope.get_binding(name)
if not binding:
return
@@ -288,22 +440,66 @@ def _collect_references(node, scope, parent=None, parent_key=None, parent_index=
binding.assignments.append(parent)
return
- # Recurse — can't use _recurse_into_children here because we need
- # per-child (key, index) args for reference tracking
+ if depth > _max_depth:
+ _collect_references_iterative_from(node, scope)
+ return
+
child_keys = _child_keys_map.get(node_type)
if child_keys is None:
- child_keys = get_child_keys(node)
+ child_keys = _get_child_keys(node)
+ next_depth = depth + 1
for key in child_keys:
child = node.get(key)
if child is None:
continue
- if isinstance(child, list):
- for i, item in enumerate(child):
- if isinstance(item, dict) and 'type' in item:
- _collect_references(item, scope, node, key, i)
- elif isinstance(child, dict) and 'type' in child:
- _collect_references(child, scope, node, key, None)
-
- _collect_references(ast, root_scope)
+ if _type(child) is _list:
+ for child_index, item in enumerate(child):
+ if _type(item) is _dict and 'type' in item:
+ _visit_reference(item, scope, node, key, child_index, next_depth)
+ elif _type(child) is _dict and 'type' in child:
+ _visit_reference(child, scope, node, key, None, next_depth)
+
+ def _collect_references_iterative_from(start_node: dict, start_scope: Scope) -> None:
+ """Run iterative reference collection starting from a specific node/scope."""
+ ref_stack = [(start_node, start_scope, None, None, None)]
+ while ref_stack:
+ node, scope, parent, parent_key, parent_index = ref_stack.pop()
+ if not _type(node) is _dict:
+ continue
+ node_type = node.get('type')
+ if node_type is None:
+ continue
+ node_id = id(node)
+ if node_id in node_scope:
+ scope = node_scope[node_id]
+ if node_type == 'Identifier':
+ name = node.get('name', '')
+ if _is_non_reference_identifier(parent, parent_key):
+ continue
+ binding = scope.get_binding(name)
+ if not binding:
+ continue
+ binding.references.append((node, parent, parent_key, parent_index))
+ if parent and parent.get('type') == 'AssignmentExpression' and parent_key == 'left':
+ binding.assignments.append(parent)
+ elif parent and parent.get('type') == 'UpdateExpression':
+ binding.assignments.append(parent)
+ continue
+ child_keys = _child_keys_map.get(node_type)
+ if child_keys is None:
+ child_keys = _get_child_keys(node)
+ for key in reversed(child_keys):
+ child = node.get(key)
+ if child is None:
+ continue
+ if _type(child) is _list:
+ for i in range(len(child) - 1, -1, -1):
+ item = child[i]
+ if _type(item) is _dict and 'type' in item:
+ ref_stack.append((item, scope, node, key, i))
+ elif _type(child) is _dict and 'type' in child:
+ ref_stack.append((child, scope, node, key, None))
+
+ _visit_reference(ast, root_scope, None, None, None, 0)
return root_scope, node_scope
diff --git a/pyjsclear/transforms/aa_decode.py b/pyjsclear/transforms/aa_decode.py
index 6f8d042..cec51a3 100644
--- a/pyjsclear/transforms/aa_decode.py
+++ b/pyjsclear/transforms/aa_decode.py
@@ -42,7 +42,7 @@
]
-def is_aa_encoded(code):
+def is_aa_encoded(code: str) -> bool:
"""Check if *code* looks like AAEncoded JavaScript.
Returns True when the characteristic execution pattern is found.
@@ -52,7 +52,7 @@ def is_aa_encoded(code):
return _SIGNATURE in code
-def aa_decode(code):
+def aa_decode(code: str) -> str | None:
"""Decode AAEncoded JavaScript.
Returns the decoded source string, or ``None`` on any failure.
@@ -72,7 +72,7 @@ def aa_decode(code):
# ---------------------------------------------------------------------------
-def _decode_impl(code):
+def _decode_impl(code: str) -> str | None:
"""Core decoding logic."""
# 1. Isolate the data section.
# AAEncode wraps data inside an execution pattern. The encoded payload
@@ -84,8 +84,8 @@ def _decode_impl(code):
# Find the data region: everything after the initial variable setup and
# before the trailing execution portion.
# The data starts at the first separator token.
- sep_idx = code.find(_SEPARATOR)
- if sep_idx == -1:
+ separator_index = code.find(_SEPARATOR)
+ if separator_index == -1:
return None
# The trailing execution wrapper varies but typically looks like:
@@ -95,16 +95,16 @@ def _decode_impl(code):
"(\uff9f\u0414\uff9f)['_']",
'(\uff9f\u0414\uff9f)["_"]',
]
- data = code[sep_idx:]
- for pat in tail_patterns:
- tail_pos = data.rfind(pat)
- if tail_pos != -1:
- data = data[:tail_pos]
+ data = code[separator_index:]
+ for tail_pattern in tail_patterns:
+ tail_position = data.rfind(tail_pattern)
+ if tail_position != -1:
+ data = data[:tail_position]
break
# 2. Apply emoticon-to-digit replacements.
- for old, new in _REPLACEMENTS:
- data = data.replace(old, new)
+ for original, replacement in _REPLACEMENTS:
+ data = data.replace(original, replacement)
# 3. Split on the separator to get individual character segments.
segments = data.split(_SEPARATOR)
diff --git a/pyjsclear/transforms/anti_tamper.py b/pyjsclear/transforms/anti_tamper.py
index 7781461..13d027d 100644
--- a/pyjsclear/transforms/anti_tamper.py
+++ b/pyjsclear/transforms/anti_tamper.py
@@ -38,7 +38,7 @@ class AntiTamperRemover(Transform):
]
@staticmethod
- def _extract_iife_call(expr):
+ def _extract_iife_call(expr: dict) -> dict | None:
"""Extract a CallExpression from an IIFE pattern."""
if expr.get('type') == 'CallExpression':
return expr
@@ -46,23 +46,20 @@ def _extract_iife_call(expr):
return expr.get('argument')
return None
- def _matches_anti_tamper_pattern(self, src):
+ def _matches_anti_tamper_pattern(self, source: str) -> bool:
"""Check if source matches any anti-tamper pattern."""
- for pattern in self._SELF_DEFENDING_PATTERNS:
- if pattern.search(src):
- return True
- if any(p.search(src) for p in self._DEBUG_PATTERNS):
- if re.search(r'\bdebugger\b', src) and (re.search(r'\bwhile\b|\bfor\b|\bsetInterval\b', src)):
- return True
- for pattern in self._CONSOLE_PATTERNS:
- if pattern.search(src):
- return True
+ if any(pattern.search(source) for pattern in self._SELF_DEFENDING_PATTERNS):
+ return True
+ if re.search(r'\bdebugger\b', source) and re.search(r'\bwhile\b|\bfor\b|\bsetInterval\b', source):
+ return True
+ if any(pattern.search(source) for pattern in self._CONSOLE_PATTERNS):
+ return True
return False
- def execute(self):
+ def execute(self) -> bool:
nodes_to_remove = []
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict, key: str, index: int | None) -> None:
if node.get('type') != 'ExpressionStatement':
return
expr = node.get('expression')
@@ -94,9 +91,9 @@ def enter(node, parent, key, index):
# Remove flagged nodes
if nodes_to_remove:
- remove_set = set(id(n) for n in nodes_to_remove)
+ remove_set = {id(node) for node in nodes_to_remove}
- def remover(node, parent, key, index):
+ def remover(node: dict, parent: dict, key: str, index: int | None) -> object | None:
if id(node) in remove_set:
self.set_changed()
return REMOVE
diff --git a/pyjsclear/transforms/base.py b/pyjsclear/transforms/base.py
index 35219dc..8a9d00d 100644
--- a/pyjsclear/transforms/base.py
+++ b/pyjsclear/transforms/base.py
@@ -1,5 +1,14 @@
"""Base transform class."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from ..traverser import build_parent_map
+
+if TYPE_CHECKING:
+ from ..scope import Scope
+
class Transform:
"""Base class for all AST transforms."""
@@ -7,18 +16,43 @@ class Transform:
# Subclasses can set this to True to trigger scope rebuild after execution
rebuild_scope = False
- def __init__(self, ast, scope_tree=None, node_scope=None):
+ def __init__(
+ self,
+ ast: dict,
+ scope_tree: Scope | None = None,
+ node_scope: dict[int, Scope] | None = None,
+ ) -> None:
self.ast = ast
self.scope_tree = scope_tree
self.node_scope = node_scope
self._changed = False
+ self._parent_map = None
- def execute(self):
+ def execute(self) -> bool:
"""Execute the transform. Returns True if the AST was modified."""
raise NotImplementedError
- def set_changed(self):
+ def set_changed(self) -> None:
self._changed = True
- def has_changed(self):
+ def has_changed(self) -> bool:
return self._changed
+
+ def get_parent_map(self):
+ """Lazily build and return a parent map for the AST.
+
+ Returns dict mapping id(node) -> (parent, key, index).
+ Call invalidate_parent_map() after AST modifications.
+ """
+ if self._parent_map is None:
+ self._parent_map = build_parent_map(self.ast)
+ return self._parent_map
+
+ def invalidate_parent_map(self):
+ """Invalidate the cached parent map after AST modifications."""
+ self._parent_map = None
+
+ def find_parent(self, target_node):
+ """Find the parent of a node using the parent map. Returns (parent, key, index) or None."""
+ pm = self.get_parent_map()
+ return pm.get(id(target_node))
diff --git a/pyjsclear/transforms/class_static_resolver.py b/pyjsclear/transforms/class_static_resolver.py
index 78bca39..4080a1e 100644
--- a/pyjsclear/transforms/class_static_resolver.py
+++ b/pyjsclear/transforms/class_static_resolver.py
@@ -11,7 +11,6 @@
... C.id(expr) ... → ... expr ...
"""
-from ..traverser import find_parent
from ..traverser import simple_traverse
from ..traverser import traverse
from ..utils.ast_helpers import deep_copy
@@ -24,7 +23,7 @@
class ClassStaticResolver(Transform):
"""Inline class static constant properties and identity methods."""
- def execute(self):
+ def execute(self) -> bool:
# Step 1: Find class variables (var X = class { ... })
class_vars = {} # name -> ClassExpression node
@@ -113,6 +112,9 @@ def collect_static_props(node, parent):
return False
# Step 4: Replace accesses
+ # Build parent map once before traversal
+ self.get_parent_map()
+
def enter(node, parent, key, index):
if node.get('type') != 'MemberExpression':
return
@@ -141,12 +143,14 @@ def enter(node, parent, key, index):
# Try identity method inlining
if pair in static_methods:
- self._try_inline_identity(node, static_methods[pair])
+ self._try_inline_identity(node, static_methods[pair], parent, key, index)
traverse(self.ast, {'enter': enter})
+ # Invalidate parent map once after all replacements
+ self.invalidate_parent_map()
return self.has_changed()
- def _get_prop_name(self, member_expr):
+ def _get_prop_name(self, member_expr: dict) -> str | None:
"""Get the property name from a MemberExpression."""
prop = member_expr.get('property')
if not prop:
@@ -159,7 +163,7 @@ def _get_prop_name(self, member_expr):
return prop['name']
return None
- def _is_identity_function(self, func_node):
+ def _is_identity_function(self, func_node: dict) -> bool:
"""Check if a function simply returns its first argument."""
params = func_node.get('params', [])
if len(params) != 1:
@@ -173,32 +177,34 @@ def _is_identity_function(self, func_node):
stmts = body.get('body', [])
if len(stmts) != 1 or stmts[0].get('type') != 'ReturnStatement':
return False
- arg = stmts[0].get('argument')
- if not arg or not is_identifier(arg):
+ return_argument = stmts[0].get('argument')
+ if not return_argument or not is_identifier(return_argument):
return False
- return arg['name'] == param['name']
+ return return_argument['name'] == param['name']
- def _try_inline_identity(self, member_expr, method_node):
- """Inline Class.identity(arg) → arg."""
- result = find_parent(self.ast, member_expr)
- if not result:
- return
- call_parent, call_key, call_index = result
- if not call_parent or call_parent.get('type') != 'CallExpression' or call_key != 'callee':
+ def _try_inline_identity(self, member_expr: dict, method_node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
+ """Inline Class.identity(arg) → arg.
+
+ parent/key/index refer to the MemberExpression's parent (from the enter callback).
+ The MemberExpression should be the callee of a CallExpression.
+ """
+ # parent is the CallExpression that contains this MemberExpression as callee
+ if not parent or parent.get('type') != 'CallExpression' or key != 'callee':
return
- args = call_parent.get('arguments', [])
+ args = parent.get('arguments', [])
if len(args) != 1:
return
replacement = deep_copy(args[0])
- # Replace the CallExpression with the argument
- grandparent_result = find_parent(self.ast, call_parent)
+ # Find grandparent of the CallExpression using cached parent map
+ pm = self.get_parent_map()
+ grandparent_result = pm.get(id(parent))
if not grandparent_result:
return
- gp, gp_key, gp_index = grandparent_result
- self._replace_in_parent(call_parent, replacement, gp, gp_key, gp_index)
+ grandparent, grandparent_key, grandparent_index = grandparent_result
+ self._replace_in_parent(parent, replacement, grandparent, grandparent_key, grandparent_index)
self.set_changed()
- def _replace_in_parent(self, target, replacement, parent, key, index):
+ def _replace_in_parent(self, target: dict, replacement: dict, parent: dict, key: str, index: int | None) -> None:
"""Replace target node in the AST using known parent info."""
if index is not None:
parent[key][index] = replacement
diff --git a/pyjsclear/transforms/class_string_decoder.py b/pyjsclear/transforms/class_string_decoder.py
index 42e9961..5061b34 100644
--- a/pyjsclear/transforms/class_string_decoder.py
+++ b/pyjsclear/transforms/class_string_decoder.py
@@ -33,9 +33,9 @@
class ClassStringDecoder(Transform):
"""Resolve class-based string encoder patterns."""
- def execute(self):
- class_props = {}
- decoders = {}
+ def execute(self) -> bool:
+ class_props: dict = {}
+ decoders: dict = {}
self._collect_class_props(class_props)
self._find_decoders(class_props, decoders)
@@ -48,7 +48,7 @@ def execute(self):
self._resolve_calls(decoders)
return self.has_changed()
- def _collect_class_props(self, class_props):
+ def _collect_class_props(self, class_props: dict) -> None:
"""Collect static property assignments on class variables.
Builds: class_props[var_name] = {prop_name: value, ...}
@@ -56,7 +56,7 @@ def _collect_class_props(self, class_props):
Handles assignments in ExpressionStatements and SequenceExpressions.
"""
- def visit(node, parent):
+ def visit(node: dict, parent: dict) -> None:
if node.get('type') != 'AssignmentExpression':
return
if node.get('operator') != '=':
@@ -78,24 +78,24 @@ def visit(node, parent):
simple_traverse(self.ast, visit)
- def _resolve_array(self, class_props, var_name, elements):
+ def _resolve_array(self, class_props: dict, var_name: str, elements: list) -> list | None:
"""Resolve an array of MemberExpression references to string values."""
props = class_props.get(var_name, {})
resolved = []
- for el in elements:
- el_obj, el_prop = get_member_names(el)
- if not el_obj or el_obj != var_name:
+ for element in elements:
+ element_object, element_property = get_member_names(element)
+ if not element_object or element_object != var_name:
return None
- value = props.get(el_prop)
+ value = props.get(element_property)
if not isinstance(value, str):
return None
resolved.append(value)
return resolved
- def _find_decoders(self, class_props, decoders):
+ def _find_decoders(self, class_props: dict, decoders: dict) -> None:
"""Find decoder methods and their associated lookup tables."""
- def visit(node, parent):
+ def visit(node: dict, parent: dict) -> None:
if node.get('type') != 'MethodDefinition':
return
if not node.get('static'):
@@ -111,21 +111,21 @@ def visit(node, parent):
case _:
return
- func = node.get('value')
- if not func or func.get('type') != 'FunctionExpression':
+ function_node = node.get('value')
+ if not function_node or function_node.get('type') != 'FunctionExpression':
return
- params = func.get('params', [])
+ params = function_node.get('params', [])
if len(params) != 1:
return
- body = func.get('body')
+ body = function_node.get('body')
if not body or body.get('type') != 'BlockStatement':
return
- stmts = body.get('body', [])
- if len(stmts) < 3:
+ statements = body.get('body', [])
+ if len(statements) < 3:
return
- table_info = self._extract_decoder_table(stmts, class_props)
+ table_info = self._extract_decoder_table(statements, class_props)
if not table_info:
return
@@ -139,12 +139,12 @@ def visit(node, parent):
simple_traverse(self.ast, visit)
- def _resolve_aliases(self, decoders):
+ def _resolve_aliases(self, decoders: dict) -> None:
"""Find identifier aliases (X = Y) where Y is a decoder class, and register X too."""
decoder_classes = {cls for cls, _ in decoders}
new_entries = {}
- def visit(node, parent):
+ def visit(node: dict, parent: dict) -> None:
if node.get('type') != 'AssignmentExpression':
return
if node.get('operator') != '=':
@@ -164,22 +164,23 @@ def visit(node, parent):
simple_traverse(self.ast, visit)
decoders.update(new_entries)
- def _extract_decoder_table(self, stmts, class_props):
+ def _extract_decoder_table(self, statements: list, class_props: dict) -> tuple | None:
"""Extract the lookup table and offset from decoder method body."""
table_class_var = None
table_prop = None
- for stmt in stmts:
- if stmt.get('type') != 'VariableDeclaration':
+ for statement in statements:
+ if statement.get('type') != 'VariableDeclaration':
continue
- for decl in stmt.get('declarations', []):
- init = decl.get('init')
+ for declaration in statement.get('declarations', []):
+ init = declaration.get('init')
obj_name, prop_name = get_member_names(init)
- if obj_name and prop_name:
- decl_id = decl.get('id')
- if decl_id and decl_id.get('type') == 'Identifier':
- table_class_var = obj_name
- table_prop = prop_name
+ if not obj_name or not prop_name:
+ continue
+ declaration_id = declaration.get('id')
+ if declaration_id and declaration_id.get('type') == 'Identifier':
+ table_class_var = obj_name
+ table_prop = prop_name
if not table_class_var:
return None
@@ -193,14 +194,14 @@ def _extract_decoder_table(self, stmts, class_props):
if not resolved:
return None
- offset = self._find_offset(stmts)
+ offset = self._find_offset(statements)
return resolved, offset
- def _find_offset(self, stmts):
+ def _find_offset(self, statements: list) -> int:
"""Find the subtraction offset in the decoder loop (e.g., - 48)."""
offset = 48
- def scan(node, parent):
+ def scan(node: dict, parent: dict) -> None:
nonlocal offset
if not isinstance(node, dict):
return
@@ -211,17 +212,17 @@ def scan(node, parent):
if isinstance(val, (int, float)) and val > 0:
offset = int(val)
- for stmt in stmts:
- if stmt.get('type') == 'ForStatement':
- simple_traverse(stmt, scan)
+ for statement in statements:
+ if statement.get('type') == 'ForStatement':
+ simple_traverse(statement, scan)
return offset
- def _find_enclosing_class_var(self, method_node):
+ def _find_enclosing_class_var(self, method_node: dict) -> str | None:
"""Find the variable name of the class containing this method."""
result = [None]
- def _check_class_body(class_expr, var_name):
+ def _check_class_body(class_expr: dict, var_name: str) -> bool:
body = class_expr.get('body')
if body and body.get('type') == 'ClassBody':
for member in body.get('body', []):
@@ -230,15 +231,15 @@ def _check_class_body(class_expr, var_name):
return True
return False
- def scan(node, parent):
+ def scan(node: dict, parent: dict) -> None:
if result[0]:
return
if node.get('type') == 'VariableDeclarator':
init = node.get('init')
if init and init.get('type') == 'ClassExpression':
- decl_id = node.get('id')
- if decl_id and decl_id.get('type') == 'Identifier':
- _check_class_body(init, decl_id['name'])
+ declaration_id = node.get('id')
+ if declaration_id and declaration_id.get('type') == 'Identifier':
+ _check_class_body(init, declaration_id['name'])
elif node.get('type') == 'AssignmentExpression':
right = node.get('right')
if right and right.get('type') == 'ClassExpression':
@@ -249,7 +250,7 @@ def scan(node, parent):
simple_traverse(self.ast, scan)
return result[0]
- def _decode_call(self, lookup_table, offset, args):
+ def _decode_call(self, lookup_table: list, offset: int, args: list) -> str | None:
"""Statically evaluate a decoder call: decode([0x4f, 0x3a, ...])."""
if len(args) != 1:
return None
@@ -258,10 +259,10 @@ def _decode_call(self, lookup_table, offset, args):
return None
elements = arg.get('elements', [])
result = ''
- for el in elements:
- if not is_numeric_literal(el):
+ for element in elements:
+ if not is_numeric_literal(element):
return None
- idx = int(el['value']) - offset
+ idx = int(element['value']) - offset
if idx < 0 or idx >= len(lookup_table):
return None
entry = lookup_table[idx]
@@ -270,34 +271,34 @@ def _decode_call(self, lookup_table, offset, args):
result += entry[0]
return result
- def _resolve_calls(self, decoders):
+ def _resolve_calls(self, decoders: dict) -> None:
"""Replace all decoder calls with their decoded string literals."""
- decoded_constants = {}
+ decoded_constants: dict = {}
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict, key: str, index: int | None) -> dict | None:
if node.get('type') != 'CallExpression':
- return
+ return None
callee = node.get('callee')
obj_name, method_name = get_member_names(callee)
if not obj_name:
- return
+ return None
decoder_key = (obj_name, method_name)
if decoder_key not in decoders:
- return
+ return None
lookup_table, offset = decoders[decoder_key]
decoded = self._decode_call(lookup_table, offset, node.get('arguments', []))
if decoded is None:
- return
+ return None
replacement = make_literal(decoded)
# Track the assignment target so we can inline the constant later
if parent and parent.get('type') == 'AssignmentExpression' and key == 'right':
- lobj, lprop = get_member_names(parent.get('left'))
- if lobj and lprop:
- decoded_constants[(lobj, lprop)] = decoded
+ left_object, left_property = get_member_names(parent.get('left'))
+ if left_object and left_property:
+ decoded_constants[(left_object, left_property)] = decoded
self.set_changed()
return replacement
@@ -307,23 +308,23 @@ def enter(node, parent, key, index):
if decoded_constants:
self._inline_decoded_constants(decoded_constants)
- def _inline_decoded_constants(self, decoded_constants):
+ def _inline_decoded_constants(self, decoded_constants: dict) -> None:
"""Replace references like _0x279589["propName"] with the decoded string."""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict, key: str, index: int | None) -> dict | None:
if node.get('type') != 'MemberExpression':
- return
+ return None
# Skip assignment targets
if parent and parent.get('type') == 'AssignmentExpression' and key == 'left':
- return
+ return None
obj_name, prop_name = get_member_names(node)
if not obj_name:
- return
+ return None
lookup_key = (obj_name, prop_name)
if lookup_key not in decoded_constants:
- return
+ return None
decoded = decoded_constants[lookup_key]
self.set_changed()
diff --git a/pyjsclear/transforms/cleanup.py b/pyjsclear/transforms/cleanup.py
index e79dd17..3219dcd 100644
--- a/pyjsclear/transforms/cleanup.py
+++ b/pyjsclear/transforms/cleanup.py
@@ -22,9 +22,9 @@ class EmptyIfRemover(Transform):
- ``if (expr) {} else { body }`` → ``if (!expr) { body }``
"""
- def execute(self):
+ def execute(self) -> bool:
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object:
if node.get('type') != 'IfStatement':
return
consequent = node.get('consequent')
@@ -52,14 +52,14 @@ def enter(node, parent, key, index):
return self.has_changed()
@staticmethod
- def _is_empty_block(node):
+ def _is_empty_block(node: object) -> bool:
"""Check if a node is an empty block statement ``{}``."""
if not isinstance(node, dict):
return False
if node.get('type') != 'BlockStatement':
return False
body = node.get('body')
- return not body or len(body) == 0
+ return not body
class TrailingReturnRemover(Transform):
@@ -71,20 +71,24 @@ class TrailingReturnRemover(Transform):
_FUNC_TYPES = frozenset({'FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'})
- def execute(self):
+ def execute(self) -> bool:
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') not in self._FUNC_TYPES:
return
body = node.get('body')
if not isinstance(body, dict) or body.get('type') != 'BlockStatement':
return
- stmts = body.get('body')
- if not stmts or not isinstance(stmts, list):
+ statements = body.get('body')
+ if not statements or not isinstance(statements, list):
return
- last = stmts[-1]
- if isinstance(last, dict) and last.get('type') == 'ReturnStatement' and last.get('argument') is None:
- stmts.pop()
+ last_statement = statements[-1]
+ if (
+ isinstance(last_statement, dict)
+ and last_statement.get('type') == 'ReturnStatement'
+ and last_statement.get('argument') is None
+ ):
+ statements.pop()
self.set_changed()
traverse(self.ast, {'enter': enter})
@@ -94,9 +98,9 @@ def enter(node, parent, key, index):
class OptionalCatchBinding(Transform):
"""Remove unused catch clause parameters (ES2019 optional catch binding)."""
- def execute(self):
+ def execute(self) -> bool:
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'CatchClause':
return
param = node.get('param')
@@ -114,32 +118,33 @@ def enter(node, parent, key, index):
traverse(self.ast, {'enter': enter})
return self.has_changed()
- def _is_name_used(self, body, name):
+ def _is_name_used(self, body: dict, name: str) -> bool:
"""Check if an identifier name is used anywhere in the subtree."""
- found = [False]
+ found = False
- def cb(node, parent):
- if found[0]:
+ def callback(node: dict, parent: dict | None) -> None:
+ nonlocal found
+ if found:
return
if is_identifier(node) and node.get('name') == name:
- found[0] = True
+ found = True
- simple_traverse(body, cb)
- return found[0]
+ simple_traverse(body, callback)
+ return found
class ReturnUndefinedCleanup(Transform):
"""Simplify `return undefined;` to `return;`."""
- def execute(self):
+ def execute(self) -> bool:
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'ReturnStatement':
return
- arg = node.get('argument')
- if not arg:
+ argument = node.get('argument')
+ if not argument:
return
- if is_identifier(arg) and arg.get('name') == 'undefined':
+ if is_identifier(argument) and argument.get('name') == 'undefined':
node['argument'] = None
self.set_changed()
@@ -156,28 +161,31 @@ class LetToConst(Transform):
- The binding has no assignments after declaration
"""
- def execute(self):
- scope_tree, _ = build_scope_tree(self.ast)
- safe_declarators = set()
+ def execute(self) -> bool:
+ if self.scope_tree is not None:
+ scope_tree = self.scope_tree
+ else:
+ scope_tree, _ = build_scope_tree(self.ast)
+ safe_declarators: set[int] = set()
self._collect_let_const_candidates(scope_tree, safe_declarators)
if not safe_declarators:
return False
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'VariableDeclaration':
return
if node.get('kind') != 'let':
return
- decls = node.get('declarations', [])
- if len(decls) == 1 and id(decls[0]) in safe_declarators:
+ declarations = node.get('declarations', [])
+ if len(declarations) == 1 and id(declarations[0]) in safe_declarators:
node['kind'] = 'const'
self.set_changed()
traverse(self.ast, {'enter': enter})
return self.has_changed()
- def _collect_let_const_candidates(self, scope, safe_declarators):
+ def _collect_let_const_candidates(self, scope: object, safe_declarators: set[int]) -> None:
"""Find let bindings that are never reassigned and have initializers."""
for name, binding in scope.bindings.items():
if binding.kind != 'let':
@@ -206,19 +214,22 @@ class VarToConst(Transform):
but const is block-scoped
"""
- def execute(self):
- scope_tree, _ = build_scope_tree(self.ast)
- safe_declarators = set()
+ def execute(self) -> bool:
+ if self.scope_tree is not None:
+ scope_tree = self.scope_tree
+ else:
+ scope_tree, _ = build_scope_tree(self.ast)
+ safe_declarators: set[int] = set()
self._collect_const_candidates(scope_tree, safe_declarators, in_function=True)
if not safe_declarators:
return False
# Track which BlockStatements are direct function bodies
- func_body_ids = set()
+ func_body_ids: set[int] = set()
self._collect_func_bodies(self.ast, func_body_ids)
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'VariableDeclaration':
return
if node.get('kind') != 'var':
@@ -234,26 +245,26 @@ def enter(node, parent, key, index):
return # Inside a nested block — unsafe
else:
return
- decls = node.get('declarations', [])
- if len(decls) == 1 and id(decls[0]) in safe_declarators:
+ declarations = node.get('declarations', [])
+ if len(declarations) == 1 and id(declarations[0]) in safe_declarators:
node['kind'] = 'const'
self.set_changed()
traverse(self.ast, {'enter': enter})
return self.has_changed()
- def _collect_func_bodies(self, ast, func_body_ids):
+ def _collect_func_bodies(self, ast: dict, func_body_ids: set[int]) -> None:
"""Collect ids of BlockStatements that are direct function bodies."""
- def cb(node, parent):
+ def callback(node: dict, parent: dict | None) -> None:
if node.get('type') in ('FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'):
body = node.get('body')
if body and body.get('type') == 'BlockStatement':
func_body_ids.add(id(body))
- simple_traverse(ast, cb)
+ simple_traverse(ast, callback)
- def _collect_const_candidates(self, scope, safe_declarators, in_function=False):
+ def _collect_const_candidates(self, scope: object, safe_declarators: set[int], in_function: bool = False) -> None:
"""Find var bindings that are never reassigned and have initializers."""
if in_function:
for name, binding in scope.bindings.items():
diff --git a/pyjsclear/transforms/constant_prop.py b/pyjsclear/transforms/constant_prop.py
index d218ecf..a6fcab9 100644
--- a/pyjsclear/transforms/constant_prop.py
+++ b/pyjsclear/transforms/constant_prop.py
@@ -1,5 +1,7 @@
"""Constant propagation — replace references to constant variables with their literal values."""
+from ..scope import Binding
+from ..scope import Scope
from ..scope import build_scope_tree
from ..traverser import REMOVE
from ..traverser import SKIP
@@ -9,16 +11,16 @@
from .base import Transform
-def _should_skip_reference(ref_parent, ref_key):
+def _should_skip_reference(reference_parent: dict | None, reference_key: str | None) -> bool:
"""Return True if this reference should not be replaced with its literal value."""
- if not ref_parent:
+ if not reference_parent:
return True
- match ref_parent.get('type'):
- case 'AssignmentExpression' if ref_key == 'left':
+ match reference_parent.get('type'):
+ case 'AssignmentExpression' if reference_key == 'left':
return True
case 'UpdateExpression':
return True
- case 'VariableDeclarator' if ref_key == 'id':
+ case 'VariableDeclarator' if reference_key == 'id':
return True
return False
@@ -28,8 +30,11 @@ class ConstantProp(Transform):
rebuild_scope = True
- def execute(self):
- scope_tree, node_scope = build_scope_tree(self.ast)
+ def execute(self) -> bool:
+ if self.scope_tree is not None:
+ scope_tree, node_scope = self.scope_tree, self.node_scope
+ else:
+ scope_tree, node_scope = build_scope_tree(self.ast)
replacements = dict(self._iter_constant_bindings(scope_tree))
if not replacements:
@@ -39,7 +44,9 @@ def execute(self):
self._remove_fully_propagated(replacements, bindings_replaced)
return self.has_changed()
- def _iter_constant_bindings(self, scope):
+ def _iter_constant_bindings(
+ self, scope: Scope
+ ) -> list[tuple[int, tuple[Binding, dict]]]:
"""Yield (binding_id, (binding, literal)) for constant bindings with literal values."""
for name, binding in scope.bindings.items():
if not binding.is_constant:
@@ -47,56 +54,58 @@ def _iter_constant_bindings(self, scope):
node = binding.node
if not isinstance(node, dict) or node.get('type') != 'VariableDeclarator':
continue
- init_val = node.get('init')
- if not init_val or not is_literal(init_val):
+ init_value = node.get('init')
+ if not init_value or not is_literal(init_value):
continue
- yield id(binding), (binding, init_val)
+ yield id(binding), (binding, init_value)
for child in scope.children:
yield from self._iter_constant_bindings(child)
- def _replace_references(self, replacements):
+ def _replace_references(self, replacements: dict[int, tuple[Binding, dict]]) -> set[int]:
"""Replace all qualifying references with their literal values."""
bindings_replaced = set()
- for bind_id, (binding, literal) in replacements.items():
- for ref_node, ref_parent, ref_key, ref_index in binding.references:
- if _should_skip_reference(ref_parent, ref_key):
+ for binding_id, (binding, literal) in replacements.items():
+ for reference_node, reference_parent, reference_key, reference_index in binding.references:
+ if _should_skip_reference(reference_parent, reference_key):
continue
new_node = deep_copy(literal)
- if ref_index is not None:
- ref_parent[ref_key][ref_index] = new_node
+ if reference_index is not None:
+ reference_parent[reference_key][reference_index] = new_node
else:
- ref_parent[ref_key] = new_node
+ reference_parent[reference_key] = new_node
self.set_changed()
- bindings_replaced.add(bind_id)
+ bindings_replaced.add(binding_id)
return bindings_replaced
- def _remove_fully_propagated(self, replacements, bindings_replaced):
+ def _remove_fully_propagated(
+ self, replacements: dict[int, tuple[Binding, dict]], bindings_replaced: set[int]
+ ) -> None:
"""Remove declarations whose bindings were fully propagated."""
- for bind_id in bindings_replaced:
- binding = replacements[bind_id][0]
+ for binding_id in bindings_replaced:
+ binding = replacements[binding_id][0]
if binding.assignments:
continue
- decl_node = binding.node
- if not isinstance(decl_node, dict):
+ declarator_node = binding.node
+ if not isinstance(declarator_node, dict):
continue
- if decl_node.get('type') != 'VariableDeclarator':
+ if declarator_node.get('type') != 'VariableDeclarator':
continue
- self._remove_declarator(decl_node)
+ self._remove_declarator(declarator_node)
- def _remove_declarator(self, declarator_node):
+ def _remove_declarator(self, declarator_node: dict) -> None:
"""Remove a VariableDeclarator from its parent VariableDeclaration."""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None):
if node.get('type') != 'VariableDeclaration':
return
- decls = node.get('declarations', [])
- for i, declaration in enumerate(decls):
+ declarations = node.get('declarations', [])
+ for i, declaration in enumerate(declarations):
if declaration is not declarator_node:
continue
- decls.pop(i)
+ declarations.pop(i)
self.set_changed()
- if not decls:
+ if not declarations:
return REMOVE
return SKIP
diff --git a/pyjsclear/transforms/control_flow.py b/pyjsclear/transforms/control_flow.py
index 06646a2..aecfaa6 100644
--- a/pyjsclear/transforms/control_flow.py
+++ b/pyjsclear/transforms/control_flow.py
@@ -7,10 +7,7 @@
And reconstructs the linear statement sequence.
"""
-from ..utils.ast_helpers import get_child_keys
-from ..utils.ast_helpers import is_identifier
-from ..utils.ast_helpers import is_literal
-from ..utils.ast_helpers import is_string_literal
+from ..utils.ast_helpers import get_child_keys, is_identifier, is_literal, is_string_literal
from .base import Transform
@@ -19,11 +16,11 @@ class ControlFlowRecoverer(Transform):
rebuild_scope = True
- def execute(self):
+ def execute(self) -> bool:
self._recover_in_bodies(self.ast)
return self.has_changed()
- def _recover_in_bodies(self, root):
+ def _recover_in_bodies(self, root: dict) -> None:
"""Walk through the AST looking for bodies containing CFF patterns."""
stack = [root]
visited = set()
@@ -47,7 +44,7 @@ def _recover_in_bodies(self, root):
self._queue_children(node, stack)
@staticmethod
- def _queue_children(node, stack):
+ def _queue_children(node: dict, stack: list) -> None:
"""Add all child nodes to the traversal stack."""
for key in get_child_keys(node):
child = node.get(key)
@@ -60,140 +57,140 @@ def _queue_children(node, stack):
elif isinstance(child, dict) and 'type' in child:
stack.append(child)
- def _try_recover_body(self, parent_node, body_key, body):
+ def _try_recover_body(self, parent_node: dict, body_key: str, body: list) -> None:
"""Try to find and recover CFF patterns in a body array."""
- i = 0
- while i < len(body):
- statement = body[i]
+ index = 0
+ while index < len(body):
+ statement = body[index]
if not isinstance(statement, dict):
- i += 1
+ index += 1
continue
- if self._try_recover_variable_pattern(body, i, statement):
+ if self._try_recover_variable_pattern(body, index, statement):
continue
- if self._try_recover_expression_pattern(body, i, statement):
+ if self._try_recover_expression_pattern(body, index, statement):
continue
- i += 1
+ index += 1
- def _try_recover_variable_pattern(self, body, i, stmt):
+ def _try_recover_variable_pattern(self, body: list, index: int, statement: dict) -> bool:
"""Try Pattern 1: VariableDeclaration with split + loop. Returns True if recovered."""
- if stmt.get('type') != 'VariableDeclaration':
+ if statement.get('type') != 'VariableDeclaration':
return False
- state_info = self._find_state_array_in_decl(stmt)
+ state_info = self._find_state_array_in_decl(statement)
if not state_info:
return False
states, state_var, counter_var = state_info
- next_idx = i + 1
- if next_idx >= len(body):
+ next_index = index + 1
+ if next_index >= len(body):
return False
- recovered = self._try_recover_from_loop(body[next_idx], states, state_var, counter_var)
+ recovered = self._try_recover_from_loop(body[next_index], states, state_var, counter_var)
if recovered is None:
return False
- body[i : next_idx + 1] = recovered
+ body[index : next_index + 1] = recovered
self.set_changed()
return True
- def _try_recover_expression_pattern(self, body, i, stmt):
+ def _try_recover_expression_pattern(self, body: list, index: int, statement: dict) -> bool:
"""Try Pattern 2: ExpressionStatement with split assignment + loop."""
- if stmt.get('type') != 'ExpressionStatement':
+ if statement.get('type') != 'ExpressionStatement':
return False
- expr = stmt.get('expression')
- if not expr or expr.get('type') != 'AssignmentExpression':
+ expression = statement.get('expression')
+ if not expression or expression.get('type') != 'AssignmentExpression':
return False
- state_info = self._find_state_from_assignment(expr)
+ state_info = self._find_state_from_assignment(expression)
if not state_info:
return False
states, state_var = state_info
- next_idx = i + 1
+ next_index = index + 1
counter_var = None
- if next_idx < len(body):
- counter_variable = self._find_counter_init(body[next_idx])
+ if next_index < len(body):
+ counter_variable = self._find_counter_init(body[next_index])
if counter_variable is not None:
counter_var = counter_variable
- next_idx += 1
- if next_idx >= len(body):
+ next_index += 1
+ if next_index >= len(body):
return False
- recovered = self._try_recover_from_loop(body[next_idx], states, state_var, counter_var or '_index')
+ recovered = self._try_recover_from_loop(body[next_index], states, state_var, counter_var or '_index')
if recovered is None:
return False
- body[i : next_idx + 1] = recovered
+ body[index : next_index + 1] = recovered
self.set_changed()
return True
- def _find_state_array_in_decl(self, decl):
+ def _find_state_array_in_decl(self, declaration: dict) -> tuple | None:
"""Find "X".split("|") pattern in a VariableDeclaration."""
- for declaration in decl.get('declarations', []):
- initializer = declaration.get('init')
+ for declarator in declaration.get('declarations', []):
+ initializer = declarator.get('init')
if not initializer or not self._is_split_call(initializer):
continue
states = self._extract_split_states(initializer)
if not states:
continue
- if declaration.get('id', {}).get('type') != 'Identifier':
+ if declarator.get('id', {}).get('type') != 'Identifier':
continue
- state_var = declaration['id']['name']
- counter_var = self._find_counter_in_declaration(decl, exclude=declaration)
+ state_var = declarator['id']['name']
+ counter_var = self._find_counter_in_declaration(declaration, exclude=declarator)
return states, state_var, counter_var
return None
- def _find_counter_in_declaration(self, decl, exclude):
+ def _find_counter_in_declaration(self, declaration: dict, exclude: dict) -> str | None:
"""Find a numeric-initialized counter variable in a declaration, skipping *exclude*."""
- for declaration in decl.get('declarations', []):
- if declaration is exclude:
+ for declarator in declaration.get('declarations', []):
+ if declarator is exclude:
continue
- if declaration.get('id', {}).get('type') != 'Identifier':
+ if declarator.get('id', {}).get('type') != 'Identifier':
continue
- initializer = declaration.get('init')
+ initializer = declarator.get('init')
if (
initializer
and initializer.get('type') == 'Literal'
and isinstance(initializer.get('value'), (int, float))
):
- return declaration['id']['name']
+ return declarator['id']['name']
return None
- def _find_state_from_assignment(self, expr):
+ def _find_state_from_assignment(self, expression: dict) -> tuple | None:
"""Find state array from assignment expression."""
- if expr.get('type') != 'AssignmentExpression':
+ if expression.get('type') != 'AssignmentExpression':
return None
- if not is_identifier(expr.get('left')):
+ if not is_identifier(expression.get('left')):
return None
- right = expr.get('right')
+ right = expression.get('right')
if self._is_split_call(right):
states = self._extract_split_states(right)
if states:
- return states, expr['left']['name']
+ return states, expression['left']['name']
return None
- def _find_counter_init(self, statement):
+ def _find_counter_init(self, statement: dict) -> str | None:
"""Find counter variable initialization."""
if not isinstance(statement, dict):
return None
match statement.get('type'):
case 'VariableDeclaration':
- for declaration in statement.get('declarations', []):
- if declaration.get('id', {}).get('type') == 'Identifier':
- initializer = declaration.get('init')
+ for declarator in statement.get('declarations', []):
+ if declarator.get('id', {}).get('type') == 'Identifier':
+ initializer = declarator.get('init')
if (
initializer
and initializer.get('type') == 'Literal'
and isinstance(initializer.get('value'), (int, float))
):
- return declaration['id']['name']
+ return declarator['id']['name']
case 'ExpressionStatement':
- expr = statement.get('expression')
+ expression = statement.get('expression')
if (
- expr
- and expr.get('type') == 'AssignmentExpression'
- and is_identifier(expr.get('left'))
- and is_literal(expr.get('right'))
- and isinstance(expr['right'].get('value'), (int, float))
+ expression
+ and expression.get('type') == 'AssignmentExpression'
+ and is_identifier(expression.get('left'))
+ and is_literal(expression.get('right'))
+ and isinstance(expression['right'].get('value'), (int, float))
):
- return expr['left']['name']
+ return expression['left']['name']
return None
- def _is_split_call(self, node):
+ def _is_split_call(self, node: dict) -> bool:
"""Check if node is "X".split("|")."""
if not isinstance(node, dict):
return False
@@ -210,36 +207,36 @@ def _is_split_call(self, node):
is_string_literal(property_expression) and property_expression.get('value') == 'split'
):
return False
- args = node.get('arguments', [])
- if len(args) != 1 or not is_string_literal(args[0]):
+ arguments = node.get('arguments', [])
+ if len(arguments) != 1 or not is_string_literal(arguments[0]):
return False
return True
- def _extract_split_states(self, node):
+ def _extract_split_states(self, node: dict) -> list:
"""Extract states from "1|0|3|2|4".split("|")."""
callee = node['callee']
string = callee['object']['value']
separator = node['arguments'][0]['value']
return string.split(separator)
- def _try_recover_from_loop(self, loop, states, state_var, counter_var):
+ def _try_recover_from_loop(
+ self, loop: dict, states: list, state_var: str, counter_var: str | None
+ ) -> list | None:
"""Try to recover statements from a for/while loop with switch dispatcher."""
if not isinstance(loop, dict):
return None
- loop_type = loop.get('type', '')
- switch_body = None
initial_value = 0
+ switch_body = None
- if loop_type == 'ForStatement':
- # for(var _i = 0; ...) { switch(_array[_i++]) { ... } break; }
- initial_value = self._extract_for_init_value(loop.get('init'))
- switch_body = self._extract_switch_from_loop_body(loop.get('body'))
-
- elif loop_type == 'WhileStatement':
- test = loop.get('test')
- if self._is_truthy(test):
+ match loop.get('type', ''):
+ case 'ForStatement':
+ # for(var _i = 0; ...) { switch(_array[_i++]) { ... } break; }
+ initial_value = self._extract_for_init_value(loop.get('init'))
switch_body = self._extract_switch_from_loop_body(loop.get('body'))
+ case 'WhileStatement':
+ if self._is_truthy(loop.get('test')):
+ switch_body = self._extract_switch_from_loop_body(loop.get('body'))
if switch_body is None:
return None
@@ -248,20 +245,20 @@ def _try_recover_from_loop(self, loop, states, state_var, counter_var):
return self._reconstruct_statements(cases_map, states, initial_value)
@staticmethod
- def _extract_for_init_value(initializer):
+ def _extract_for_init_value(initializer: dict | None) -> int:
"""Extract the initial counter value from a for-loop init clause."""
if not initializer:
return 0
if initializer.get('type') == 'VariableDeclaration':
- for declaration in initializer.get('declarations', []):
- if declaration.get('init') and declaration['init'].get('type') == 'Literal':
- return int(declaration['init'].get('value', 0))
+ for declarator in initializer.get('declarations', []):
+ if declarator.get('init') and declarator['init'].get('type') == 'Literal':
+ return int(declarator['init'].get('value', 0))
elif initializer.get('type') == 'AssignmentExpression' and is_literal(initializer.get('right')):
return int(initializer['right'].get('value', 0))
return 0
@staticmethod
- def _build_case_map(cases):
+ def _build_case_map(cases: list) -> dict:
"""Build map from case test value to (filtered statements, original statements)."""
cases_map = {}
for case in cases:
@@ -282,11 +279,11 @@ def _build_case_map(cases):
return cases_map
@staticmethod
- def _reconstruct_statements(cases_map, states, initial_value):
+ def _reconstruct_statements(cases_map: dict, states: list, initial_value: int) -> list | None:
"""Reconstruct linear statement sequence from case map and state order."""
recovered = []
- for idx in range(initial_value, len(states)):
- state = states[idx]
+ for index in range(initial_value, len(states)):
+ state = states[index]
if state not in cases_map:
break
statements, original = cases_map[state]
@@ -296,20 +293,20 @@ def _reconstruct_statements(cases_map, states, initial_value):
break
return recovered or None
- def _extract_switch_from_loop_body(self, body):
+ def _extract_switch_from_loop_body(self, body: dict | None) -> dict | None:
"""Extract SwitchStatement from loop body."""
if not isinstance(body, dict):
return None
if body.get('type') == 'BlockStatement':
- stmts = body.get('body', [])
- for stmt in stmts:
- if stmt.get('type') == 'SwitchStatement':
- return stmt
+ statements = body.get('body', [])
+ for statement in statements:
+ if statement.get('type') == 'SwitchStatement':
+ return statement
elif body.get('type') == 'SwitchStatement':
return body
return None
- def _is_truthy(self, node):
+ def _is_truthy(self, node: dict | None) -> bool:
"""Check if a test expression is always truthy."""
if not isinstance(node, dict):
return False
@@ -317,14 +314,14 @@ def _is_truthy(self, node):
return bool(node.get('value'))
# !0 = true, !![] = true
if node.get('type') == 'UnaryExpression' and node.get('operator') == '!':
- arg = node.get('argument')
- if arg and arg.get('type') == 'Literal' and arg.get('value') == 0:
+ argument = node.get('argument')
+ if argument and argument.get('type') == 'Literal' and argument.get('value') == 0:
return True
- if arg and arg.get('type') == 'ArrayExpression':
+ if argument and argument.get('type') == 'ArrayExpression':
return False # ![] = false, but !![] = true
- if arg and arg.get('type') == 'UnaryExpression' and arg.get('operator') == '!':
+ if argument and argument.get('type') == 'UnaryExpression' and argument.get('operator') == '!':
# !!something
- inner = arg.get('argument')
+ inner = argument.get('argument')
if inner and inner.get('type') == 'ArrayExpression':
return True
return False
diff --git a/pyjsclear/transforms/dead_branch.py b/pyjsclear/transforms/dead_branch.py
index ec7a9b0..7c222ea 100644
--- a/pyjsclear/transforms/dead_branch.py
+++ b/pyjsclear/transforms/dead_branch.py
@@ -5,7 +5,7 @@
from .base import Transform
-def _is_truthy_literal(node):
+def _is_truthy_literal(node: dict) -> bool | None:
"""Check if node is a literal that is truthy in JS. Returns None if unknown."""
if not isinstance(node, dict):
return None
@@ -34,15 +34,15 @@ def _is_truthy_literal(node):
case 'LogicalExpression':
left = _is_truthy_literal(node.get('left'))
right = _is_truthy_literal(node.get('right'))
- op = node.get('operator')
- if op == '&&':
+ operator = node.get('operator')
+ if operator == '&&':
# falsy && anything → falsy
if left is False:
return False
# truthy && right → right (if right is known)
if left is True and right is not None:
return right
- elif op == '||':
+ elif operator == '||':
# truthy || anything → truthy
if left is True:
return True
@@ -52,7 +52,7 @@ def _is_truthy_literal(node):
return None
-def _unwrap_block(node):
+def _unwrap_block(node: dict) -> dict:
"""Unwrap a single-statement block to its contents."""
if isinstance(node, dict) and node.get('type') == 'BlockStatement':
body = node.get('body', [])
@@ -64,8 +64,8 @@ def _unwrap_block(node):
class DeadBranchRemover(Transform):
"""Remove dead branches from if statements and ternary expressions."""
- def execute(self):
- def enter(node, parent, key, index):
+ def execute(self) -> bool:
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None:
node_type = node.get('type', '')
if node_type == 'IfStatement':
@@ -75,8 +75,8 @@ def enter(node, parent, key, index):
self.set_changed()
if truthy:
return node.get('consequent')
- alt = node.get('alternate')
- return alt if alt else REMOVE
+ alternate_branch = node.get('alternate')
+ return alternate_branch if alternate_branch else REMOVE
if node_type == 'ConditionalExpression':
truthy = _is_truthy_literal(node.get('test'))
diff --git a/pyjsclear/transforms/dead_class_props.py b/pyjsclear/transforms/dead_class_props.py
index 205135f..1c328fe 100644
--- a/pyjsclear/transforms/dead_class_props.py
+++ b/pyjsclear/transforms/dead_class_props.py
@@ -26,73 +26,73 @@
class DeadClassPropRemover(Transform):
"""Remove dead property assignments on class variables."""
- def execute(self):
+ def execute(self) -> bool:
# Step 1: Find class variable names, aliases, and class-id-to-name mapping
# in a single traversal
- class_vars = set()
- class_aliases = {} # inner_name -> outer_name
- reverse_aliases = {} # outer_name -> set of inner_names
- class_node_to_name = {} # id(ClassExpression node) -> outer name
+ class_vars: set[str] = set()
+ class_aliases: dict[str, str] = {} # inner_name -> outer_name
+ reverse_aliases: dict[str, set[str]] = {} # outer_name -> set of inner_names
+ class_node_to_name: dict[int, str] = {} # id(ClassExpression node) -> outer name
- def find_classes(node, parent):
+ def find_classes(node: dict, parent: dict | None) -> None:
if node.get('type') == 'VariableDeclarator':
init = node.get('init')
if init and init.get('type') == 'ClassExpression':
decl_id = node.get('id')
if decl_id and is_identifier(decl_id):
- outer = decl_id['name']
- class_vars.add(outer)
- class_node_to_name[id(init)] = outer
+ outer_name = decl_id['name']
+ class_vars.add(outer_name)
+ class_node_to_name[id(init)] = outer_name
class_id = init.get('id')
if class_id and is_identifier(class_id):
- inner = class_id['name']
- if inner != outer:
- class_vars.add(inner)
- class_aliases[inner] = outer
- reverse_aliases.setdefault(outer, set()).add(inner)
+ inner_name = class_id['name']
+ if inner_name != outer_name:
+ class_vars.add(inner_name)
+ class_aliases[inner_name] = outer_name
+ reverse_aliases.setdefault(outer_name, set()).add(inner_name)
elif node.get('type') == 'AssignmentExpression':
right = node.get('right')
if right and right.get('type') == 'ClassExpression':
left = node.get('left')
if left and is_identifier(left):
- outer = left['name']
- class_vars.add(outer)
- class_node_to_name[id(right)] = outer
+ outer_name = left['name']
+ class_vars.add(outer_name)
+ class_node_to_name[id(right)] = outer_name
class_id = right.get('id')
if class_id and is_identifier(class_id):
- inner = class_id['name']
- if inner != outer:
- class_vars.add(inner)
- class_aliases[inner] = outer
- reverse_aliases.setdefault(outer, set()).add(inner)
+ inner_name = class_id['name']
+ if inner_name != outer_name:
+ class_vars.add(inner_name)
+ class_aliases[inner_name] = outer_name
+ reverse_aliases.setdefault(outer_name, set()).add(inner_name)
simple_traverse(self.ast, find_classes)
if not class_vars:
return False
- def _normalize(obj_name):
+ def _normalize(obj_name: str) -> str:
"""Resolve class aliases to their canonical (outer) name."""
return class_aliases.get(obj_name, obj_name)
- def _has_standalone(name):
+ def _has_standalone(name: str) -> bool:
if standalone_refs.get(name, 0) > 0:
return True
canonical = class_aliases.get(name, name)
if standalone_refs.get(canonical, 0) > 0:
return True
- for inner in reverse_aliases.get(name, ()):
- if standalone_refs.get(inner, 0) > 0:
+ for inner_name in reverse_aliases.get(name, ()):
+ if standalone_refs.get(inner_name, 0) > 0:
return True
return False
# Step 2: Classify identifier references and collect this.prop reads
# in a single traversal
- member_refs = {v: 0 for v in class_vars}
- standalone_refs = {v: 0 for v in class_vars}
- this_reads = set()
+ member_refs: dict[str, int] = {var: 0 for var in class_vars}
+ standalone_refs: dict[str, int] = {var: 0 for var in class_vars}
+ this_reads: set[tuple[str, str]] = set()
- def classify_and_collect(node, parent):
+ def classify_and_collect(node: dict, parent: dict | None) -> None:
# Collect this.prop reads inside class bodies
if node.get('type') == 'ClassExpression':
class_name = class_node_to_name.get(id(node))
@@ -133,43 +133,46 @@ def classify_and_collect(node, parent):
else:
standalone_refs[name] = standalone_refs.get(name, 0) + 1
- def _collect_this_reads_in_class(class_node, class_name):
- def _inner(n, p):
- if n.get('type') != 'MemberExpression':
+ def _collect_this_reads_in_class(class_node: dict, class_name: str) -> None:
+ def _visit(node: dict, parent: dict | None) -> None:
+ if node.get('type') != 'MemberExpression':
return
- obj = n.get('object')
+ obj = node.get('object')
if not obj or obj.get('type') != 'ThisExpression':
return
- prop = n.get('property')
+ prop = node.get('property')
if not prop:
return
- if n.get('computed'):
+ if node.get('computed'):
if is_string_literal(prop):
this_reads.add((class_name, prop['value']))
elif is_identifier(prop):
this_reads.add((class_name, prop['name']))
- simple_traverse(class_node.get('body', {}), _inner)
+ simple_traverse(class_node.get('body', {}), _visit)
simple_traverse(self.ast, classify_and_collect)
# Classes with `this.prop` reads use their own properties — not fully dead
- classes_with_this = set()
+ classes_with_this: set[str] = set()
for name, prop in this_reads:
classes_with_this.add(name)
classes_with_this.add(class_aliases.get(name, name))
# Classes that never escape (only used as X.prop) — all their props are dead
- fully_dead_classes = {v for v in class_vars if not _has_standalone(v) and v not in classes_with_this}
+ fully_dead_classes = {
+ var for var in class_vars
+ if not _has_standalone(var) and var not in classes_with_this
+ }
# Classes that have escaped — skip individual prop dead-code analysis
- escaped_classes = {_normalize(v) for v in class_vars if _has_standalone(v)}
+ escaped_classes = {_normalize(var) for var in class_vars if _has_standalone(var)}
# Step 3: For non-fully-dead classes, find individually dead properties
- writes = set()
- reads = set()
+ writes: set[tuple[str, str]] = set()
+ reads: set[tuple[str, str]] = set()
- def count_prop_refs(node, parent):
+ def count_prop_refs(node: dict, parent: dict | None) -> None:
if node.get('type') != 'MemberExpression':
return
obj_name, prop_name = get_member_names(node)
@@ -192,16 +195,16 @@ def count_prop_refs(node, parent):
reads |= {(_normalize(name), prop) for name, prop in this_reads}
# Dead props: written but never read, OR belonging to fully dead classes
- dead_props = set()
+ dead_props: set[tuple[str, str]] = set()
for pair in writes:
if pair not in reads:
dead_props.add(pair)
# Collect all props of fully dead classes in a single traversal
if fully_dead_classes:
- fully_dead_canonical = {_normalize(v) for v in fully_dead_classes} | fully_dead_classes
+ fully_dead_canonical = {_normalize(var) for var in fully_dead_classes} | fully_dead_classes
- def collect_all_dead(node, parent):
+ def collect_all_dead(node: dict, parent: dict | None) -> None:
if node.get('type') != 'AssignmentExpression' or node.get('operator') != '=':
return
obj_name, prop_name = get_member_names(node.get('left'))
@@ -214,11 +217,11 @@ def collect_all_dead(node, parent):
return False
# Step 4: Remove dead assignment expressions
- def _is_dead(obj_name, prop_name):
+ def _is_dead(obj_name: str, prop_name: str) -> bool:
canonical = _normalize(obj_name)
return (canonical, prop_name) in dead_props or (obj_name, prop_name) in dead_props
- def remove_dead_stmts(node, parent, key, index):
+ def remove_dead_stmts(node: dict, parent: dict | None, key: str | None, index: int | None) -> object:
if node.get('type') != 'ExpressionStatement':
return
expr = node.get('expression')
@@ -232,13 +235,13 @@ def remove_dead_stmts(node, parent, key, index):
if expr.get('type') == 'SequenceExpression':
exprs = expr.get('expressions', [])
remaining = []
- for e in exprs:
- if e.get('type') == 'AssignmentExpression' and e.get('operator') == '=':
- obj_name, prop_name = get_member_names(e.get('left'))
+ for expression in exprs:
+ if expression.get('type') == 'AssignmentExpression' and expression.get('operator') == '=':
+ obj_name, prop_name = get_member_names(expression.get('left'))
if obj_name and _is_dead(obj_name, prop_name):
self.set_changed()
continue
- remaining.append(e)
+ remaining.append(expression)
if not remaining:
return REMOVE
if len(remaining) < len(exprs):
diff --git a/pyjsclear/transforms/dead_expressions.py b/pyjsclear/transforms/dead_expressions.py
index 0d7d17d..15df3d6 100644
--- a/pyjsclear/transforms/dead_expressions.py
+++ b/pyjsclear/transforms/dead_expressions.py
@@ -12,14 +12,14 @@ class DeadExpressionRemover(Transform):
like `(0, fn())`, and other numeric literal statements.
"""
- def execute(self):
- def enter(node, parent, key, index):
+ def execute(self) -> bool:
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object:
if node.get('type') != 'ExpressionStatement':
return
- expr = node.get('expression')
- if not isinstance(expr, dict) or expr.get('type') != 'Literal':
+ expression = node.get('expression')
+ if not isinstance(expression, dict) or expression.get('type') != 'Literal':
return
- value = expr.get('value')
+ value = expression.get('value')
# Only remove numeric literals (not strings/booleans/null/regex)
if isinstance(value, (int, float)) and not isinstance(value, bool):
self.set_changed()
diff --git a/pyjsclear/transforms/dead_object_props.py b/pyjsclear/transforms/dead_object_props.py
index 7e97beb..a6acbc0 100644
--- a/pyjsclear/transforms/dead_object_props.py
+++ b/pyjsclear/transforms/dead_object_props.py
@@ -46,61 +46,60 @@
class DeadObjectPropRemover(Transform):
"""Remove object property assignments where the property is never read."""
- def execute(self):
+ def execute(self) -> bool:
# Phase 1: Find all obj.PROP = value statements and all obj.PROP reads.
# Also track which objects "escape" (are assigned to external refs, passed as
# function arguments, or returned) — their properties may be read externally.
- writes = {} # (obj_name, prop_name) -> count
- reads = set() # set of (obj_name, prop_name)
- escaped = set() # set of obj_name that escape
+ writes: dict[tuple[str, str], int] = {} # (obj_name, prop_name) -> count
+ reads: set[tuple[str, str]] = set() # set of (obj_name, prop_name)
+ escaped: set[str] = set() # set of obj_name that escape
# Phase 0: Collect locally declared variable names (var/let/const).
# Only properties on locally declared objects are candidates for removal.
- local_vars = set()
+ local_vars: set[str] = set()
- def collect_locals(node, parent):
+ def collect_locals(node: dict, parent: dict | None) -> None:
if not isinstance(node, dict):
return
- t = node.get('type')
- if t == 'VariableDeclarator':
- vid = node.get('id')
- if vid and is_identifier(vid):
- local_vars.add(vid['name'])
+ node_type = node.get('type')
+ if node_type == 'VariableDeclarator':
+ variable_id = node.get('id')
+ if variable_id and is_identifier(variable_id):
+ local_vars.add(variable_id['name'])
# Function/arrow params are externally provided — mark as escaped
- if t in ('FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'):
+ if node_type in ('FunctionDeclaration', 'FunctionExpression', 'ArrowFunctionExpression'):
for param in node.get('params', []):
if is_identifier(param):
escaped.add(param['name'])
simple_traverse(self.ast, collect_locals)
- def collect(node, parent):
+ def collect(node: dict, parent: dict | None) -> None:
if not isinstance(node, dict):
return
- t = node.get('type')
+ node_type = node.get('type')
# Track identifiers that escape
- if t == 'Identifier' and parent:
+ if node_type == 'Identifier' and parent:
name = node.get('name', '')
if name in _GLOBAL_OBJECTS:
escaped.add(name)
- pt = parent.get('type')
+ parent_type = parent.get('type')
# RHS of assignment to a member (e.g., r.exports = obj)
- if pt == 'AssignmentExpression' and node is parent.get('right'):
+ if parent_type == 'AssignmentExpression' and node is parent.get('right'):
left = parent.get('left')
if left and left.get('type') == 'MemberExpression':
escaped.add(name)
# Function/method argument
- if pt == 'CallExpression' or pt == 'NewExpression':
- args = parent.get('arguments', [])
- if node in args:
+ if parent_type in ('CallExpression', 'NewExpression'):
+ if node in parent.get('arguments', []):
escaped.add(name)
# Return value
- if pt == 'ReturnStatement':
+ if parent_type == 'ReturnStatement':
escaped.add(name)
# Track member access patterns
- if t != 'MemberExpression':
+ if node_type != 'MemberExpression':
return
if node.get('computed'):
return
@@ -125,7 +124,7 @@ def collect(node, parent):
return False
# Phase 2: Remove dead assignment statements
- def remove_dead(node, parent, key, index):
+ def remove_dead(node: dict, parent: dict | None, key: str | None, index: int | None) -> object:
if node.get('type') != 'ExpressionStatement':
return
expr = node.get('expression')
@@ -139,12 +138,13 @@ def remove_dead(node, parent, key, index):
if not obj or not is_identifier(obj) or not prop or not is_identifier(prop):
return
pair = (obj['name'], prop['name'])
- if pair in dead_props:
- # Only remove if the RHS is side-effect-free
- rhs = expr.get('right')
- if is_side_effect_free(rhs):
- self.set_changed()
- return REMOVE
+ if pair not in dead_props:
+ return
+ # Only remove if the RHS is side-effect-free
+ rhs = expr.get('right')
+ if is_side_effect_free(rhs):
+ self.set_changed()
+ return REMOVE
traverse(self.ast, {'enter': remove_dead})
return self.has_changed()
diff --git a/pyjsclear/transforms/else_if_flatten.py b/pyjsclear/transforms/else_if_flatten.py
index 8559627..8a3fb27 100644
--- a/pyjsclear/transforms/else_if_flatten.py
+++ b/pyjsclear/transforms/else_if_flatten.py
@@ -11,22 +11,22 @@ class ElseIfFlattener(Transform):
where each else block wraps a single if statement.
"""
- def execute(self):
- def enter(node, parent, key, index):
- if node.get('type') != 'IfStatement':
- return
- alt = node.get('alternate')
- if not alt or alt.get('type') != 'BlockStatement':
- return
- body = alt.get('body', [])
- if len(body) != 1:
- return
- inner = body[0]
- if inner.get('type') != 'IfStatement':
- return
- # Flatten: replace the block with the inner if
- node['alternate'] = inner
- self.set_changed()
+ def _enter_node(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
+ if node.get('type') != 'IfStatement':
+ return
+ alternate_node = node.get('alternate')
+ if not alternate_node or alternate_node.get('type') != 'BlockStatement':
+ return
+ body = alternate_node.get('body', [])
+ if len(body) != 1:
+ return
+ inner_if = body[0]
+ if inner_if.get('type') != 'IfStatement':
+ return
+ # Flatten: replace the block with the inner if
+ node['alternate'] = inner_if
+ self.set_changed()
- traverse(self.ast, {'enter': enter})
+ def execute(self) -> bool:
+ traverse(self.ast, {'enter': self._enter_node})
return self.has_changed()
diff --git a/pyjsclear/transforms/enum_resolver.py b/pyjsclear/transforms/enum_resolver.py
index 6c9ee94..9555043 100644
--- a/pyjsclear/transforms/enum_resolver.py
+++ b/pyjsclear/transforms/enum_resolver.py
@@ -21,10 +21,10 @@
class EnumResolver(Transform):
"""Replace TypeScript enum member accesses with their numeric values."""
- def execute(self):
+ def execute(self) -> bool:
# Phase 1: Detect enum IIFEs and extract member values.
# Pattern: (function(param) { param[param.NAME = VALUE] = "NAME"; ... })(E || (E = {}))
- enum_members = {} # (enum_name, member_name) -> numeric value
+ enum_members: dict[tuple[str, str], int | float] = {} # (enum_name, member_name) -> numeric value
def find_enums(node, parent):
if node.get('type') != 'CallExpression':
@@ -91,7 +91,7 @@ def resolve(node, parent, key, index):
traverse(self.ast, {'enter': resolve})
return self.has_changed()
- def _extract_enum_name(self, arg):
+ def _extract_enum_name(self, argument_node: dict) -> str | None:
"""Extract the enum name from the IIFE argument pattern.
Handles:
@@ -99,21 +99,21 @@ def _extract_enum_name(self, arg):
- E = X.Y || (X.Y = {}) (export-assigned variant)
"""
# Simple case: just an identifier
- if is_identifier(arg):
- return arg['name']
+ if is_identifier(argument_node):
+ return argument_node['name']
# Assignment wrapper: E = X.Y || (X.Y = {})
- if arg.get('type') == 'AssignmentExpression' and arg.get('operator') == '=':
- assign_left = arg.get('left')
+ if argument_node.get('type') == 'AssignmentExpression' and argument_node.get('operator') == '=':
+ assign_left = argument_node.get('left')
if is_identifier(assign_left):
- inner = arg.get('right')
+ inner = argument_node.get('right')
if inner and inner.get('type') == 'LogicalExpression':
return assign_left['name']
return None
# Logical OR pattern: E || (E = {})
- if arg.get('type') != 'LogicalExpression' or arg.get('operator') != '||':
+ if argument_node.get('type') != 'LogicalExpression' or argument_node.get('operator') != '||':
return None
- left = arg.get('left')
- right = arg.get('right')
+ left = argument_node.get('left')
+ right = argument_node.get('right')
if not is_identifier(left):
return None
name = left['name']
@@ -124,15 +124,15 @@ def _extract_enum_name(self, arg):
return name
return None
- def _extract_enum_assignment(self, expr, param_name):
+ def _extract_enum_assignment(self, expression: dict | None, param_name: str) -> tuple[str | None, int | float | None]:
"""Extract (member_name, value) from: param[param.NAME = VALUE] = "NAME".
Returns (member_name, numeric_value) or (None, None).
"""
- if not expr or expr.get('type') != 'AssignmentExpression':
+ if not expression or expression.get('type') != 'AssignmentExpression':
return None, None
# The outer assignment: param[...] = "NAME"
- left = expr.get('left')
+ left = expression.get('left')
if not left or left.get('type') != 'MemberExpression' or not left.get('computed'):
return None, None
obj = left.get('object')
@@ -159,7 +159,7 @@ def _extract_enum_assignment(self, expr, param_name):
return member_name, value
@staticmethod
- def _get_numeric_value(node):
+ def _get_numeric_value(node: dict | None) -> int | float | None:
"""Extract a numeric value from a literal or unary minus expression."""
if not node:
return None
diff --git a/pyjsclear/transforms/eval_unpack.py b/pyjsclear/transforms/eval_unpack.py
index 7f2baa1..9acf547 100644
--- a/pyjsclear/transforms/eval_unpack.py
+++ b/pyjsclear/transforms/eval_unpack.py
@@ -27,21 +27,21 @@
_EVAL_RE = re.compile(r'^eval\s*\(', re.MULTILINE)
-def is_eval_packed(code):
+def is_eval_packed(code: str) -> bool:
"""Check if code uses eval packing."""
return bool(_PACKER_RE.search(code) or _PACKER_RE2.search(code) or _EVAL_RE.search(code.lstrip()))
-def _dean_edwards_unpack(packed, radix, count, keywords):
+def _dean_edwards_unpack(packed: str, radix: int, count: int, keywords: list[str]) -> str:
"""Pure Python implementation of Dean Edwards unpacker."""
# Build the replacement function
- def base_encode(c):
- prefix = '' if c < radix else base_encode(int(c / radix))
- c = c % radix
- if c > 35:
- return prefix + chr(c + 29)
- return prefix + ('0123456789abcdefghijklmnopqrstuvwxyz'[c] if c < 36 else chr(c + 29))
+ def base_encode(value: int) -> str:
+ prefix = '' if value < radix else base_encode(int(value / radix))
+ remainder = value % radix
+ if remainder > 35:
+ return prefix + chr(remainder + 29)
+ return prefix + ('0123456789abcdefghijklmnopqrstuvwxyz'[remainder] if remainder < 36 else chr(remainder + 29))
# Build dictionary
lookup = {}
@@ -51,34 +51,35 @@ def base_encode(c):
lookup[key] = keywords[count] if count < len(keywords) and keywords[count] else key
# Replace tokens in packed string
- def replacer(match):
- token = match.group(0)
+ def replacer(token_match: re.Match) -> str:
+ token = token_match.group(0)
return lookup.get(token, token)
return re.sub(r'\b\w+\b', replacer, packed)
-def eval_unpack(code):
+def eval_unpack(code: str) -> str | None:
"""Unpack eval-packed JavaScript. Returns unpacked code or None."""
return _try_dean_edwards(code)
-def _try_dean_edwards(code):
+def _try_dean_edwards(code: str) -> str | None:
"""Try to unpack Dean Edwards packer format."""
for pattern in [_PACKER_RE, _PACKER_RE2]:
- m = pattern.search(code)
- if m:
- packed = m.group(1)
- radix = int(m.group(2))
- count = int(m.group(3))
- keywords_str = m.group(4)
- keywords = keywords_str.split('|')
-
- # Unescape the packed string
- packed = packed.replace("\\'", "'").replace('\\\\', '\\')
-
- try:
- return _dean_edwards_unpack(packed, radix, count, keywords)
- except Exception:
- continue
+ pattern_match = pattern.search(code)
+ if not pattern_match:
+ continue
+
+ packed = pattern_match.group(1)
+ radix = int(pattern_match.group(2))
+ count = int(pattern_match.group(3))
+ keywords = pattern_match.group(4).split('|')
+
+ # Unescape the packed string
+ packed = packed.replace("\\'", "'").replace('\\\\', '\\')
+
+ try:
+ return _dean_edwards_unpack(packed, radix, count, keywords)
+ except Exception:
+ continue
return None
diff --git a/pyjsclear/transforms/expression_simplifier.py b/pyjsclear/transforms/expression_simplifier.py
index 073a6f5..743ce97 100644
--- a/pyjsclear/transforms/expression_simplifier.py
+++ b/pyjsclear/transforms/expression_simplifier.py
@@ -1,6 +1,7 @@
"""Evaluate static unary/binary expressions to literals."""
import math
+from typing import Any
from ..traverser import traverse
from ..utils.ast_helpers import is_literal
@@ -41,7 +42,7 @@
class ExpressionSimplifier(Transform):
"""Simplify constant unary/binary expressions to literals."""
- def execute(self):
+ def execute(self) -> bool:
self._simplify_unary_binary()
self._simplify_conditionals()
self._simplify_awaits()
@@ -49,32 +50,38 @@ def execute(self):
self._simplify_method_calls()
return self.has_changed()
- def _simplify_unary_binary(self):
+ def _simplify_unary_binary(self) -> None:
"""Fold constant unary and binary expressions."""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None:
match node.get('type', ''):
case 'UnaryExpression':
result = self._simplify_unary(node)
case 'BinaryExpression':
result = self._simplify_binary(node)
case _:
- return
+ return None
if result is not None:
self.set_changed()
return result
+ return None
traverse(self.ast, {'enter': enter})
- def _simplify_conditionals(self):
+ def _simplify_conditionals(self) -> None:
"""Convert test ? false : true → !test."""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None:
if node.get('type') != 'ConditionalExpression':
- return
- cons = node.get('consequent')
- alt = node.get('alternate')
- if is_literal(cons) and cons.get('value') is False and is_literal(alt) and alt.get('value') is True:
+ return None
+ consequent = node.get('consequent')
+ alternate = node.get('alternate')
+ if (
+ is_literal(consequent)
+ and consequent.get('value') is False
+ and is_literal(alternate)
+ and alternate.get('value') is True
+ ):
self.set_changed()
return {
'type': 'UnaryExpression',
@@ -82,48 +89,49 @@ def enter(node, parent, key, index):
'prefix': True,
'argument': node['test'],
}
+ return None
traverse(self.ast, {'enter': enter})
- def _simplify_awaits(self):
+ def _simplify_awaits(self) -> None:
"""Simplify await (0x0, expr) → await expr."""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'AwaitExpression':
return
- arg = node.get('argument')
- if not isinstance(arg, dict) or arg.get('type') != 'SequenceExpression':
+ argument = node.get('argument')
+ if not isinstance(argument, dict) or argument.get('type') != 'SequenceExpression':
return
- exprs = arg.get('expressions', [])
- if len(exprs) <= 1:
+ expressions = argument.get('expressions', [])
+ if len(expressions) <= 1:
return
- node['argument'] = exprs[-1]
+ node['argument'] = expressions[-1]
self.set_changed()
traverse(self.ast, {'enter': enter})
- def _simplify_comma_calls(self):
+ def _simplify_comma_calls(self) -> None:
"""Simplify (0, expr)(args) → expr(args)."""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'CallExpression':
return
callee = node.get('callee')
if not isinstance(callee, dict) or callee.get('type') != 'SequenceExpression':
return
- exprs = callee.get('expressions', [])
- if len(exprs) < 2:
+ expressions = callee.get('expressions', [])
+ if len(expressions) < 2:
return
# Only simplify when the leading expressions are side-effect-free literals
- for expr in exprs[:-1]:
- if not isinstance(expr, dict) or expr.get('type') != 'Literal':
+ for expression in expressions[:-1]:
+ if not isinstance(expression, dict) or expression.get('type') != 'Literal':
return
- node['callee'] = exprs[-1]
+ node['callee'] = expressions[-1]
self.set_changed()
traverse(self.ast, {'enter': enter})
- def _simplify_unary(self, node):
+ def _simplify_unary(self, node: dict) -> dict | None:
operator = node.get('operator', '')
if operator not in _RESOLVABLE_UNARY:
return None
@@ -137,8 +145,8 @@ def _simplify_unary(self, node):
argument = node.get('argument')
argument = self._simplify_expr(argument)
- value, ok = self._get_resolvable_value(argument)
- if not ok:
+ value, resolved = self._get_resolvable_value(argument)
+ if not resolved:
return None
try:
@@ -147,7 +155,7 @@ def _simplify_unary(self, node):
return None
return self._value_to_node(result)
- def _simplify_binary(self, node):
+ def _simplify_binary(self, node: dict) -> dict | None:
operator = node.get('operator', '')
if operator not in _RESOLVABLE_BINARY:
return None
@@ -172,7 +180,7 @@ def _simplify_binary(self, node):
return None
return self._value_to_node(result)
- def _simplify_expr(self, node):
+ def _simplify_expr(self, node: Any) -> Any:
if not isinstance(node, dict):
return node
match node.get('type', ''):
@@ -184,7 +192,7 @@ def _simplify_expr(self, node):
return result if result is not None else node
return node
- def _is_negative_numeric(self, node):
+ def _is_negative_numeric(self, node: Any) -> bool:
return (
isinstance(node, dict)
and node.get('type') == 'UnaryExpression'
@@ -193,7 +201,7 @@ def _is_negative_numeric(self, node):
and isinstance(node['argument'].get('value'), (int, float))
)
- def _get_resolvable_value(self, node):
+ def _get_resolvable_value(self, node: Any) -> tuple[Any, bool]:
if not isinstance(node, dict):
return None, False
match node.get('type', ''):
@@ -204,9 +212,9 @@ def _get_resolvable_value(self, node):
# Literal with value None is JS null, not undefined
return (_JS_NULL if value is None else value), True
case 'UnaryExpression' if node.get('operator') == '-':
- arg = node.get('argument')
- if is_literal(arg) and isinstance(arg.get('value'), (int, float)):
- return -arg['value'], True
+ argument = node.get('argument')
+ if is_literal(argument) and isinstance(argument.get('value'), (int, float)):
+ return -argument['value'], True
case 'Identifier' if node.get('name') == 'undefined':
return None, True
case 'ArrayExpression' if len(node.get('elements', [])) == 0:
@@ -215,7 +223,7 @@ def _get_resolvable_value(self, node):
return {}, True
return None, False
- def _apply_unary(self, operator, value):
+ def _apply_unary(self, operator: str, value: Any) -> Any:
match operator:
case '-':
return -self._js_to_number(value)
@@ -224,16 +232,16 @@ def _apply_unary(self, operator, value):
case '!':
return not self._js_truthy(value)
case '~':
- n = self._js_to_number(value)
- if isinstance(n, float) and math.isnan(n):
+ number = self._js_to_number(value)
+ if isinstance(number, float) and math.isnan(number):
return -1 # ~NaN → -1
- return ~int(n)
+ return ~int(number)
case 'typeof':
return self._js_typeof(value)
case 'void':
return None # JS undefined
- def _apply_binary(self, operator, left, right):
+ def _apply_binary(self, operator: str, left: Any, right: Any) -> Any:
match operator:
case '+':
if isinstance(left, str) or isinstance(right, str):
@@ -244,15 +252,15 @@ def _apply_binary(self, operator, left, right):
case '*':
return self._js_to_number(left) * self._js_to_number(right)
case '/':
- result = self._js_to_number(right)
- if result == 0:
+ divisor = self._js_to_number(right)
+ if divisor == 0:
raise ValueError('division by zero')
- return self._js_to_number(left) / result
+ return self._js_to_number(left) / divisor
case '%':
- result = self._js_to_number(right)
- if result == 0:
+ modulus = self._js_to_number(right)
+ if modulus == 0:
raise ValueError('mod by zero')
- return self._js_to_number(left) % result
+ return self._js_to_number(left) % modulus
case '**':
return self._js_to_number(left) ** self._js_to_number(right)
case '|':
@@ -267,14 +275,14 @@ def _apply_binary(self, operator, left, right):
return self._js_to_int32(left) >> (self._js_to_int32(right) & 31)
case '>>>':
left_operand = self._js_to_int32(left) & 0xFFFFFFFF
- result = self._js_to_int32(right) & 31
- return left_operand >> result
+ shift = self._js_to_int32(right) & 31
+ return left_operand >> shift
case '==' | '!=':
- eq = self._js_abstract_eq(left, right)
- return eq if operator == '==' else not eq
+ equal = self._js_abstract_eq(left, right)
+ return equal if operator == '==' else not equal
case '===' | '!==':
- eq = self._js_strict_eq(left, right)
- return eq if operator == '===' else not eq
+ equal = self._js_strict_eq(left, right)
+ return equal if operator == '===' else not equal
case '<':
return self._js_compare(left, right) < 0
case '<=':
@@ -286,7 +294,7 @@ def _apply_binary(self, operator, left, right):
case _:
raise ValueError(f'Unknown operator: {operator}')
- def _js_abstract_eq(self, left, right):
+ def _js_abstract_eq(self, left: Any, right: Any) -> bool:
"""JS == (null and undefined are equal to each other only)."""
if (left is None or left is _JS_NULL) and (right is None or right is _JS_NULL):
return True
@@ -294,7 +302,7 @@ def _js_abstract_eq(self, left, right):
return False
return left == right
- def _js_strict_eq(self, left, right):
+ def _js_strict_eq(self, left: Any, right: Any) -> bool:
"""JS === (null !== undefined)."""
if left is _JS_NULL:
return right is _JS_NULL
@@ -302,7 +310,7 @@ def _js_strict_eq(self, left, right):
return False
return left == right and type(left) == type(right)
- def _js_truthy(self, value):
+ def _js_truthy(self, value: Any) -> bool:
if value is None or value is _JS_NULL:
return False
match value:
@@ -317,7 +325,7 @@ def _js_truthy(self, value):
case _:
return bool(value)
- def _js_typeof(self, value):
+ def _js_typeof(self, value: Any) -> str:
if value is _JS_NULL:
return 'object' # typeof null === 'object' in JS
if value is None:
@@ -334,11 +342,11 @@ def _js_typeof(self, value):
case _:
return 'undefined'
- def _js_to_int32(self, value):
+ def _js_to_int32(self, value: Any) -> int:
"""Coerce to 32-bit integer (for bitwise ops)."""
return int(self._js_to_number(value))
- def _js_to_number(self, value):
+ def _js_to_number(self, value: Any) -> int | float:
if value is _JS_NULL:
return 0 # Number(null) → 0
if value is None:
@@ -362,7 +370,7 @@ def _js_to_number(self, value):
case _:
return 0
- def _js_to_string(self, value):
+ def _js_to_string(self, value: Any) -> str:
if value is _JS_NULL:
return 'null'
if value is None:
@@ -381,7 +389,7 @@ def _js_to_string(self, value):
case _:
return str(value)
- def _js_compare(self, left, right):
+ def _js_compare(self, left: Any, right: Any) -> int | float:
# JS compares strings lexicographically, not numerically
if isinstance(left, str) and isinstance(right, str):
if left < right:
@@ -403,7 +411,7 @@ def _js_compare(self, left, right):
return 1
return 0
- def _value_to_node(self, value):
+ def _value_to_node(self, value: Any) -> dict | None:
if value is _JS_NULL:
return make_literal(None) # null literal
if value is None:
@@ -426,7 +434,7 @@ def _value_to_node(self, value):
return make_literal(value)
return None
- def _simplify_method_calls(self):
+ def _simplify_method_calls(self) -> None:
"""Statically evaluate simple method calls on literals.
Handles:
@@ -434,27 +442,27 @@ def _simplify_method_calls(self):
(N).toString() → "N"
"""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None:
if node.get('type') != 'CallExpression':
- return
+ return None
callee = node.get('callee')
if not isinstance(callee, dict) or callee.get('type') != 'MemberExpression':
- return
+ return None
prop = callee.get('property')
if not prop:
- return
+ return None
method_name = prop.get('name') if prop.get('type') == 'Identifier' else None
if not method_name:
- return
+ return None
# (N).toString() → "N"
if method_name == 'toString' and len(node.get('arguments', [])) == 0:
obj = callee.get('object')
if is_numeric_literal(obj):
val = obj['value']
- s = str(int(val)) if isinstance(val, float) and val == int(val) else str(val)
+ string_value = str(int(val)) if isinstance(val, float) and val == int(val) else str(val)
self.set_changed()
- return make_literal(s)
+ return make_literal(string_value)
# Buffer.from([...nums...]).toString(encoding) → string literal
if method_name == 'toString' and len(node.get('arguments', [])) <= 1:
@@ -463,9 +471,11 @@ def enter(node, parent, key, index):
self.set_changed()
return make_literal(result)
+ return None
+
traverse(self.ast, {'enter': enter})
- def _try_eval_buffer_from_tostring(self, obj, args):
+ def _try_eval_buffer_from_tostring(self, obj: Any, arguments: list) -> str | None:
"""Try to evaluate Buffer.from([...nums...]).toString(encoding)."""
if not isinstance(obj, dict) or obj.get('type') != 'CallExpression':
return None
@@ -473,11 +483,19 @@ def _try_eval_buffer_from_tostring(self, obj, args):
if not isinstance(callee, dict) or callee.get('type') != 'MemberExpression':
return None
# Check for Buffer.from
- buf_obj = callee.get('object')
- buf_prop = callee.get('property')
- if not (buf_obj and buf_obj.get('type') == 'Identifier' and buf_obj.get('name') == 'Buffer'):
+ buffer_object = callee.get('object')
+ buffer_property = callee.get('property')
+ if not (
+ buffer_object
+ and buffer_object.get('type') == 'Identifier'
+ and buffer_object.get('name') == 'Buffer'
+ ):
return None
- if not (buf_prop and buf_prop.get('type') == 'Identifier' and buf_prop.get('name') == 'from'):
+ if not (
+ buffer_property
+ and buffer_property.get('type') == 'Identifier'
+ and buffer_property.get('name') == 'from'
+ ):
return None
# First arg must be an array of numbers
call_args = obj.get('arguments', [])
@@ -485,17 +503,17 @@ def _try_eval_buffer_from_tostring(self, obj, args):
return None
elements = call_args[0].get('elements', [])
byte_values = []
- for el in elements:
- if not is_numeric_literal(el):
+ for element in elements:
+ if not is_numeric_literal(element):
return None
- val = el['value']
+ val = element['value']
if not isinstance(val, (int, float)) or val != int(val) or val < 0 or val > 255:
return None
byte_values.append(int(val))
# Determine encoding for toString
encoding = 'utf8'
- if args and is_literal(args[0]) and isinstance(args[0].get('value'), str):
- encoding = args[0]['value']
+ if arguments and is_literal(arguments[0]) and isinstance(arguments[0].get('value'), str):
+ encoding = arguments[0]['value']
try:
data = bytes(byte_values)
if encoding in ('utf8', 'utf-8'):
diff --git a/pyjsclear/transforms/global_alias.py b/pyjsclear/transforms/global_alias.py
index ffab53b..1c644e2 100644
--- a/pyjsclear/transforms/global_alias.py
+++ b/pyjsclear/transforms/global_alias.py
@@ -54,64 +54,75 @@ class GlobalAliasInliner(Transform):
of mangled variable names is extremely unlikely.
"""
- def execute(self):
- aliases = {}
-
- # Phase 1: Find `var X = GLOBAL` patterns
- def find_aliases(node, parent, key, index):
- if node.get('type') != 'VariableDeclarator':
- return
- decl_id = node.get('id')
- init = node.get('init')
- if not is_identifier(decl_id) or not is_identifier(init):
- return
- if init['name'] in _WELL_KNOWN_GLOBALS:
- aliases[decl_id['name']] = init['name']
-
- traverse(self.ast, {'enter': find_aliases})
-
- if not aliases:
+ def _find_var_aliases(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
+ """Collect `var alias = GLOBAL` patterns into self._aliases."""
+ if node.get('type') != 'VariableDeclarator':
+ return
+ declaration_id = node.get('id')
+ initializer = node.get('init')
+ if not is_identifier(declaration_id) or not is_identifier(initializer):
+ return
+ if initializer['name'] in _WELL_KNOWN_GLOBALS:
+ self._aliases[declaration_id['name']] = initializer['name']
+
+ def _find_assignment_aliases(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
+ """Collect `alias = GLOBAL` assignment patterns into self._aliases."""
+ if node.get('type') != 'AssignmentExpression':
+ return
+ if node.get('operator') != '=':
+ return
+ left_node = node.get('left')
+ right_node = node.get('right')
+ if not is_identifier(left_node) or not is_identifier(right_node):
+ return
+ if right_node['name'] in _WELL_KNOWN_GLOBALS:
+ self._aliases[left_node['name']] = right_node['name']
+
+ def _is_non_reference_position(self, parent: dict | None, key: str | None) -> bool:
+ """Return True if the identifier is in a non-reference (definition/key) position."""
+ if not parent:
return False
+ parent_type = parent.get('type')
+ # Non-computed property name
+ if parent_type == 'MemberExpression' and key == 'property' and not parent.get('computed'):
+ return True
+ # Variable declaration target
+ if parent_type == 'VariableDeclarator' and key == 'id':
+ return True
+ # Assignment left-hand side
+ if parent_type == 'AssignmentExpression' and key == 'left':
+ return True
+ # Function/method name
+ if parent_type in ('FunctionDeclaration', 'FunctionExpression') and key == 'id':
+ return True
+ # Non-computed property key
+ if parent_type == 'Property' and key == 'key' and not parent.get('computed'):
+ return True
+ return False
+
+ def _replace_alias_refs(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None:
+ """Replace aliased identifier references with the global name."""
+ if not is_identifier(node):
+ return None
+ if self._is_non_reference_position(parent, key):
+ return None
+ name = node.get('name')
+ if name in self._aliases:
+ self.set_changed()
+ return make_identifier(self._aliases[name])
+ return None
+
+ def execute(self) -> bool:
+ self._aliases: dict[str, str] = {}
+
+ # Phase 1: collect `var X = GLOBAL` and `X = GLOBAL` patterns
+ traverse(self.ast, {'enter': self._find_var_aliases})
+
+ if not self._aliases:
+ return False
+
+ traverse(self.ast, {'enter': self._find_assignment_aliases})
- # Also find assignment aliases: X = GLOBAL (not just var X = GLOBAL)
- def find_assignment_aliases(node, parent, key, index):
- if node.get('type') != 'AssignmentExpression':
- return
- if node.get('operator') != '=':
- return
- left = node.get('left')
- right = node.get('right')
- if not is_identifier(left) or not is_identifier(right):
- return
- if right['name'] in _WELL_KNOWN_GLOBALS:
- aliases[left['name']] = right['name']
-
- traverse(self.ast, {'enter': find_assignment_aliases})
-
- # Phase 2: Replace all references
- def replace_refs(node, parent, key, index):
- if not is_identifier(node):
- return
- # Skip non-computed property names
- if parent and parent.get('type') == 'MemberExpression' and key == 'property' and not parent.get('computed'):
- return
- # Skip declaration targets
- if parent and parent.get('type') == 'VariableDeclarator' and key == 'id':
- return
- # Skip assignment left-hand sides
- if parent and parent.get('type') == 'AssignmentExpression' and key == 'left':
- return
- # Skip function/method names
- if parent and parent.get('type') in ('FunctionDeclaration', 'FunctionExpression') and key == 'id':
- return
- # Skip property keys
- if parent and parent.get('type') == 'Property' and key == 'key' and not parent.get('computed'):
- return
-
- name = node.get('name')
- if name in aliases:
- self.set_changed()
- return make_identifier(aliases[name])
-
- traverse(self.ast, {'enter': replace_refs})
+ # Phase 2: replace all alias references
+ traverse(self.ast, {'enter': self._replace_alias_refs})
return self.has_changed()
diff --git a/pyjsclear/transforms/hex_escapes.py b/pyjsclear/transforms/hex_escapes.py
index 6af0c2b..01b6965 100644
--- a/pyjsclear/transforms/hex_escapes.py
+++ b/pyjsclear/transforms/hex_escapes.py
@@ -9,17 +9,15 @@
class HexEscapes(Transform):
"""Pre-AST regex pass to decode hex escape sequences."""
- def execute(self):
- # This works on the raw source before/after AST
- # But since we operate on AST, we decode hex in string literal raw values
+ def execute(self) -> bool:
+ # Decode hex/unicode escapes in string literal raw values (value already decoded by parser)
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'Literal' or not isinstance(node.get('value'), str):
return
raw_string = node.get('raw', '')
if '\\x' not in raw_string and '\\u' not in raw_string:
return
- # The value is already decoded by parser, just fix raw
value = node['value']
new_raw = (
'"'
@@ -38,7 +36,7 @@ def enter(node, parent, key, index):
return self.has_changed()
-def decode_hex_escapes_source(code):
+def decode_hex_escapes_source(code: str) -> str:
"""Decode hex escapes in source code string (pre-parse pass).
Only decodes hex escapes that produce printable characters (0x20-0x7e),
@@ -47,22 +45,21 @@ def decode_hex_escapes_source(code):
are left as \\xHH to avoid breaking the parser.
"""
- def replace_in_string(match_result):
+ def replace_in_string(match_result: re.Match) -> str:
quote = match_result.group(1)
content = match_result.group(2)
# Decode hex escapes, but skip backslash, both quote chars,
# and control chars. Quote chars are left for AST-level handling
# which normalizes to double quotes like Babel.
- def replace_hex_in_context(hex_match):
- value = int(hex_match.group(1), 16)
- if 0x20 <= value <= 0x7E and value not in (0x22, 0x27, 0x5C):
- return chr(value)
+ def replace_hex_in_context(hex_match: re.Match) -> str:
+ char_value = int(hex_match.group(1), 16)
+ if 0x20 <= char_value <= 0x7E and char_value not in (0x22, 0x27, 0x5C):
+ return chr(char_value)
return hex_match.group(0)
decoded = re.sub(r'\\x([0-9a-fA-F]{2})', replace_hex_in_context, content)
return quote + decoded + quote
# Match string literals and decode hex escapes within them
- result = re.sub(r"""(['"])((?:(?!\1|\\).|\\.)*?)\1""", replace_in_string, code)
- return result
+ return re.sub(r"""(['"])((?:(?!\1|\\).|\\.)*?)\1""", replace_in_string, code)
diff --git a/pyjsclear/transforms/hex_numerics.py b/pyjsclear/transforms/hex_numerics.py
index 2eb5256..97461b0 100644
--- a/pyjsclear/transforms/hex_numerics.py
+++ b/pyjsclear/transforms/hex_numerics.py
@@ -7,9 +7,9 @@
class HexNumerics(Transform):
"""Convert hex numeric literals to decimal representation."""
- def execute(self):
+ def execute(self) -> bool:
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'Literal':
return
value = node.get('value')
diff --git a/pyjsclear/transforms/jj_decode.py b/pyjsclear/transforms/jj_decode.py
index 99b3924..148f5b7 100644
--- a/pyjsclear/transforms/jj_decode.py
+++ b/pyjsclear/transforms/jj_decode.py
@@ -18,7 +18,7 @@
# ---------------------------------------------------------------------------
-def is_jj_encoded(code):
+def is_jj_encoded(code: str) -> bool:
"""Return True if *code* looks like JJEncoded JavaScript.
Checks for the ``VARNAME=~[]`` initialisation pattern that begins every
@@ -37,84 +37,91 @@ def is_jj_encoded(code):
_OBJECT_STR = '[object Object]'
+# Single-char JS escape sequences
+_SINGLE_CHAR_ESCAPES: dict[str, str] = {
+ 'n': '\n', 'r': '\r', 't': '\t',
+ '\\': '\\', "'": "'", '"': '"',
+ '/': '/', 'b': '\b', 'f': '\f',
+}
-def _split_at_depth_zero(text, delimiter):
+
+def _split_at_depth_zero(text: str, delimiter: str) -> list[str]:
"""Split *text* on *delimiter* only when bracket/paren depth is 0
and not inside a string literal. All logic is iterative."""
parts = []
current = []
depth = 0
- i = 0
+ index = 0
in_string = None
- while i < len(text):
- ch = text[i]
+ while index < len(text):
+ character = text[index]
if in_string is not None:
- current.append(ch)
- if ch == '\\' and i + 1 < len(text):
- i += 1
- current.append(text[i])
- elif ch == in_string:
+ current.append(character)
+ if character == '\\' and index + 1 < len(text):
+ index += 1
+ current.append(text[index])
+ elif character == in_string:
in_string = None
- i += 1
+ index += 1
continue
- if ch in ('"', "'"):
- in_string = ch
- current.append(ch)
- i += 1
+ if character in ('"', "'"):
+ in_string = character
+ current.append(character)
+ index += 1
continue
- if ch in ('(', '[', '{'):
+ if character in ('(', '[', '{'):
depth += 1
- current.append(ch)
- i += 1
+ current.append(character)
+ index += 1
continue
- if ch in (')', ']', '}'):
+ if character in (')', ']', '}'):
depth -= 1
- current.append(ch)
- i += 1
+ current.append(character)
+ index += 1
continue
- if depth == 0 and text[i:i + len(delimiter)] == delimiter:
+ if depth == 0 and text[index:index + len(delimiter)] == delimiter:
parts.append(''.join(current))
current = []
- i += len(delimiter)
+ index += len(delimiter)
continue
- current.append(ch)
- i += 1
+ current.append(character)
+ index += 1
parts.append(''.join(current))
return parts
-def _find_matching_close(text, start, open_ch, close_ch):
- """Return index of *close_ch* matching *open_ch* at *start*.
+def _find_matching_close(text: str, start: int, open_character: str, close_character: str) -> int:
+ """Return index of *close_character* matching *open_character* at *start*.
Iterative, respects strings."""
depth = 0
in_string = None
- i = start
- while i < len(text):
- ch = text[i]
+ index = start
+ while index < len(text):
+ character = text[index]
if in_string is not None:
- if ch == '\\' and i + 1 < len(text):
- i += 2
+ if character == '\\' and index + 1 < len(text):
+ index += 2
continue
- if ch == in_string:
+ if character == in_string:
in_string = None
- i += 1
+ index += 1
continue
- if ch in ('"', "'"):
- in_string = ch
- elif ch == open_ch:
+ if character in ('"', "'"):
+ in_string = character
+ elif character == open_character:
depth += 1
- elif ch == close_ch:
+ elif character == close_character:
depth -= 1
if depth == 0:
- return i
- i += 1
+ return index
+ index += 1
return -1
@@ -123,7 +130,7 @@ def _find_matching_close(text, start, open_ch, close_ch):
# ---------------------------------------------------------------------------
-def _parse_symbol_table(stmt, varname):
+def _parse_symbol_table(stmt: str, varname: str) -> dict[str, int | str] | None:
"""Parse the ``$={___:++$, ...}`` statement and return a dict
mapping property names to their resolved values (ints or chars)."""
prefix = varname + '='
@@ -137,7 +144,7 @@ def _parse_symbol_table(stmt, varname):
else:
return None
- table = {}
+ table: dict[str, int | str] = {}
counter = -1 # ~[] = -1
entries = _split_at_depth_zero(body, ',')
@@ -145,11 +152,11 @@ def _parse_symbol_table(stmt, varname):
entry = entry.strip()
if not entry:
continue
- colon_idx = entry.find(':')
- if colon_idx == -1:
+ colon_index = entry.find(':')
+ if colon_index == -1:
continue
- key = entry[:colon_idx].strip()
- value_expr = entry[colon_idx + 1:].strip()
+ key = entry[:colon_index].strip()
+ value_expr = entry[colon_index + 1:].strip()
if value_expr.startswith('++'):
counter += 1
@@ -163,16 +170,16 @@ def _parse_symbol_table(stmt, varname):
# Walk backwards to find matching [
depth = 0
bracket_start = -1
- j = bracket_end
- while j >= 0:
- if value_expr[j] == ']':
+ scan = bracket_end
+ while scan >= 0:
+ if value_expr[scan] == ']':
depth += 1
- elif value_expr[j] == '[':
+ elif value_expr[scan] == '[':
depth -= 1
if depth == 0:
- bracket_start = j
+ bracket_start = scan
break
- j -= 1
+ scan -= 1
if bracket_start <= 0:
continue
@@ -192,7 +199,7 @@ def _parse_symbol_table(stmt, varname):
return table
-def _eval_coercion(expr, varname):
+def _eval_coercion(expr: str, varname: str) -> str | None:
"""Evaluate a coercion expression to a string.
Handles: (![]+"") -> "false", (!""+"") -> "true",
@@ -212,7 +219,7 @@ def _eval_coercion(expr, varname):
if expr == '![]':
return 'false'
- if expr == '!""' or expr == "!''":
+ if expr in ('!""', "!''"):
return 'true'
if expr == '{}':
return _OBJECT_STR
@@ -220,7 +227,7 @@ def _eval_coercion(expr, varname):
if expr == varname + '[' + varname + ']':
return 'undefined'
# (!VARNAME) where VARNAME is object -> false
- if expr == '!' + varname or expr == '(!' + varname + ')':
+ if expr in ('!' + varname, '(!' + varname + ')'):
return 'false'
# General X[X] pattern
if re.match(r'^([a-zA-Z_$][a-zA-Z0-9_$]*)\[\1\]$', expr):
@@ -239,14 +246,14 @@ def _eval_coercion(expr, varname):
_MAX_EVAL_DEPTH = 100
-def _eval_expr(expr, table, varname, _depth=0):
+def _eval_expr(expr: str, table: dict[str, int | str], varname: str, depth: int = 0) -> str | None:
"""Evaluate a JJEncode expression to a string value.
Handles symbol-table references, string literals, coercion
expressions with indexing, sub-assignments, and concatenation.
Returns the resolved string or None.
"""
- if _depth > _MAX_EVAL_DEPTH:
+ if depth > _MAX_EVAL_DEPTH:
return None
expr = expr.strip()
@@ -281,15 +288,15 @@ def _eval_expr(expr, table, varname, _depth=0):
else:
break
- val = _eval_inner(inner, table, varname, _depth + 1)
+ val = _eval_inner(inner, table, varname, depth + 1)
if not rest:
return val
# Check for [index] after the paren
if rest.startswith('[') and rest.endswith(']'):
if val is None:
return None
- idx_expr = rest[1:-1].strip()
- idx = _resolve_int(idx_expr, table, varname)
+ index_expr = rest[1:-1].strip()
+ idx = _resolve_int(index_expr, table, varname)
if isinstance(val, str) and idx is not None and 0 <= idx < len(val):
return val[idx]
return None
@@ -309,8 +316,8 @@ def _eval_expr(expr, table, varname, _depth=0):
key = expr[len(prefix):bracket_pos]
str_val = table.get(key)
if isinstance(str_val, str) and expr.endswith(']'):
- idx_expr = expr[bracket_pos + 1:-1]
- idx = _resolve_int(idx_expr, table, varname)
+ index_expr = expr[bracket_pos + 1:-1]
+ idx = _resolve_int(index_expr, table, varname)
if idx is not None and 0 <= idx < len(str_val):
return str_val[idx]
return None
@@ -326,20 +333,20 @@ def _eval_expr(expr, table, varname, _depth=0):
tokens = _split_at_depth_zero(expr, '+')
if len(tokens) > 1:
parts = []
- for t in tokens:
- v = _eval_expr(t, table, varname, _depth + 1)
- if v is None:
+ for token in tokens:
+ token_val = _eval_expr(token, table, varname, depth + 1)
+ if token_val is None:
return None
- parts.append(v)
+ parts.append(token_val)
return ''.join(parts)
return None
-def _eval_inner(inner, table, varname, _depth=0):
+def _eval_inner(inner: str, table: dict[str, int | str], varname: str, depth: int = 0) -> str | None:
"""Evaluate the inside of a parenthesised expression.
Handles sub-assignments and simple expressions."""
- if _depth > _MAX_EVAL_DEPTH:
+ if depth > _MAX_EVAL_DEPTH:
return None
prefix = varname + '.'
@@ -350,7 +357,7 @@ def _eval_inner(inner, table, varname, _depth=0):
if eq_pos is not None:
key = inner[len(prefix):eq_pos]
rhs = inner[eq_pos + 1:]
- val = _eval_expr(rhs, table, varname, _depth + 1)
+ val = _eval_expr(rhs, table, varname, depth + 1)
if val is not None:
table[key] = val
return val
@@ -361,41 +368,41 @@ def _eval_inner(inner, table, varname, _depth=0):
return coercion_str
# Just a nested expression
- return _eval_expr(inner, table, varname, _depth + 1)
+ return _eval_expr(inner, table, varname, depth + 1)
-def _find_top_level_eq(expr):
+def _find_top_level_eq(expr: str) -> int | None:
"""Find the position of the first ``=`` at depth 0 that is not ``==``."""
depth = 0
in_string = None
- i = 0
- while i < len(expr):
- ch = expr[i]
+ index = 0
+ while index < len(expr):
+ character = expr[index]
if in_string is not None:
- if ch == '\\' and i + 1 < len(expr):
- i += 2
+ if character == '\\' and index + 1 < len(expr):
+ index += 2
continue
- if ch == in_string:
+ if character == in_string:
in_string = None
- i += 1
+ index += 1
continue
- if ch in ('"', "'"):
- in_string = ch
- elif ch in ('(', '[', '{'):
+ if character in ('"', "'"):
+ in_string = character
+ elif character in ('(', '[', '{'):
depth += 1
- elif ch in (')', ']', '}'):
+ elif character in (')', ']', '}'):
depth -= 1
- elif ch == '=' and depth == 0:
+ elif character == '=' and depth == 0:
# Check not ==
- if i + 1 < len(expr) and expr[i + 1] == '=':
- i += 2
+ if index + 1 < len(expr) and expr[index + 1] == '=':
+ index += 2
continue
- return i
- i += 1
+ return index
+ index += 1
return None
-def _eval_coercion_indexed(expr, table, varname):
+def _eval_coercion_indexed(expr: str, table: dict[str, int | str], varname: str) -> str | None:
"""Handle ``(![]+"")[$._$_]`` — coercion string indexed by a
symbol table reference."""
if not expr.endswith(']'):
@@ -404,16 +411,16 @@ def _eval_coercion_indexed(expr, table, varname):
bracket_end = len(expr) - 1
depth = 0
bracket_start = -1
- j = bracket_end
- while j >= 0:
- if expr[j] == ']':
+ scan = bracket_end
+ while scan >= 0:
+ if expr[scan] == ']':
depth += 1
- elif expr[j] == '[':
+ elif expr[scan] == '[':
depth -= 1
if depth == 0:
- bracket_start = j
+ bracket_start = scan
break
- j -= 1
+ scan -= 1
if bracket_start <= 0:
return None
@@ -434,7 +441,7 @@ def _eval_coercion_indexed(expr, table, varname):
return ''
-def _resolve_int(expr, table, varname):
+def _resolve_int(expr: str, table: dict[str, int | str], varname: str) -> int | None:
"""Resolve an expression to an integer."""
expr = expr.strip()
prefix = varname + '.'
@@ -455,7 +462,7 @@ def _resolve_int(expr, table, varname):
# ---------------------------------------------------------------------------
-def _parse_augment_statement(stmt, table, varname):
+def _parse_augment_statement(stmt: str, table: dict[str, int | str], varname: str) -> None:
"""Parse statements that build multi-character strings like
"constructor" and "return" by concatenation, and store
intermediate single-char sub-assignments into the table."""
@@ -492,51 +499,51 @@ def _parse_augment_statement(stmt, table, varname):
# ---------------------------------------------------------------------------
-def _decode_js_string_literal(s):
+def _decode_js_string_literal(content: str) -> str:
"""Decode escapes in a JS string literal content (between quotes).
Only handles \\\\ -> \\, \\\" -> \", \\' -> ', and leaves everything
else (like \\1, \\x, \\u) as-is for later processing."""
result = []
- i = 0
- while i < len(s):
- if s[i] == '\\' and i + 1 < len(s):
- nch = s[i + 1]
- if nch in ('"', "'", '\\'):
- result.append(nch)
- i += 2
+ index = 0
+ while index < len(content):
+ if content[index] == '\\' and index + 1 < len(content):
+ next_character = content[index + 1]
+ if next_character in ('"', "'", '\\'):
+ result.append(next_character)
+ index += 2
continue
- result.append(s[i])
- i += 1
+ result.append(content[index])
+ index += 1
return ''.join(result)
-def _decode_escapes(s):
+def _decode_escapes(text: str) -> str:
"""Decode octal (\\NNN), hex (\\xNN), unicode (\\uNNNN) escape
sequences in a single left-to-right pass. Also handles standard
single-char escapes."""
result = []
- i = 0
- while i < len(s):
- if s[i] == '\\' and i + 1 < len(s):
- nch = s[i + 1]
+ index = 0
+ while index < len(text):
+ if text[index] == '\\' and index + 1 < len(text):
+ next_character = text[index + 1]
# Unicode escape \uNNNN
- if nch == 'u' and i + 5 < len(s):
- hex_str = s[i + 2:i + 6]
+ if next_character == 'u' and index + 5 < len(text):
+ hex_str = text[index + 2:index + 6]
try:
result.append(chr(int(hex_str, 16)))
- i += 6
+ index += 6
continue
except ValueError:
pass
# Hex escape \xNN
- if nch == 'x' and i + 3 < len(s):
- hex_str = s[i + 2:i + 4]
+ if next_character == 'x' and index + 3 < len(text):
+ hex_str = text[index + 2:index + 4]
try:
result.append(chr(int(hex_str, 16)))
- i += 4
+ index += 4
continue
except ValueError:
pass
@@ -544,35 +551,30 @@ def _decode_escapes(s):
# Octal escape: JS allows \0-\377 (max value 255).
# First digit 0-3: up to 3 total digits (\000-\377).
# First digit 4-7: up to 2 total digits (\40-\77).
- if '0' <= nch <= '7':
- max_digits = 3 if nch <= '3' else 2
+ if '0' <= next_character <= '7':
+ max_digits = 3 if next_character <= '3' else 2
octal = ''
- j = i + 1
- while j < len(s) and j < i + 1 + max_digits and '0' <= s[j] <= '7':
- octal += s[j]
- j += 1
+ scan = index + 1
+ while scan < len(text) and scan < index + 1 + max_digits and '0' <= text[scan] <= '7':
+ octal += text[scan]
+ scan += 1
result.append(chr(int(octal, 8)))
- i = j
+ index = scan
continue
# Standard single-char escapes
- _esc = {
- 'n': '\n', 'r': '\r', 't': '\t',
- '\\': '\\', "'": "'", '"': '"',
- '/': '/', 'b': '\b', 'f': '\f',
- }
- if nch in _esc:
- result.append(_esc[nch])
- i += 2
+ if next_character in _SINGLE_CHAR_ESCAPES:
+ result.append(_SINGLE_CHAR_ESCAPES[next_character])
+ index += 2
continue
# Unknown escape — keep literal
- result.append(nch)
- i += 2
+ result.append(next_character)
+ index += 2
continue
- result.append(s[i])
- i += 1
+ result.append(text[index])
+ index += 1
return ''.join(result)
@@ -582,7 +584,7 @@ def _decode_escapes(s):
# ---------------------------------------------------------------------------
-def _extract_payload_expression(stmt, varname):
+def _extract_payload_expression(stmt: str, varname: str) -> str | None:
"""Extract the inner concatenation expression from the payload
statement ``$.$($.$(EXPR)())()``."""
# Find VARNAME.$(VARNAME.$(
@@ -596,28 +598,28 @@ def _extract_payload_expression(stmt, varname):
# Find matching ) for the inner $.$(
depth = 1
in_string = None
- i = start
- while i < len(stmt):
- ch = stmt[i]
+ index = start
+ while index < len(stmt):
+ character = stmt[index]
if in_string is not None:
- if ch == '\\' and i + 1 < len(stmt):
- i += 2
+ if character == '\\' and index + 1 < len(stmt):
+ index += 2
continue
- if ch == in_string:
+ if character == in_string:
in_string = None
- i += 1
+ index += 1
continue
- if ch in ('"', "'"):
- in_string = ch
- i += 1
+ if character in ('"', "'"):
+ in_string = character
+ index += 1
continue
- if ch == '(':
+ if character == '(':
depth += 1
- elif ch == ')':
+ elif character == ')':
depth -= 1
if depth == 0:
- return stmt[start:i]
- i += 1
+ return stmt[start:index]
+ index += 1
return None
@@ -627,7 +629,7 @@ def _extract_payload_expression(stmt, varname):
# ---------------------------------------------------------------------------
-def jj_decode(code):
+def jj_decode(code: str) -> str | None:
"""Decode JJEncoded JavaScript. Returns the decoded string, or
``None`` on any failure."""
try:
@@ -637,16 +639,16 @@ def jj_decode(code):
return None
-def _jj_decode_inner(code):
+def _jj_decode_inner(code: str) -> str | None:
if not code or not code.strip():
return None
stripped = code.strip()
- m = re.match(r'^([a-zA-Z_$][a-zA-Z0-9_$]*)\s*=\s*~\s*\[\s*\]', stripped)
- if not m:
+ match = re.match(r'^([a-zA-Z_$][a-zA-Z0-9_$]*)\s*=\s*~\s*\[\s*\]', stripped)
+ if not match:
return None
- varname = m.group(1)
+ varname = match.group(1)
# Find the JJEncode line
jj_line = None
@@ -661,7 +663,7 @@ def _jj_decode_inner(code):
# Split into semicolon-delimited statements at depth 0
stmts = _split_at_depth_zero(jj_line, ';')
- stmts = [s.strip() for s in stmts if s.strip()]
+ stmts = [statement.strip() for statement in stmts if statement.strip()]
if len(stmts) < 5:
return None
diff --git a/pyjsclear/transforms/jsfuck_decode.py b/pyjsclear/transforms/jsfuck_decode.py
index dac5f49..6dbe547 100644
--- a/pyjsclear/transforms/jsfuck_decode.py
+++ b/pyjsclear/transforms/jsfuck_decode.py
@@ -6,7 +6,20 @@
string passed to Function().
"""
-def is_jsfuck(code):
+from enum import StrEnum
+
+
+class _JSType(StrEnum):
+ ARRAY = 'array'
+ BOOL = 'bool'
+ NUMBER = 'number'
+ STRING = 'string'
+ UNDEFINED = 'undefined'
+ OBJECT = 'object'
+ FUNCTION = 'function'
+
+
+def is_jsfuck(code: str) -> bool:
"""Check if code is JSFuck-encoded.
JSFuck code consists only of []()!+ characters (with optional whitespace/semicolons).
@@ -18,7 +31,7 @@ def is_jsfuck(code):
# Only count the six JSFuck operator characters — whitespace and
# semicolons are not distinctive and inflate the ratio on minified JS.
jsfuck_chars = set('[]()!+')
- jsfuck_count = sum(1 for c in stripped if c in jsfuck_chars)
+ jsfuck_count = sum(1 for character in stripped if character in jsfuck_chars)
return jsfuck_count / len(stripped) > 0.95
@@ -32,45 +45,45 @@ class _JSValue:
__slots__ = ('val', 'type')
- def __init__(self, val, typ):
+ def __init__(self, val: object, js_type: _JSType | str) -> None:
self.val = val
- self.type = typ # 'array', 'bool', 'number', 'string', 'undefined', 'object', 'function'
+ self.type = js_type
# -- coercion helpers ---------------------------------------------------
- def to_number(self):
+ def to_number(self) -> int | float:
match self.type:
- case 'number':
+ case _JSType.NUMBER:
return self.val
- case 'bool':
+ case _JSType.BOOL:
return 1 if self.val else 0
- case 'string':
- s = self.val.strip()
- if s == '':
+ case _JSType.STRING:
+ stripped = self.val.strip()
+ if stripped == '':
return 0
try:
- return int(s)
+ return int(stripped)
except ValueError:
try:
- return float(s)
+ return float(stripped)
except ValueError:
return float('nan')
- case 'array':
+ case _JSType.ARRAY:
if len(self.val) == 0:
return 0
if len(self.val) == 1:
return _JSValue(self.val[0], _guess_type(self.val[0])).to_number()
return float('nan')
- case 'undefined':
+ case _JSType.UNDEFINED:
return float('nan')
case _:
return float('nan')
- def to_string(self):
+ def to_string(self) -> str:
match self.type:
- case 'string':
+ case _JSType.STRING:
return self.val
- case 'number':
+ case _JSType.NUMBER:
if isinstance(self.val, float):
if self.val != self.val: # NaN
return 'NaN'
@@ -82,9 +95,9 @@ def to_string(self):
return str(int(self.val))
return str(self.val)
return str(self.val)
- case 'bool':
+ case _JSType.BOOL:
return 'true' if self.val else 'false'
- case 'array':
+ case _JSType.ARRAY:
parts = []
for item in self.val:
if item is None:
@@ -94,139 +107,143 @@ def to_string(self):
else:
parts.append(_JSValue(item, _guess_type(item)).to_string())
return ','.join(parts)
- case 'undefined':
+ case _JSType.UNDEFINED:
return 'undefined'
- case 'object':
+ case _JSType.OBJECT:
return '[object Object]'
case _:
return str(self.val)
- def to_bool(self):
+ def to_bool(self) -> bool:
match self.type:
- case 'bool':
+ case _JSType.BOOL:
return self.val
- case 'number':
+ case _JSType.NUMBER:
return self.val != 0 and self.val == self.val # 0 and NaN are falsy
- case 'string':
+ case _JSType.STRING:
return len(self.val) > 0
- case 'array':
+ case _JSType.ARRAY:
return True # arrays are always truthy in JS
- case 'undefined':
+ case _JSType.UNDEFINED:
return False
- case 'object':
+ case _JSType.OBJECT:
return True
case _:
return bool(self.val)
- def get_property(self, key):
+ def get_property(self, key: '_JSValue') -> '_JSValue':
"""Property access: self[key]."""
- key_str = key.to_string() if isinstance(key, _JSValue) else str(key)
-
- if self.type == 'string':
- # String indexing
- try:
- idx = int(key_str)
- if 0 <= idx < len(self.val):
- return _JSValue(self.val[idx], 'string')
- except (ValueError, IndexError):
- pass
- # String properties
- if key_str == 'length':
- return _JSValue(len(self.val), 'number')
- if key_str == 'constructor':
- return _STRING_CONSTRUCTOR
- # String.prototype methods
- return _get_string_method(self, key_str)
-
- if self.type == 'array':
- try:
- idx = int(key_str)
- if 0 <= idx < len(self.val):
- item = self.val[idx]
- if isinstance(item, _JSValue):
- return item
- return _JSValue(item, _guess_type(item))
- except (ValueError, IndexError):
- pass
- if key_str == 'length':
- return _JSValue(len(self.val), 'number')
- if key_str == 'constructor':
- return _ARRAY_CONSTRUCTOR
- # Array methods that JSFuck commonly accesses
- if key_str in (
- 'flat',
- 'fill',
- 'find',
- 'filter',
- 'entries',
- 'concat',
- 'join',
- 'sort',
- 'reverse',
- 'slice',
- 'map',
- 'forEach',
- 'reduce',
- 'some',
- 'every',
- 'indexOf',
- 'includes',
- 'keys',
- 'values',
- 'at',
- 'pop',
- 'push',
- 'shift',
- 'unshift',
- 'splice',
- 'toString',
- 'valueOf',
- ):
- return _JSValue(key_str, 'function')
-
- if self.type == 'number':
- if key_str == 'constructor':
- return _NUMBER_CONSTRUCTOR
- if key_str == 'toString':
- return _JSValue('toString', 'function')
- return _JSValue(None, 'undefined')
-
- if self.type == 'bool':
- if key_str == 'constructor':
- return _BOOLEAN_CONSTRUCTOR
- return _JSValue(None, 'undefined')
-
- if self.type == 'function':
- if key_str == 'constructor':
- return _FUNCTION_CONSTRUCTOR
- return _JSValue(None, 'undefined')
-
- if self.type == 'object':
- if key_str == 'constructor':
- return _OBJECT_CONSTRUCTOR
- return _JSValue(None, 'undefined')
-
- return _JSValue(None, 'undefined')
-
- def __repr__(self):
+ key_string = key.to_string() if isinstance(key, _JSValue) else str(key)
+
+ match self.type:
+ case _JSType.STRING:
+ return _get_string_property(self, key_string)
+ case _JSType.ARRAY:
+ return _get_array_property(self, key_string)
+ case _JSType.NUMBER:
+ if key_string == 'constructor':
+ return _NUMBER_CONSTRUCTOR
+ if key_string == 'toString':
+ return _JSValue('toString', _JSType.FUNCTION)
+ return _JSValue(None, _JSType.UNDEFINED)
+ case _JSType.BOOL:
+ if key_string == 'constructor':
+ return _BOOLEAN_CONSTRUCTOR
+ return _JSValue(None, _JSType.UNDEFINED)
+ case _JSType.FUNCTION:
+ if key_string == 'constructor':
+ return _FUNCTION_CONSTRUCTOR
+ return _JSValue(None, _JSType.UNDEFINED)
+ case _JSType.OBJECT:
+ if key_string == 'constructor':
+ return _OBJECT_CONSTRUCTOR
+ return _JSValue(None, _JSType.UNDEFINED)
+ case _:
+ return _JSValue(None, _JSType.UNDEFINED)
+
+ def __repr__(self) -> str:
return f'_JSValue({self.val!r}, {self.type!r})'
-def _guess_type(val):
- if isinstance(val, bool):
- return 'bool'
- if isinstance(val, (int, float)):
- return 'number'
- if isinstance(val, str):
- return 'string'
- if isinstance(val, list):
- return 'array'
- if val is None:
- return 'undefined'
- return 'object'
+def _get_string_property(string_value: '_JSValue', key_string: str) -> '_JSValue':
+ """Return the result of property access on a string value."""
+ try:
+ index = int(key_string)
+ if 0 <= index < len(string_value.val):
+ return _JSValue(string_value.val[index], _JSType.STRING)
+ except (ValueError, IndexError):
+ pass
+ if key_string == 'length':
+ return _JSValue(len(string_value.val), _JSType.NUMBER)
+ if key_string == 'constructor':
+ return _STRING_CONSTRUCTOR
+ return _get_string_method(key_string)
+
+
+def _get_array_property(array_value: '_JSValue', key_string: str) -> '_JSValue':
+ """Return the result of property access on an array value."""
+ try:
+ index = int(key_string)
+ if 0 <= index < len(array_value.val):
+ item = array_value.val[index]
+ if isinstance(item, _JSValue):
+ return item
+ return _JSValue(item, _guess_type(item))
+ except (ValueError, IndexError):
+ pass
+ if key_string == 'length':
+ return _JSValue(len(array_value.val), _JSType.NUMBER)
+ if key_string == 'constructor':
+ return _ARRAY_CONSTRUCTOR
+ # Array methods that JSFuck commonly accesses
+ if key_string in (
+ 'flat',
+ 'fill',
+ 'find',
+ 'filter',
+ 'entries',
+ 'concat',
+ 'join',
+ 'sort',
+ 'reverse',
+ 'slice',
+ 'map',
+ 'forEach',
+ 'reduce',
+ 'some',
+ 'every',
+ 'indexOf',
+ 'includes',
+ 'keys',
+ 'values',
+ 'at',
+ 'pop',
+ 'push',
+ 'shift',
+ 'unshift',
+ 'splice',
+ 'toString',
+ 'valueOf',
+ ):
+ return _JSValue(key_string, _JSType.FUNCTION)
+ return _JSValue(None, _JSType.UNDEFINED)
+
+
+def _guess_type(value: object) -> _JSType:
+ if isinstance(value, bool):
+ return _JSType.BOOL
+ if isinstance(value, (int, float)):
+ return _JSType.NUMBER
+ if isinstance(value, str):
+ return _JSType.STRING
+ if isinstance(value, list):
+ return _JSType.ARRAY
+ if value is None:
+ return _JSType.UNDEFINED
+ return _JSType.OBJECT
-def _get_string_method(string_val, method_name):
+def _get_string_method(method_name: str) -> '_JSValue':
"""Return a callable _JSValue wrapping a string method."""
if method_name in (
'italics',
@@ -265,17 +282,17 @@ def _get_string_method(string_val, method_name):
'normalize',
'flat',
):
- return _JSValue(method_name, 'function')
- return _JSValue(None, 'undefined')
+ return _JSValue(method_name, _JSType.FUNCTION)
+ return _JSValue(None, _JSType.UNDEFINED)
# Sentinel constructors for property chain resolution
-_STRING_CONSTRUCTOR = _JSValue('String', 'function')
-_NUMBER_CONSTRUCTOR = _JSValue('Number', 'function')
-_BOOLEAN_CONSTRUCTOR = _JSValue('Boolean', 'function')
-_ARRAY_CONSTRUCTOR = _JSValue('Array', 'function')
-_OBJECT_CONSTRUCTOR = _JSValue('Object', 'function')
-_FUNCTION_CONSTRUCTOR = _JSValue('Function', 'function')
+_STRING_CONSTRUCTOR = _JSValue('String', _JSType.FUNCTION)
+_NUMBER_CONSTRUCTOR = _JSValue('Number', _JSType.FUNCTION)
+_BOOLEAN_CONSTRUCTOR = _JSValue('Boolean', _JSType.FUNCTION)
+_ARRAY_CONSTRUCTOR = _JSValue('Array', _JSType.FUNCTION)
+_OBJECT_CONSTRUCTOR = _JSValue('Object', _JSType.FUNCTION)
+_FUNCTION_CONSTRUCTOR = _JSValue('Function', _JSType.FUNCTION)
# Known constructor-of-constructor chain results
_CONSTRUCTOR_MAP = {
@@ -293,12 +310,12 @@ def _get_string_method(string_val, method_name):
# ---------------------------------------------------------------------------
-def _tokenize(code):
+def _tokenize(code: str) -> list[str]:
"""Tokenize JSFuck code into a list of characters/tokens."""
tokens = []
- for ch in code:
- if ch in '[]()!+':
- tokens.append(ch)
+ for character in code:
+ if character in '[]()!+':
+ tokens.append(character)
# Skip whitespace, semicolons
return tokens
@@ -334,166 +351,166 @@ class _Parser:
arbitrarily deep nesting never overflows the Python call stack.
"""
- def __init__(self, tokens):
+ def __init__(self, tokens: list[str]) -> None:
self.tokens = tokens
self.pos = 0
- self.captured = None # Result from Function(body)()
+ self.captured: str | None = None # Result from Function(body)()
- def peek(self):
+ def peek(self) -> str | None:
if self.pos < len(self.tokens):
return self.tokens[self.pos]
return None
- def consume(self, expected=None):
+ def consume(self, expected: str | None = None) -> str:
if self.pos >= len(self.tokens):
raise _ParseError('Unexpected end of input')
- tok = self.tokens[self.pos]
- if expected is not None and tok != expected:
- raise _ParseError(f'Expected {expected!r}, got {tok!r}')
+ token = self.tokens[self.pos]
+ if expected is not None and token != expected:
+ raise _ParseError(f'Expected {expected!r}, got {token!r}')
self.pos += 1
- return tok
+ return token
# ------------------------------------------------------------------
- def parse(self):
+ def parse(self) -> _JSValue:
"""Parse and evaluate the full expression (iterative)."""
- val_stack = []
- cont = [(_K_DONE,)]
+ value_stack: list[_JSValue] = []
+ continuation: list[tuple] = [(_K_DONE,)]
state = _S_EXPR
while True:
if state == _S_EXPR:
# expression = unary ('+' unary)*
- cont.append((_K_EXPR_LOOP,))
+ continuation.append((_K_EXPR_LOOP,))
state = _S_UNARY
elif state == _S_UNARY:
# Collect prefix operators, then parse postfix
- ops = []
+ operators = []
while self.peek() in ('!', '+'):
- ops.append(self.consume())
- cont.append((_K_UNARY_APPLY, ops))
+ operators.append(self.consume())
+ continuation.append((_K_UNARY_APPLY, operators))
state = _S_POSTFIX
elif state == _S_POSTFIX:
# Parse primary, then handle postfix [ ] and ( )
- cont.append((_K_POSTFIX_LOOP, None)) # receiver=None
+ continuation.append((_K_POSTFIX_LOOP, None)) # receiver=None
state = _S_PRIMARY
elif state == _S_PRIMARY:
- tok = self.peek()
- if tok == '(':
+ token = self.peek()
+ if token == '(':
self.consume('(')
- cont.append((_K_PAREN_CLOSE,))
+ continuation.append((_K_PAREN_CLOSE,))
state = _S_EXPR
- elif tok == '[':
+ elif token == '[':
self.consume('[')
if self.peek() == ']':
self.consume(']')
- val_stack.append(_JSValue([], 'array'))
+ value_stack.append(_JSValue([], _JSType.ARRAY))
state = _S_RESUME
else:
- cont.append((_K_ARRAY_ELEM, []))
+ continuation.append((_K_ARRAY_ELEM, []))
state = _S_EXPR
else:
raise _ParseError(
- f'Unexpected token: {tok!r} at pos {self.pos}')
+ f'Unexpected token: {token!r} at pos {self.pos}')
elif state == _S_RESUME:
- k = cont.pop()
- ktype = k[0]
+ continuation_frame = continuation.pop()
+ continuation_type = continuation_frame[0]
- if ktype == _K_DONE:
- return val_stack.pop()
+ if continuation_type == _K_DONE:
+ return value_stack.pop()
- elif ktype == _K_PAREN_CLOSE:
+ elif continuation_type == _K_PAREN_CLOSE:
self.consume(')')
state = _S_RESUME
- elif ktype == _K_ARRAY_ELEM:
- elements = k[1]
- elements.append(val_stack.pop())
+ elif continuation_type == _K_ARRAY_ELEM:
+ elements = continuation_frame[1]
+ elements.append(value_stack.pop())
if self.peek() not in (']', None):
- cont.append((_K_ARRAY_ELEM, elements))
+ continuation.append((_K_ARRAY_ELEM, elements))
state = _S_EXPR
else:
self.consume(']')
- val_stack.append(_JSValue(elements, 'array'))
+ value_stack.append(_JSValue(elements, _JSType.ARRAY))
state = _S_RESUME
- elif ktype == _K_POSTFIX_LOOP:
- receiver = k[1]
- val = val_stack[-1]
+ elif continuation_type == _K_POSTFIX_LOOP:
+ receiver = continuation_frame[1]
+ current_value = value_stack[-1]
if self.peek() == '[':
self.consume('[')
- val_stack.pop()
- cont.append((_K_POSTFIX_BRACKET, val))
+ value_stack.pop()
+ continuation.append((_K_POSTFIX_BRACKET, current_value))
state = _S_EXPR
elif self.peek() == '(':
self.consume('(')
if self.peek() == ')':
self.consume(')')
- val_stack.pop()
- result = self._call(val, [], receiver)
- val_stack.append(result)
- cont.append((_K_POSTFIX_LOOP, None))
+ value_stack.pop()
+ result = self._call(current_value, [], receiver)
+ value_stack.append(result)
+ continuation.append((_K_POSTFIX_LOOP, None))
state = _S_RESUME
else:
- val_stack.pop()
- cont.append((_K_POSTFIX_ARGDONE, val, receiver))
+ value_stack.pop()
+ continuation.append((_K_POSTFIX_ARGDONE, current_value, receiver))
state = _S_EXPR
else:
# No more postfix ops
state = _S_RESUME
- elif ktype == _K_POSTFIX_BRACKET:
- parent_val = k[1]
- key = val_stack.pop()
+ elif continuation_type == _K_POSTFIX_BRACKET:
+ parent_value = continuation_frame[1]
+ key = value_stack.pop()
self.consume(']')
- val_stack.append(parent_val.get_property(key))
- cont.append((_K_POSTFIX_LOOP, parent_val))
+ value_stack.append(parent_value.get_property(key))
+ continuation.append((_K_POSTFIX_LOOP, parent_value))
state = _S_RESUME
- elif ktype == _K_POSTFIX_ARGDONE:
- func = k[1]
- receiver = k[2]
- arg = val_stack.pop()
+ elif continuation_type == _K_POSTFIX_ARGDONE:
+ func = continuation_frame[1]
+ receiver = continuation_frame[2]
+ argument = value_stack.pop()
self.consume(')')
- result = self._call(func, [arg], receiver)
- val_stack.append(result)
- cont.append((_K_POSTFIX_LOOP, None))
+ result = self._call(func, [argument], receiver)
+ value_stack.append(result)
+ continuation.append((_K_POSTFIX_LOOP, None))
state = _S_RESUME
- elif ktype == _K_UNARY_APPLY:
- ops = k[1]
- val = val_stack.pop()
- for op in reversed(ops):
- if op == '!':
- val = _JSValue(not val.to_bool(), 'bool')
- elif op == '+':
- val = _JSValue(val.to_number(), 'number')
- val_stack.append(val)
+ elif continuation_type == _K_UNARY_APPLY:
+ operators = continuation_frame[1]
+ current_value = value_stack.pop()
+ for operator in reversed(operators):
+ if operator == '!':
+ current_value = _JSValue(not current_value.to_bool(), _JSType.BOOL)
+ elif operator == '+':
+ current_value = _JSValue(current_value.to_number(), _JSType.NUMBER)
+ value_stack.append(current_value)
state = _S_RESUME
- elif ktype == _K_EXPR_LOOP:
+ elif continuation_type == _K_EXPR_LOOP:
if self.peek() == '+':
self.consume('+')
- left = val_stack.pop()
- cont.append((_K_EXPR_ADD, left))
+ left = value_stack.pop()
+ continuation.append((_K_EXPR_ADD, left))
state = _S_UNARY
else:
state = _S_RESUME
- elif ktype == _K_EXPR_ADD:
- left = k[1]
- right = val_stack.pop()
- val_stack.append(_js_add(left, right))
- cont.append((_K_EXPR_LOOP,))
+ elif continuation_type == _K_EXPR_ADD:
+ left = continuation_frame[1]
+ right = value_stack.pop()
+ value_stack.append(_js_add(left, right))
+ continuation.append((_K_EXPR_LOOP,))
state = _S_RESUME
# ------------------------------------------------------------------
- def _call(self, func, args, receiver=None):
+ def _call(self, func: _JSValue, args: list[_JSValue], receiver: _JSValue | None = None) -> _JSValue:
"""Handle function call semantics.
Only single-argument calls are supported (e.g. Function(body),
@@ -501,53 +518,53 @@ def _call(self, func, args, receiver=None):
emits multi-argument calls.
"""
# Function constructor: Function(body) returns a new function
- if func.type == 'function' and func.val == 'Function':
+ if func.type == _JSType.FUNCTION and func.val == 'Function':
if args:
body = args[-1].to_string()
- return _JSValue(('__function_body__', body), 'function')
+ return _JSValue(('__function_body__', body), _JSType.FUNCTION)
# Calling a function created by Function(body)
- if func.type == 'function' and isinstance(func.val, tuple):
+ if func.type == _JSType.FUNCTION and isinstance(func.val, tuple):
if func.val[0] == '__function_body__':
self.captured = func.val[1]
- return _JSValue(None, 'undefined')
+ return _JSValue(None, _JSType.UNDEFINED)
# Constructor property access — e.g., []["flat"]["constructor"]
- if func.type == 'function' and isinstance(func.val, str):
+ if func.type == _JSType.FUNCTION and isinstance(func.val, str):
name = func.val
if name in _CONSTRUCTOR_MAP:
if args:
- return _JSValue(args[0].to_string(), 'string')
- return _JSValue('', 'string')
+ return _JSValue(args[0].to_string(), _JSType.STRING)
+ return _JSValue('', _JSType.STRING)
if name == 'italics':
- return _JSValue('', 'string')
+ return _JSValue('', _JSType.STRING)
if name == 'fontcolor':
- return _JSValue('', 'string')
+ return _JSValue('', _JSType.STRING)
# toString with radix — e.g., (10)["toString"](36) → "a"
if name == 'toString' and args and receiver is not None:
radix = args[0].to_number()
if isinstance(radix, (int, float)) and radix == int(radix):
radix = int(radix)
- if 2 <= radix <= 36 and receiver.type == 'number':
+ if 2 <= radix <= 36 and receiver.type == _JSType.NUMBER:
num = receiver.to_number()
if isinstance(num, (int, float)) and num == int(num):
- return _JSValue(_int_to_base(int(num), radix), 'string')
+ return _JSValue(_int_to_base(int(num), radix), _JSType.STRING)
- return _JSValue(None, 'undefined')
+ return _JSValue(None, _JSType.UNDEFINED)
-def _js_add(left, right):
+def _js_add(left: _JSValue, right: _JSValue) -> _JSValue:
"""JS + operator with type coercion."""
- if left.type == 'string' or right.type == 'string':
- return _JSValue(left.to_string() + right.to_string(), 'string')
- if left.type in ('array', 'object') or right.type in ('array', 'object'):
- return _JSValue(left.to_string() + right.to_string(), 'string')
- return _JSValue(left.to_number() + right.to_number(), 'number')
+ if left.type == _JSType.STRING or right.type == _JSType.STRING:
+ return _JSValue(left.to_string() + right.to_string(), _JSType.STRING)
+ if left.type in (_JSType.ARRAY, _JSType.OBJECT) or right.type in (_JSType.ARRAY, _JSType.OBJECT):
+ return _JSValue(left.to_string() + right.to_string(), _JSType.STRING)
+ return _JSValue(left.to_number() + right.to_number(), _JSType.NUMBER)
-def _int_to_base(num, base):
+def _int_to_base(num: int, base: int) -> str:
"""Convert integer to string in given base (2-36), matching JS behavior."""
if num == 0:
return '0'
@@ -572,7 +589,7 @@ class _ParseError(Exception):
# ---------------------------------------------------------------------------
-def jsfuck_decode(code):
+def jsfuck_decode(code: str) -> str | None:
"""Decode JSFuck-encoded JavaScript. Returns decoded string or None."""
if not code or not code.strip():
return None
diff --git a/pyjsclear/transforms/logical_to_if.py b/pyjsclear/transforms/logical_to_if.py
index df895f3..b471470 100644
--- a/pyjsclear/transforms/logical_to_if.py
+++ b/pyjsclear/transforms/logical_to_if.py
@@ -13,65 +13,65 @@
from .base import Transform
-def _negate(expr):
+def _negate(expression: dict) -> dict:
"""Wrap an expression in a logical NOT."""
return {
'type': 'UnaryExpression',
'operator': '!',
'prefix': True,
- 'argument': expr,
+ 'argument': expression,
}
class LogicalToIf(Transform):
"""Convert logical/comma expressions in statement position to if-statements."""
- def execute(self):
+ def execute(self) -> bool:
self._transform_bodies(self.ast)
return self.has_changed()
- def _transform_bodies(self, node):
+ def _transform_bodies(self, node: dict | list | object) -> None:
"""Walk all statement arrays and apply transforms."""
if not isinstance(node, dict):
return
for key, child in node.items():
if isinstance(child, list):
if child and isinstance(child[0], dict) and 'type' in child[0]:
- self._process_stmt_array(child)
+ self._process_statement_array(child)
for item in child:
self._transform_bodies(item)
elif isinstance(child, dict) and 'type' in child:
self._transform_bodies(child)
- def _process_stmt_array(self, stmts):
- i = 0
- while i < len(stmts):
- stmt = stmts[i]
- if not isinstance(stmt, dict):
- i += 1
+ def _process_statement_array(self, statements: list) -> None:
+ index = 0
+ while index < len(statements):
+ statement = statements[index]
+ if not isinstance(statement, dict):
+ index += 1
continue
- replacement = self._try_convert_stmt(stmt)
+ replacement = self._try_convert_stmt(statement)
if replacement is not None:
- stmts[i : i + 1] = replacement
+ statements[index : index + 1] = replacement
self.set_changed()
- i += len(replacement)
+ index += len(replacement)
continue
- i += 1
+ index += 1
- def _try_convert_stmt(self, stmt):
+ def _try_convert_stmt(self, statement: dict) -> list | None:
"""Try to convert a statement. Returns replacement list or None."""
- match stmt.get('type'):
+ match statement.get('type'):
case 'ExpressionStatement':
- return self._handle_expression_stmt(stmt)
+ return self._handle_expression_stmt(statement)
case 'ReturnStatement':
- return self._handle_return_stmt(stmt)
+ return self._handle_return_stmt(statement)
return None
- def _handle_expression_stmt(self, stmt):
+ def _handle_expression_stmt(self, statement: dict) -> list | None:
"""Handle ExpressionStatement with logical or conditional."""
- expression = stmt.get('expression')
+ expression = statement.get('expression')
if not isinstance(expression, dict):
return None
match expression.get('type'):
@@ -81,9 +81,9 @@ def _handle_expression_stmt(self, stmt):
return self._ternary_to_if(expression)
return None
- def _handle_return_stmt(self, stmt):
+ def _handle_return_stmt(self, statement: dict) -> list | None:
"""Handle ReturnStatement with sequence or logical expressions."""
- argument = stmt.get('argument')
+ argument = statement.get('argument')
if not isinstance(argument, dict):
return None
@@ -97,49 +97,49 @@ def _handle_return_stmt(self, stmt):
return None
- def _split_return_sequence(self, seq):
+ def _split_return_sequence(self, sequence: dict) -> list | None:
"""Split return (a, b, c) into a; b; return c."""
- exprs = seq.get('expressions', [])
- if len(exprs) <= 1:
+ expressions = sequence.get('expressions', [])
+ if len(expressions) <= 1:
return None
- new_stmts = []
- for expression in exprs[:-1]:
+ new_statements = []
+ for expression in expressions[:-1]:
if isinstance(expression, dict) and expression.get('type') == 'LogicalExpression':
converted = self._logical_to_if(expression)
if converted:
- new_stmts.extend(converted)
+ new_statements.extend(converted)
continue
- new_stmts.append(make_expression_statement(expression))
- new_stmts.append({'type': 'ReturnStatement', 'argument': exprs[-1]})
- return new_stmts
+ new_statements.append(make_expression_statement(expression))
+ new_statements.append({'type': 'ReturnStatement', 'argument': expressions[-1]})
+ return new_statements
- def _split_return_logical(self, logical):
+ def _split_return_logical(self, logical: dict) -> list | None:
"""Split return a || (b(), c) into if (!a) { b(); } return c."""
right = logical.get('right')
if not (isinstance(right, dict) and right.get('type') == 'SequenceExpression'):
return None
- exprs = right.get('expressions', [])
- if len(exprs) <= 1:
+ expressions = right.get('expressions', [])
+ if len(expressions) <= 1:
return None
test = logical.get('left')
if logical.get('operator') == '||':
test = _negate(test)
- body_stmts = [make_expression_statement(e) for e in exprs[:-1]]
- if_stmt = {
+ body_statements = [make_expression_statement(expression) for expression in expressions[:-1]]
+ if_statement = {
'type': 'IfStatement',
'test': test,
- 'consequent': make_block_statement(body_stmts),
+ 'consequent': make_block_statement(body_statements),
'alternate': None,
}
- ret = {'type': 'ReturnStatement', 'argument': exprs[-1]}
- return [if_stmt, ret]
+ return_statement = {'type': 'ReturnStatement', 'argument': expressions[-1]}
+ return [if_statement, return_statement]
- def _logical_to_if(self, expr):
+ def _logical_to_if(self, expression: dict) -> list | None:
"""Convert a LogicalExpression to if-statement(s). Returns list of stmts or None."""
- left = expr.get('left')
- match expr.get('operator'):
+ left = expression.get('left')
+ match expression.get('operator'):
case '&&':
test = left
case '||':
@@ -147,27 +147,27 @@ def _logical_to_if(self, expr):
case _:
return None
- body_stmts = self._expr_to_stmts(expr.get('right'))
- if_stmt = {
+ body_statements = self._expr_to_stmts(expression.get('right'))
+ if_statement = {
'type': 'IfStatement',
'test': test,
- 'consequent': make_block_statement(body_stmts),
+ 'consequent': make_block_statement(body_statements),
'alternate': None,
}
- return [if_stmt]
+ return [if_statement]
- def _ternary_to_if(self, expr):
+ def _ternary_to_if(self, expression: dict) -> list:
"""Convert a ConditionalExpression to if-else. Returns list of stmts or None."""
- if_stmt = {
+ if_statement = {
'type': 'IfStatement',
- 'test': expr.get('test'),
- 'consequent': make_block_statement(self._expr_to_stmts(expr.get('consequent'))),
- 'alternate': make_block_statement(self._expr_to_stmts(expr.get('alternate'))),
+ 'test': expression.get('test'),
+ 'consequent': make_block_statement(self._expr_to_stmts(expression.get('consequent'))),
+ 'alternate': make_block_statement(self._expr_to_stmts(expression.get('alternate'))),
}
- return [if_stmt]
+ return [if_statement]
- def _expr_to_stmts(self, expr):
+ def _expr_to_stmts(self, expression: dict | None) -> list:
"""Convert an expression to a list of statements."""
- if isinstance(expr, dict) and expr.get('type') == 'SequenceExpression':
- return [make_expression_statement(e) for e in expr.get('expressions', [])]
- return [make_expression_statement(expr)]
+ if isinstance(expression, dict) and expression.get('type') == 'SequenceExpression':
+ return [make_expression_statement(e) for e in expression.get('expressions', [])]
+ return [make_expression_statement(expression)]
diff --git a/pyjsclear/transforms/member_chain_resolver.py b/pyjsclear/transforms/member_chain_resolver.py
index 206cd9f..dd610c4 100644
--- a/pyjsclear/transforms/member_chain_resolver.py
+++ b/pyjsclear/transforms/member_chain_resolver.py
@@ -22,7 +22,7 @@
from .base import Transform
-def _is_constant_expr(node):
+def _is_constant_expr(node: dict) -> bool:
"""Check if a node is a constant expression safe to inline."""
if not isinstance(node, dict):
return False
@@ -36,32 +36,46 @@ def _is_constant_expr(node):
return False
+def _get_property_name(member_expr: dict, property_key: str) -> str | None:
+ """Extract the string name of a member expression's property."""
+ prop = member_expr.get(property_key)
+ if not prop:
+ return None
+ if member_expr.get('computed'):
+ if not is_string_literal(prop):
+ return None
+ return prop['value']
+ if is_identifier(prop):
+ return prop['name']
+ return None
+
+
class MemberChainResolver(Transform):
"""Resolve multi-level member chains (A.B.C) to literal values."""
- def execute(self):
- # Maps: (class_name, prop_name) → AST node (constant expression)
- class_constants = {}
- # Maps: prop_name → class_name (from X.prop = ClassIdentifier assignments)
- prop_to_class = {}
+ def execute(self) -> bool:
+ # Maps: (class_name, property_name) → AST node (constant expression)
+ class_constants: dict[tuple[str, str], dict] = {}
+ # Maps: property_name → class_name (from X.prop = ClassIdentifier assignments)
+ property_to_class: dict[str, str] = {}
# Phase 1: Collect X.prop = constant_expr and X.prop = Identifier assignments
- def collect(node, parent):
+ def collect(node: dict, parent: dict | None) -> None:
if node.get('type') != 'AssignmentExpression':
return
if node.get('operator') != '=':
return
left = node.get('left')
right = node.get('right')
- obj_name, prop_name = get_member_names(left)
- if not obj_name:
+ object_name, property_name = get_member_names(left)
+ if not object_name:
return
if _is_constant_expr(right):
- class_constants[(obj_name, prop_name)] = right
+ class_constants[(object_name, property_name)] = right
elif is_identifier(right):
- # X.prop = SomeClass — record prop_name → SomeClass
- prop_to_class[prop_name] = right['name']
+ # X.prop = SomeClass — record property_name → SomeClass
+ property_to_class[property_name] = right['name']
simple_traverse(self.ast, collect)
@@ -69,9 +83,9 @@ def collect(node, parent):
return False
# Phase 1b: Invalidate constants that are reassigned through alias chains.
- # Pattern: A.B.C = expr where B → class_name via prop_to_class
+ # Pattern: A.B.C = expr where B → class_name via property_to_class
# means (class_name, C) is NOT a true constant.
- def invalidate_chain_assignments(node, parent):
+ def invalidate_chain_assignments(node: dict, parent: dict | None) -> None:
if node.get('type') != 'AssignmentExpression':
return
left = node.get('left')
@@ -80,34 +94,16 @@ def invalidate_chain_assignments(node, parent):
inner = left.get('object')
if not inner or inner.get('type') != 'MemberExpression':
return
- # Get C (outer property)
- outer_prop = left.get('property')
- if not outer_prop:
- return
- if left.get('computed'):
- if not is_string_literal(outer_prop):
- return
- c_name = outer_prop['value']
- elif is_identifier(outer_prop):
- c_name = outer_prop['name']
- else:
- return
- # Get B (middle property)
- inner_prop = inner.get('property')
- if not inner_prop:
+ outer_property_name = _get_property_name(left, 'property')
+ if outer_property_name is None:
return
- if inner.get('computed'):
- if not is_string_literal(inner_prop):
- return
- b_name = inner_prop['value']
- elif is_identifier(inner_prop):
- b_name = inner_prop['name']
- else:
+ middle_property_name = _get_property_name(inner, 'property')
+ if middle_property_name is None:
return
# If B resolves to a class, invalidate (class, C)
- class_name = prop_to_class.get(b_name)
- if class_name and (class_name, c_name) in class_constants:
- del class_constants[(class_name, c_name)]
+ class_name = property_to_class.get(middle_property_name)
+ if class_name and (class_name, outer_property_name) in class_constants:
+ del class_constants[(class_name, outer_property_name)]
simple_traverse(self.ast, invalidate_chain_assignments)
@@ -116,56 +112,38 @@ def invalidate_chain_assignments(node, parent):
# Phase 2: Replace A.B.C member chains where B resolves to a class
# and (class, C) maps to a constant expression
- def resolve(node, parent, key, index):
+ def resolve(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None:
if node.get('type') != 'MemberExpression':
- return
+ return None
# Skip assignment targets
if parent and parent.get('type') == 'AssignmentExpression' and key == 'left':
- return
+ return None
- # Get C (the outer property)
- outer_prop = node.get('property')
- if not outer_prop:
- return
- if node.get('computed'):
- if not is_string_literal(outer_prop):
- return
- c_name = outer_prop['value']
- elif is_identifier(outer_prop):
- c_name = outer_prop['name']
- else:
- return
+ outer_property_name = _get_property_name(node, 'property')
+ if outer_property_name is None:
+ return None
# Get the inner member expression (A.B)
inner = node.get('object')
if not inner or inner.get('type') != 'MemberExpression':
- return
+ return None
- # Get B (the middle property)
- inner_prop = inner.get('property')
- if not inner_prop:
- return
- if inner.get('computed'):
- if not is_string_literal(inner_prop):
- return
- b_name = inner_prop['value']
- elif is_identifier(inner_prop):
- b_name = inner_prop['name']
- else:
- return
+ middle_property_name = _get_property_name(inner, 'property')
+ if middle_property_name is None:
+ return None
# Resolve B → class_name
- class_name = prop_to_class.get(b_name)
+ class_name = property_to_class.get(middle_property_name)
if not class_name:
- return
+ return None
# Resolve (class_name, C) → constant expression
- const_node = class_constants.get((class_name, c_name))
- if const_node is None:
- return
+ constant_node = class_constants.get((class_name, outer_property_name))
+ if constant_node is None:
+ return None
self.set_changed()
- return deep_copy(const_node)
+ return deep_copy(constant_node)
traverse(self.ast, {'enter': resolve})
return self.has_changed()
diff --git a/pyjsclear/transforms/noop_calls.py b/pyjsclear/transforms/noop_calls.py
index 7177442..a54e6ef 100644
--- a/pyjsclear/transforms/noop_calls.py
+++ b/pyjsclear/transforms/noop_calls.py
@@ -7,6 +7,8 @@
obj.methodName('...'); // removed
"""
+from typing import Any
+
from ..traverser import REMOVE
from ..traverser import simple_traverse
from ..traverser import traverse
@@ -17,11 +19,11 @@
class NoopCallRemover(Transform):
"""Remove expression-statement calls to no-op methods."""
- def execute(self):
+ def execute(self) -> bool:
# Phase 1: Find no-op methods (empty body or just 'return;')
- noop_methods = set()
+ noop_methods: set[str] = set()
- def find_noops(node, parent):
+ def find_noops(node: dict, parent: dict | None) -> None:
if node.get('type') != 'MethodDefinition':
return
if node.get('kind') not in ('method', None):
@@ -29,21 +31,21 @@ def find_noops(node, parent):
key = node.get('key')
if not key or not is_identifier(key):
return
- fn = node.get('value')
- if not fn or fn.get('type') != 'FunctionExpression':
+ function_expression = node.get('value')
+ if not function_expression or function_expression.get('type') != 'FunctionExpression':
return
- # Check if async — async no-op still returns a promise, skip
- if fn.get('async'):
+ # Async no-op still returns a promise, skip
+ if function_expression.get('async'):
return
- body = fn.get('body')
+ body = function_expression.get('body')
if not body or body.get('type') != 'BlockStatement':
return
- stmts = body.get('body', [])
- if len(stmts) == 0:
+ statements = body.get('body', [])
+ if not statements:
noop_methods.add(key['name'])
- elif len(stmts) == 1:
- stmt = stmts[0]
- if stmt.get('type') == 'ReturnStatement' and stmt.get('argument') is None:
+ elif len(statements) == 1:
+ statement = statements[0]
+ if statement.get('type') == 'ReturnStatement' and statement.get('argument') is None:
noop_methods.add(key['name'])
simple_traverse(self.ast, find_noops)
@@ -52,13 +54,13 @@ def find_noops(node, parent):
return False
# Phase 2: Remove ExpressionStatement calls to no-op methods
- def remove_calls(node, parent, key, index):
+ def remove_calls(node: dict, parent: dict | None, key: str | None, index: int | None) -> Any:
if node.get('type') != 'ExpressionStatement':
return
- expr = node.get('expression')
- if not expr or expr.get('type') not in ('CallExpression', 'AwaitExpression'):
+ expression = node.get('expression')
+ if not expression or expression.get('type') not in ('CallExpression', 'AwaitExpression'):
return
- call = expr
+ call = expression
if call.get('type') == 'AwaitExpression':
call = call.get('argument')
if not call or call.get('type') != 'CallExpression':
@@ -66,10 +68,10 @@ def remove_calls(node, parent, key, index):
callee = call.get('callee')
if not callee or callee.get('type') != 'MemberExpression':
return
- prop = callee.get('property')
- if not prop or not is_identifier(prop):
+ property_node = callee.get('property')
+ if not property_node or not is_identifier(property_node):
return
- if prop['name'] in noop_methods:
+ if property_node['name'] in noop_methods:
self.set_changed()
return REMOVE
diff --git a/pyjsclear/transforms/nullish_coalescing.py b/pyjsclear/transforms/nullish_coalescing.py
index f53d58a..91edcb9 100644
--- a/pyjsclear/transforms/nullish_coalescing.py
+++ b/pyjsclear/transforms/nullish_coalescing.py
@@ -17,36 +17,47 @@
class NullishCoalescing(Transform):
"""Convert nullish check patterns to ?? operator."""
- def execute(self):
- def enter(node, parent, key, index):
- if node.get('type') != 'ConditionalExpression':
- return
- test = node.get('test')
- if not isinstance(test, dict) or test.get('type') != 'LogicalExpression' or test.get('operator') != '&&':
- return
-
- left_cmp = test.get('left')
- right_cmp = test.get('right')
- if not isinstance(left_cmp, dict) or not isinstance(right_cmp, dict):
- return
-
- # Pattern: X !== null && X !== undefined ? X : default
- # Where X might be (_0x = value) or just an identifier
- result = self._match_nullish_pattern(left_cmp, right_cmp, node.get('consequent'), node.get('alternate'))
- if result:
- self.set_changed()
- return result
-
- # Also handle reversed order: X !== undefined && X !== null ? X : default
- result = self._match_nullish_pattern(right_cmp, left_cmp, node.get('consequent'), node.get('alternate'))
- if result:
- self.set_changed()
- return result
-
- traverse(self.ast, {'enter': enter})
+ def execute(self) -> bool:
+ traverse(self.ast, {'enter': self._enter})
return self.has_changed()
- def _match_nullish_pattern(self, null_check, undef_check, consequent, alternate):
+ def _enter(self, node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None:
+ if node.get('type') != 'ConditionalExpression':
+ return None
+ test = node.get('test')
+ if not isinstance(test, dict) or test.get('type') != 'LogicalExpression' or test.get('operator') != '&&':
+ return None
+
+ left_cmp = test.get('left')
+ right_cmp = test.get('right')
+ if not isinstance(left_cmp, dict) or not isinstance(right_cmp, dict):
+ return None
+
+ consequent = node.get('consequent')
+ alternate = node.get('alternate')
+
+ # Pattern: X !== null && X !== undefined ? X : default
+ # Where X might be (_0x = value) or just an identifier
+ result = self._match_nullish_pattern(left_cmp, right_cmp, consequent, alternate)
+ if result:
+ self.set_changed()
+ return result
+
+ # Also handle reversed order: X !== undefined && X !== null ? X : default
+ result = self._match_nullish_pattern(right_cmp, left_cmp, consequent, alternate)
+ if result:
+ self.set_changed()
+ return result
+
+ return None
+
+ def _match_nullish_pattern(
+ self,
+ null_check: dict,
+ undef_check: dict,
+ consequent: dict | None,
+ alternate: dict | None,
+ ) -> dict | None:
"""Try to match the pattern and return a ?? node if successful."""
# null_check: X !== null (or (tmp = value) !== null)
if null_check.get('type') != 'BinaryExpression' or null_check.get('operator') != '!==':
diff --git a/pyjsclear/transforms/object_packer.py b/pyjsclear/transforms/object_packer.py
index 170f10d..329dd99 100644
--- a/pyjsclear/transforms/object_packer.py
+++ b/pyjsclear/transforms/object_packer.py
@@ -12,11 +12,11 @@
class ObjectPacker(Transform):
"""Pack sequential property assignments into object initializers."""
- def execute(self):
+ def execute(self) -> bool:
self._process_bodies(self.ast)
return self.has_changed()
- def _process_bodies(self, node):
+ def _process_bodies(self, node: dict) -> None:
"""Recursively find body arrays and try packing."""
if not isinstance(node, dict):
return
@@ -30,7 +30,7 @@ def _process_bodies(self, node):
self._process_bodies(child)
@staticmethod
- def _find_empty_object_declaration(stmt):
+ def _find_empty_object_declaration(stmt: dict) -> tuple[str, dict, dict] | None:
"""Find an empty object literal in a VariableDeclaration.
Returns (name, declarator, object_expression) or None.
@@ -48,7 +48,7 @@ def _find_empty_object_declaration(stmt):
return declaration['id']['name'], declaration, initializer
return None
- def _try_pack_body(self, body):
+ def _try_pack_body(self, body: list) -> None:
"""Find empty object declarations followed by property assignments and pack them."""
i = 0
while i < len(body):
@@ -62,7 +62,7 @@ def _try_pack_body(self, body):
i += 1
continue
- obj_name, obj_decl, obj_expr = found
+ object_name, obj_decl, obj_expr = found
# Collect consecutive property assignments
assignments = []
@@ -78,18 +78,16 @@ def _try_pack_body(self, body):
if not left or left.get('type') != 'MemberExpression':
break
object_reference = left.get('object')
- if not is_identifier(object_reference) or object_reference.get('name') != obj_name:
+ if not is_identifier(object_reference) or object_reference.get('name') != object_name:
break
property_node = left.get('property')
right = expr.get('right')
- # Get property key
if property_node is None:
break
- # Support both computed and non-computed property keys
property_key = property_node
# Don't pack self-referential assignments (o.x = o.y)
- if self._references_name(right, obj_name):
+ if self._references_name(right, object_name):
break
computed = left.get('computed', False)
@@ -99,7 +97,7 @@ def _try_pack_body(self, body):
if assignments:
# Pack into the object literal
for property_key, value, computed in assignments:
- property_node = {
+ new_property = {
'type': 'Property',
'key': property_key,
'value': value,
@@ -108,14 +106,14 @@ def _try_pack_body(self, body):
'shorthand': False,
'computed': computed,
}
- obj_expr['properties'].append(property_node)
+ obj_expr['properties'].append(new_property)
# Remove the assignment statements
- del body[i + 1 : j]
+ del body[i + 1:j]
self.set_changed()
i += 1
- def _references_name(self, node, name):
+ def _references_name(self, node: dict, name: str) -> bool:
"""Check if a node references a given identifier name."""
if not isinstance(node, dict) or 'type' not in node:
return False
diff --git a/pyjsclear/transforms/object_simplifier.py b/pyjsclear/transforms/object_simplifier.py
index 7c71e18..7da9432 100644
--- a/pyjsclear/transforms/object_simplifier.py
+++ b/pyjsclear/transforms/object_simplifier.py
@@ -5,7 +5,6 @@
"""
from ..scope import build_scope_tree
-from ..traverser import find_parent
from ..utils.ast_helpers import deep_copy
from ..utils.ast_helpers import is_literal
from ..utils.ast_helpers import is_string_literal
@@ -18,12 +17,15 @@ class ObjectSimplifier(Transform):
rebuild_scope = True
- def execute(self):
- scope_tree, _ = build_scope_tree(self.ast)
+ def execute(self) -> bool:
+ if self.scope_tree is not None:
+ scope_tree = self.scope_tree
+ else:
+ scope_tree, _ = build_scope_tree(self.ast)
self._process_scope(scope_tree)
return self.has_changed()
- def _process_scope(self, scope):
+ def _process_scope(self, scope) -> None:
for name, binding in list(scope.bindings.items()):
if not binding.is_constant:
continue
@@ -39,21 +41,21 @@ def _process_scope(self, scope):
if not self._is_proxy_object(properties):
continue
- prop_map = {}
+ property_map = {}
for property_node in properties:
key = self._get_property_key(property_node)
if key is None:
continue
value = property_node.get('value')
if is_literal(value):
- prop_map[key] = value
+ property_map[key] = value
elif value and value.get('type') in (
'FunctionExpression',
'ArrowFunctionExpression',
):
- prop_map[key] = value
+ property_map[key] = value
- if not prop_map:
+ if not property_map:
continue
if self._has_property_assignment(binding):
@@ -68,10 +70,10 @@ def _process_scope(self, scope):
member_expression = ref_parent
property_name = self._get_member_prop_name(member_expression)
- if property_name is None or property_name not in prop_map:
+ if property_name is None or property_name not in property_map:
continue
- value = prop_map[property_name]
+ value = property_map[property_name]
if is_literal(value):
if self._replace_node(member_expression, deep_copy(value)):
self.set_changed()
@@ -87,25 +89,25 @@ def _process_scope(self, scope):
for child in scope.children:
self._process_scope(child)
- def _has_property_assignment(self, binding):
+ def _has_property_assignment(self, binding) -> bool:
"""Check if any reference to the binding is a property assignment target."""
for reference_node, reference_parent, ref_key, ref_index in binding.references:
if not (reference_parent and reference_parent.get('type') == 'MemberExpression' and ref_key == 'object'):
continue
- me_parent_info = find_parent(self.ast, reference_parent)
- if not me_parent_info:
+ member_expression_parent_info = self.find_parent(reference_parent)
+ if not member_expression_parent_info:
continue
- parent, key, _ = me_parent_info
+ parent, key, _ = member_expression_parent_info
if parent and parent.get('type') == 'AssignmentExpression' and key == 'left':
return True
return False
- def _try_inline_function_call(self, member_expression, function_value):
+ def _try_inline_function_call(self, member_expression, function_value) -> None:
"""Try to inline a function call at a MemberExpression site."""
- me_parent_info = find_parent(self.ast, member_expression)
- if not me_parent_info:
+ member_expression_parent_info = self.find_parent(member_expression)
+ if not member_expression_parent_info:
return
- parent, key, _ = me_parent_info
+ parent, key, _ = member_expression_parent_info
if not (parent and parent.get('type') == 'CallExpression' and key == 'callee'):
return
replacement = self._inline_func(function_value, parent.get('arguments', []))
@@ -114,24 +116,24 @@ def _try_inline_function_call(self, member_expression, function_value):
if self._replace_node(parent, replacement):
self.set_changed()
- def _is_proxy_object(self, properties):
+ def _is_proxy_object(self, properties: list) -> bool:
"""Check if all properties are literals or simple functions."""
- for p in properties:
- if p.get('type') != 'Property':
+ for property_node in properties:
+ if property_node.get('type') != 'Property':
return False
- val = p.get('value')
- if not val:
+ value = property_node.get('value')
+ if not value:
return False
- if is_literal(val):
+ if is_literal(value):
continue
- if val.get('type') in ('FunctionExpression', 'ArrowFunctionExpression'):
+ if value.get('type') in ('FunctionExpression', 'ArrowFunctionExpression'):
continue
return False
return True
- def _get_property_key(self, prop):
+ def _get_property_key(self, property_node) -> str | None:
"""Get the string key of a property."""
- key = prop.get('key')
+ key = property_node.get('key')
if not key:
return None
match key.get('type'):
@@ -141,12 +143,12 @@ def _get_property_key(self, prop):
return key['value']
return None
- def _get_member_prop_name(self, member_expr):
+ def _get_member_prop_name(self, member_expression) -> str | None:
"""Get property name from a member expression."""
- prop = member_expr.get('property')
+ prop = member_expression.get('property')
if not prop:
return None
- if member_expr.get('computed'):
+ if member_expression.get('computed'):
if is_string_literal(prop):
return prop['value']
return None
@@ -154,41 +156,42 @@ def _get_member_prop_name(self, member_expr):
return prop['name']
return None
- def _replace_node(self, target, replacement):
+ def _replace_node(self, target, replacement) -> bool:
"""Replace target node in the AST. Returns True if replaced."""
- result = find_parent(self.ast, target)
+ result = self.find_parent(target)
if result:
parent, key, index = result
if index is not None:
parent[key][index] = replacement
else:
parent[key] = replacement
+ self.invalidate_parent_map()
return True
return False
- def _inline_func(self, func, args):
+ def _inline_func(self, function_node, arguments: list):
"""Inline a simple function call."""
- body = func.get('body')
+ body = function_node.get('body')
if not body:
return None
- if func.get('type') == 'ArrowFunctionExpression' and body.get('type') != 'BlockStatement':
+ if function_node.get('type') == 'ArrowFunctionExpression' and body.get('type') != 'BlockStatement':
expr = deep_copy(body)
elif body.get('type') == 'BlockStatement':
- stmts = body.get('body', [])
- if len(stmts) != 1 or stmts[0].get('type') != 'ReturnStatement':
+ statements = body.get('body', [])
+ if len(statements) != 1 or statements[0].get('type') != 'ReturnStatement':
return None
- argument = stmts[0].get('argument')
+ argument = statements[0].get('argument')
if not argument:
return None
expr = deep_copy(argument)
else:
return None
- params = func.get('params', [])
+ params = function_node.get('params', [])
param_map = {}
- for i, parameter in enumerate(params):
+ for index, parameter in enumerate(params):
if parameter.get('type') == 'Identifier':
- param_map[parameter['name']] = args[i] if i < len(args) else {'type': 'Identifier', 'name': 'undefined'}
+ param_map[parameter['name']] = arguments[index] if index < len(arguments) else {'type': 'Identifier', 'name': 'undefined'}
replace_identifiers(expr, param_map)
return expr
diff --git a/pyjsclear/transforms/optional_chaining.py b/pyjsclear/transforms/optional_chaining.py
index 43edba5..fcead51 100644
--- a/pyjsclear/transforms/optional_chaining.py
+++ b/pyjsclear/transforms/optional_chaining.py
@@ -18,19 +18,19 @@
from .base import Transform
-def _nodes_match(a, b):
+def _nodes_match(node_a: object, node_b: object) -> bool:
"""Check if two AST nodes are structurally equivalent (shallow)."""
- if not isinstance(a, dict) or not isinstance(b, dict):
+ if not isinstance(node_a, dict) or not isinstance(node_b, dict):
return False
- if a.get('type') != b.get('type'):
+ if node_a.get('type') != node_b.get('type'):
return False
- if a.get('type') == 'Identifier':
- return a.get('name') == b.get('name')
- if a.get('type') == 'MemberExpression':
+ if node_a.get('type') == 'Identifier':
+ return node_a.get('name') == node_b.get('name')
+ if node_a.get('type') == 'MemberExpression':
return (
- _nodes_match(a.get('object'), b.get('object'))
- and _nodes_match(a.get('property'), b.get('property'))
- and a.get('computed') == b.get('computed')
+ _nodes_match(node_a.get('object'), node_b.get('object'))
+ and _nodes_match(node_a.get('property'), node_b.get('property'))
+ and node_a.get('computed') == node_b.get('computed')
)
return False
@@ -38,64 +38,65 @@ def _nodes_match(a, b):
class OptionalChaining(Transform):
"""Convert nullish check patterns to ?. operator."""
- def execute(self):
- def enter(node, parent, key, index):
+ def execute(self) -> bool:
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> dict | None:
if node.get('type') != 'ConditionalExpression':
- return
+ return None
test = node.get('test')
if not isinstance(test, dict) or test.get('type') != 'LogicalExpression' or test.get('operator') != '||':
- return
+ return None
alternate = node.get('alternate')
consequent = node.get('consequent')
# consequent must be undefined/void 0
if not is_undefined(consequent):
- return
+ return None
result = self._match_optional_pattern(test.get('left'), test.get('right'), alternate)
if result:
self.set_changed()
return result
+ return None
traverse(self.ast, {'enter': enter})
return self.has_changed()
- def _match_optional_pattern(self, left_cmp, right_cmp, alternate):
+ def _match_optional_pattern(self, left_comparison: object, right_comparison: object, alternate: object) -> dict | None:
"""Try to match X === null || X === undefined ? undefined : X.prop."""
- if not isinstance(left_cmp, dict) or not isinstance(right_cmp, dict):
+ if not isinstance(left_comparison, dict) or not isinstance(right_comparison, dict):
return None
- if left_cmp.get('type') != 'BinaryExpression' or left_cmp.get('operator') != '===':
+ if left_comparison.get('type') != 'BinaryExpression' or left_comparison.get('operator') != '===':
return None
- if right_cmp.get('type') != 'BinaryExpression' or right_cmp.get('operator') != '===':
+ if right_comparison.get('type') != 'BinaryExpression' or right_comparison.get('operator') != '===':
return None
# Figure out which comparison has null and which has undefined
- checked_var = None
+ checked_variable = None
# Try: left has null, right has undefined
- for null_cmp, undef_cmp in [(left_cmp, right_cmp), (right_cmp, left_cmp)]:
- null_left, null_right = null_cmp.get('left'), null_cmp.get('right')
- undef_left, undef_right = undef_cmp.get('left'), undef_cmp.get('right')
+ for null_comparison, undefined_comparison in [(left_comparison, right_comparison), (right_comparison, left_comparison)]:
+ null_comparison_left, null_comparison_right = null_comparison.get('left'), null_comparison.get('right')
+ undefined_comparison_left, undefined_comparison_right = undefined_comparison.get('left'), undefined_comparison.get('right')
# X === null
- if is_null_literal(null_right):
- null_checked = null_left
- elif is_null_literal(null_left):
- null_checked = null_right
+ if is_null_literal(null_comparison_right):
+ null_checked = null_comparison_left
+ elif is_null_literal(null_comparison_left):
+ null_checked = null_comparison_right
else:
continue
# X === undefined
- if is_undefined(undef_right):
- undef_checked = undef_left
- elif is_undefined(undef_left):
- undef_checked = undef_right
+ if is_undefined(undefined_comparison_right):
+ undefined_checked = undefined_comparison_left
+ elif is_undefined(undefined_comparison_left):
+ undefined_checked = undefined_comparison_right
else:
continue
# Case 1: Simple - both check the same identifier
- if _nodes_match(null_checked, undef_checked):
- checked_var = null_checked
+ if _nodes_match(null_checked, undefined_checked):
+ checked_variable = null_checked
break
# Case 2: Temp assignment - (_tmp = expr) === null || _tmp === undefined
@@ -106,31 +107,31 @@ def _match_optional_pattern(self, left_cmp, right_cmp, alternate):
):
tmp_var = null_checked.get('left')
value_expr = null_checked.get('right')
- if identifiers_match(tmp_var, undef_checked):
+ if identifiers_match(tmp_var, undefined_checked):
# The alternate should use tmp_var as the object
- checked_var = tmp_var
+ checked_variable = tmp_var
# We'll replace tmp_var references in alternate with value_expr
- return self._build_optional_chain(value_expr, checked_var, alternate)
+ return self._build_optional_chain(value_expr, checked_variable, alternate)
- if checked_var is None:
+ if checked_variable is None:
return None
- return self._build_optional_chain(checked_var, checked_var, alternate)
+ return self._build_optional_chain(checked_variable, checked_variable, alternate)
- def _build_optional_chain(self, base_expr, checked_var, alternate):
+ def _build_optional_chain(self, base_expr: object, checked_variable: object, alternate: object) -> dict | None:
"""Build an optional chain node: base_expr?.something.
base_expr: the actual expression to use as the object
- checked_var: the variable that was null-checked (may differ from base_expr for temp assignments)
- alternate: the expression that accesses checked_var.prop (or deeper)
+ checked_variable: the variable that was null-checked (may differ from base_expr for temp assignments)
+ alternate: the expression that accesses checked_variable.prop (or deeper)
"""
if not isinstance(alternate, dict):
return None
- # alternate should be a MemberExpression or CallExpression whose object matches checked_var
+ # alternate should be a MemberExpression or CallExpression whose object matches checked_variable
if alternate.get('type') == 'MemberExpression':
obj = alternate.get('object')
- if _nodes_match(obj, checked_var):
+ if _nodes_match(obj, checked_variable):
return {
'type': 'MemberExpression',
'object': base_expr,
@@ -141,7 +142,7 @@ def _build_optional_chain(self, base_expr, checked_var, alternate):
if alternate.get('type') == 'CallExpression':
callee = alternate.get('callee')
- if _nodes_match(callee, checked_var):
+ if _nodes_match(callee, checked_variable):
return {
'type': 'CallExpression',
'callee': base_expr,
diff --git a/pyjsclear/transforms/property_simplifier.py b/pyjsclear/transforms/property_simplifier.py
index e5fe9db..e5df400 100644
--- a/pyjsclear/transforms/property_simplifier.py
+++ b/pyjsclear/transforms/property_simplifier.py
@@ -9,8 +9,8 @@
class PropertySimplifier(Transform):
"""Simplify obj["prop"] to obj.prop when prop is a valid identifier."""
- def execute(self):
- def enter(node, parent, key, index):
+ def execute(self) -> bool:
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'MemberExpression':
return
if not node.get('computed'):
@@ -29,7 +29,7 @@ def enter(node, parent, key, index):
traverse(self.ast, {'enter': enter})
# Also simplify computed property keys in object literals
- def enter_obj(node, parent, key, index):
+ def enter_obj(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'Property':
return
key_node = node.get('key')
@@ -51,7 +51,7 @@ def enter_obj(node, parent, key, index):
# Simplify method definitions with string literal keys:
# static ["name"]() → static name()
# Also handles cases where parser sets computed=False but key is still a Literal
- def enter_method(node, parent, key, index):
+ def enter_method(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'MethodDefinition':
return
key_node = node.get('key')
diff --git a/pyjsclear/transforms/proxy_functions.py b/pyjsclear/transforms/proxy_functions.py
index 5c0a1c0..a34d5f9 100644
--- a/pyjsclear/transforms/proxy_functions.py
+++ b/pyjsclear/transforms/proxy_functions.py
@@ -25,13 +25,16 @@ class ProxyFunctionInliner(Transform):
rebuild_scope = True
def execute(self):
- scope_tree, node_scope = build_scope_tree(self.ast)
+ if self.scope_tree is not None:
+ scope_tree, node_scope = self.scope_tree, self.node_scope
+ else:
+ scope_tree, node_scope = build_scope_tree(self.ast)
# Find proxy functions
- proxy_fns = {} # name -> (func_node, scope, binding)
- self._find_proxy_functions(scope_tree, proxy_fns)
+ proxy_functions = {} # name -> (func_node, scope, binding)
+ self._find_proxy_functions(scope_tree, proxy_functions)
- if not proxy_fns:
+ if not proxy_functions:
return False
# Collect call sites with depth info
@@ -46,33 +49,33 @@ def enter(node, parent, key, index):
if not is_identifier(callee):
return
name = callee.get('name', '')
- if name not in proxy_fns:
+ if name not in proxy_functions:
return
- call_sites.append((node, parent, key, index, proxy_fns[name], depth_counter[0]))
+ call_sites.append((node, parent, key, index, proxy_functions[name], depth_counter[0]))
traverse(self.ast, {'enter': enter})
# Skip helper functions: many call sites + conditional body = not a true proxy
call_counts = {}
- for cs in call_sites:
- fn_id = id(cs[4][0]) # func_node
- call_counts[fn_id] = call_counts.get(fn_id, 0) + 1
+ for call_site in call_sites:
+ function_node_id = id(call_site[4][0]) # func_node
+ call_counts[function_node_id] = call_counts.get(function_node_id, 0) + 1
def _has_conditional(node):
found = [False]
- def cb(n, parent):
+ def check_node(n, parent):
if n.get('type') == 'ConditionalExpression':
found[0] = True
- simple_traverse(node, cb)
+ simple_traverse(node, check_node)
return found[0]
- helper_fn_ids = set()
- for name, (func_node, _, _) in proxy_fns.items():
+ helper_function_ids = set()
+ for name, (func_node, _, _) in proxy_functions.items():
if call_counts.get(id(func_node), 0) > 3 and _has_conditional(func_node):
- helper_fn_ids.add(id(func_node))
- call_sites = [cs for cs in call_sites if id(cs[4][0]) not in helper_fn_ids]
+ helper_function_ids.add(id(func_node))
+ call_sites = [call_site for call_site in call_sites if id(call_site[4][0]) not in helper_function_ids]
# Process innermost calls first
call_sites.sort(key=lambda x: x[5], reverse=True)
@@ -147,18 +150,18 @@ def _is_proxy_function(self, func_node):
# Block with single return
if body.get('type') == 'BlockStatement':
- stmts = body.get('body', [])
- if len(stmts) != 1:
+ statements = body.get('body', [])
+ if len(statements) != 1:
return False
- stmt = stmts[0]
+ stmt = statements[0]
if stmt.get('type') != 'ReturnStatement':
return False
- arg = stmt.get('argument')
- if arg is None:
+ argument = stmt.get('argument')
+ if argument is None:
return True # returns undefined
- if not self._is_proxy_value(arg):
+ if not self._is_proxy_value(argument):
return False
- return self._count_nodes(arg) <= _MAX_PROXY_BODY_NODES
+ return self._count_nodes(argument) <= _MAX_PROXY_BODY_NODES
return False
@@ -167,10 +170,10 @@ def _count_nodes(node):
"""Count AST nodes in a subtree."""
count = [0]
- def cb(n, parent):
+ def increment_count(n, parent):
count[0] += 1
- simple_traverse(node, cb)
+ simple_traverse(node, increment_count)
return count[0]
_DISALLOWED_PROXY_TYPES = frozenset(
@@ -211,25 +214,25 @@ def _get_replacement(self, func_node, args):
if func_node.get('type') == 'ArrowFunctionExpression' and body.get('type') != 'BlockStatement':
expr = deep_copy(body)
elif body.get('type') == 'BlockStatement':
- stmts = body.get('body', [])
- if not stmts or stmts[0].get('type') != 'ReturnStatement':
+ statements = body.get('body', [])
+ if not statements or statements[0].get('type') != 'ReturnStatement':
return None
- arg = stmts[0].get('argument')
- if arg is None:
+ argument = statements[0].get('argument')
+ if argument is None:
return {'type': 'Identifier', 'name': 'undefined'}
- expr = deep_copy(arg)
+ expr = deep_copy(argument)
else:
return None
# Build parameter map
params = func_node.get('params', [])
- param_map = {}
- for i, parameter in enumerate(params):
+ parameter_map = {}
+ for index, parameter in enumerate(params):
if parameter.get('type') == 'Identifier':
- if i < len(args):
- param_map[parameter['name']] = args[i]
+ if index < len(args):
+ parameter_map[parameter['name']] = args[index]
else:
- param_map[parameter['name']] = {'type': 'Identifier', 'name': 'undefined'}
+ parameter_map[parameter['name']] = {'type': 'Identifier', 'name': 'undefined'}
- replace_identifiers(expr, param_map)
+ replace_identifiers(expr, parameter_map)
return expr
diff --git a/pyjsclear/transforms/reassignment.py b/pyjsclear/transforms/reassignment.py
index 555a8f5..e84afbc 100644
--- a/pyjsclear/transforms/reassignment.py
+++ b/pyjsclear/transforms/reassignment.py
@@ -4,12 +4,19 @@
And replaces all references to x with y, then removes x.
"""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
from ..scope import build_scope_tree
from ..traverser import REMOVE
from ..traverser import traverse
from ..utils.ast_helpers import is_identifier
from .base import Transform
+if TYPE_CHECKING:
+ from ..scope import Scope
+
class ReassignmentRemover(Transform):
"""Remove redundant reassignments like x = y where y is used identically."""
@@ -51,13 +58,16 @@ class ReassignmentRemover(Transform):
rebuild_scope = True
- def execute(self):
- scope_tree, _ = build_scope_tree(self.ast)
+ def execute(self) -> bool:
+ if self.scope_tree is not None:
+ scope_tree = self.scope_tree
+ else:
+ scope_tree, _ = build_scope_tree(self.ast)
self._process_scope(scope_tree)
self._inline_assignment_aliases(scope_tree)
return self.has_changed()
- def _process_scope(self, scope):
+ def _process_scope(self, scope: Scope) -> None:
for name, binding in list(scope.bindings.items()):
if not binding.is_constant:
continue
@@ -69,8 +79,8 @@ def _process_scope(self, scope):
continue
# Skip destructuring patterns — id must be a simple Identifier
- decl_id = node.get('id')
- if not decl_id or decl_id.get('type') != 'Identifier':
+ declaration_id = node.get('id')
+ if not declaration_id or declaration_id.get('type') != 'Identifier':
continue
initializer = node.get('init')
@@ -80,12 +90,12 @@ def _process_scope(self, scope):
target_name = initializer.get('name', '')
if target_name == name:
continue
+
target_binding = scope.get_binding(target_name)
# Allow inlining if target is a well-known global or a constant binding
- if target_binding:
- if not target_binding.is_constant:
- continue
- elif target_name not in self._WELL_KNOWN_GLOBALS:
+ if target_binding and not target_binding.is_constant:
+ continue
+ if not target_binding and target_name not in self._WELL_KNOWN_GLOBALS:
continue
# Replace all references to `name` with `target_name`
@@ -108,7 +118,7 @@ def _process_scope(self, scope):
for child in scope.children:
self._process_scope(child)
- def _inline_assignment_aliases(self, scope_tree):
+ def _inline_assignment_aliases(self, scope_tree: Scope) -> None:
"""Inline aliases created by `var x; ... x = y;` patterns.
Handles the obfuscator pattern where a variable is declared without
@@ -116,17 +126,17 @@ def _inline_assignment_aliases(self, scope_tree):
"""
self._process_assignment_aliases(scope_tree)
- def _remove_assignment_statement(self, assignment_node):
+ def _remove_assignment_statement(self, assignment_node: dict) -> None:
"""Remove the ExpressionStatement containing the given assignment expression."""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object:
if node.get('type') == 'ExpressionStatement' and node.get('expression') is assignment_node:
self.set_changed()
return REMOVE
traverse(self.ast, {'enter': enter})
- def _process_assignment_aliases(self, scope):
+ def _process_assignment_aliases(self, scope: Scope) -> None:
for name, binding in list(scope.bindings.items()):
if binding.is_constant or binding.kind == 'param':
continue
@@ -167,10 +177,9 @@ def _process_assignment_aliases(self, scope):
# The target must be constant or a well-known global
target_binding = scope.get_binding(target_name)
- if target_binding:
- if not target_binding.is_constant:
- continue
- elif target_name not in self._WELL_KNOWN_GLOBALS:
+ if target_binding and not target_binding.is_constant:
+ continue
+ if not target_binding and target_name not in self._WELL_KNOWN_GLOBALS:
continue
# Replace all reads of `name` with `target_name`
diff --git a/pyjsclear/transforms/require_inliner.py b/pyjsclear/transforms/require_inliner.py
index b1c7431..7611c8e 100644
--- a/pyjsclear/transforms/require_inliner.py
+++ b/pyjsclear/transforms/require_inliner.py
@@ -17,26 +17,24 @@
class RequireInliner(Transform):
"""Replace require polyfill calls with direct require() calls."""
- def execute(self):
- polyfill_names = set()
+ def execute(self) -> bool:
+ polyfill_names: set[str] = set()
# Phase 1: Detect require polyfill pattern.
# We look for: var X = (...)(function(Y) { ... require ... })
# Heuristic: a VariableDeclarator whose init is a CallExpression,
# and somewhere in its body there's `typeof require !== "undefined" ? require : ...`
- def find_polyfills(node, parent):
+ def find_polyfills(node: dict, parent: dict | None) -> None:
if node.get('type') != 'VariableDeclarator':
return
- decl_id = node.get('id')
+ declaration_id = node.get('id')
init = node.get('init')
- if not is_identifier(decl_id):
+ if not is_identifier(declaration_id):
return
if not init or init.get('type') != 'CallExpression':
return
-
- # Check if the init tree contains `typeof require`
if self._contains_typeof_require(init):
- polyfill_names.add(decl_id['name'])
+ polyfill_names.add(declaration_id['name'])
simple_traverse(self.ast, find_polyfills)
@@ -44,7 +42,7 @@ def find_polyfills(node, parent):
return False
# Phase 2: Replace _0x544bfe(X) with require(X)
- def replace_calls(node, parent, key, index):
+ def replace_calls(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'CallExpression':
return
callee = node.get('callee')
@@ -55,26 +53,24 @@ def replace_calls(node, parent, key, index):
args = node.get('arguments', [])
if len(args) != 1:
return
-
- # Replace the callee with require
node['callee'] = make_identifier('require')
self.set_changed()
traverse(self.ast, {'enter': replace_calls})
return self.has_changed()
- def _contains_typeof_require(self, node):
+ def _contains_typeof_require(self, node: dict) -> bool:
"""Check if a subtree contains `typeof require`."""
found = [False]
- def scan(n, parent):
+ def scan(current_node: dict, parent: dict | None) -> None:
if found[0]:
return
- if not isinstance(n, dict):
+ if not isinstance(current_node, dict):
return
- if n.get('type') == 'UnaryExpression' and n.get('operator') == 'typeof':
- arg = n.get('argument')
- if is_identifier(arg) and arg.get('name') == 'require':
+ if current_node.get('type') == 'UnaryExpression' and current_node.get('operator') == 'typeof':
+ argument = current_node.get('argument')
+ if is_identifier(argument) and argument.get('name') == 'require':
found[0] = True
simple_traverse(node, scan)
diff --git a/pyjsclear/transforms/sequence_splitter.py b/pyjsclear/transforms/sequence_splitter.py
index 8878210..3478c60 100644
--- a/pyjsclear/transforms/sequence_splitter.py
+++ b/pyjsclear/transforms/sequence_splitter.py
@@ -14,15 +14,15 @@
class SequenceSplitter(Transform):
"""Split sequence expressions and normalize control flow bodies."""
- def execute(self):
+ def execute(self) -> bool:
self._normalize_bodies(self.ast)
self._split_in_body_arrays(self.ast)
return self.has_changed()
- def _normalize_bodies(self, ast):
+ def _normalize_bodies(self, ast: dict) -> None:
"""Ensure if/while/for bodies are BlockStatements."""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
node_type = node.get('type', '')
if node_type not in (
'IfStatement',
@@ -42,7 +42,7 @@ def enter(node, parent, key, index):
traverse(ast, {'enter': enter})
- def _normalize_if_branches(self, node):
+ def _normalize_if_branches(self, node: dict) -> None:
"""Wrap non-block consequent/alternate of IfStatement in BlockStatements."""
consequent = node.get('consequent')
if consequent and consequent.get('type') != 'BlockStatement':
@@ -53,7 +53,7 @@ def _normalize_if_branches(self, node):
node['alternate'] = make_block_statement([alternate])
self.set_changed()
- def _split_in_body_arrays(self, node):
+ def _split_in_body_arrays(self, node: dict) -> None:
"""Find all arrays that contain statements and split sequences + var decls in them."""
if not isinstance(node, dict):
return
@@ -62,13 +62,12 @@ def _split_in_body_arrays(self, node):
# Check if this looks like a statement array
if child and isinstance(child[0], dict) and 'type' in child[0]:
self._process_stmt_array(child)
- # Recurse into items
for item in child:
self._split_in_body_arrays(item)
elif isinstance(child, dict) and 'type' in child:
self._split_in_body_arrays(child)
- def _extract_indirect_call_prefixes(self, stmt):
+ def _extract_indirect_call_prefixes(self, statement: dict) -> list:
"""Extract dead prefix expressions from (0, fn)(args) patterns.
Only extracts from:
@@ -79,7 +78,7 @@ def _extract_indirect_call_prefixes(self, stmt):
"""
prefixes = []
- def extract_from_call(node):
+ def extract_from_call(node: dict | None) -> None:
"""If node is a CallExpression with SequenceExpression callee, extract prefixes."""
if not isinstance(node, dict):
return
@@ -91,45 +90,44 @@ def extract_from_call(node):
callee = target.get('callee')
if not isinstance(callee, dict) or callee.get('type') != 'SequenceExpression':
return
- exprs = callee.get('expressions', [])
- if len(exprs) <= 1:
+ expressions = callee.get('expressions', [])
+ if len(expressions) <= 1:
return
- prefixes.extend(exprs[:-1])
- target['callee'] = exprs[-1]
-
- stype = stmt.get('type', '')
- if stype == 'ExpressionStatement':
- extract_from_call(stmt.get('expression'))
- elif stype == 'VariableDeclaration':
- for d in stmt.get('declarations', []):
- extract_from_call(d.get('init'))
- elif stype == 'ReturnStatement':
- extract_from_call(stmt.get('argument'))
- # Also check assignment expressions in ExpressionStatements:
- # x = (0, fn)(args)
- if stype == 'ExpressionStatement':
- expr = stmt.get('expression')
- if isinstance(expr, dict) and expr.get('type') == 'AssignmentExpression':
- extract_from_call(expr.get('right'))
+ prefixes.extend(expressions[:-1])
+ target['callee'] = expressions[-1]
+
+ statement_type = statement.get('type', '')
+ match statement_type:
+ case 'ExpressionStatement':
+ extract_from_call(statement.get('expression'))
+ # Also check assignment: x = (0, fn)(args)
+ expression = statement.get('expression')
+ if isinstance(expression, dict) and expression.get('type') == 'AssignmentExpression':
+ extract_from_call(expression.get('right'))
+ case 'VariableDeclaration':
+ for declarator in statement.get('declarations', []):
+ extract_from_call(declarator.get('init'))
+ case 'ReturnStatement':
+ extract_from_call(statement.get('argument'))
return prefixes
- def _process_stmt_array(self, statements):
+ def _process_stmt_array(self, statements: list) -> None:
"""Split sequence expressions and multi-var declarations in a statement array."""
- i = 0
- while i < len(statements):
- statement = statements[i]
+ index = 0
+ while index < len(statements):
+ statement = statements[index]
if not isinstance(statement, dict):
- i += 1
+ index += 1
continue
# Extract dead prefix from indirect call patterns: (0, fn)(args) → 0; fn(args);
prefixes = self._extract_indirect_call_prefixes(statement)
if prefixes:
- new_stmts = [make_expression_statement(expression) for expression in prefixes]
- new_stmts.append(statement)
- statements[i : i + 1] = new_stmts
- i += len(new_stmts)
+ new_statements = [make_expression_statement(expression) for expression in prefixes]
+ new_statements.append(statement)
+ statements[index : index + 1] = new_statements
+ index += len(new_statements)
self.set_changed()
continue
@@ -141,44 +139,44 @@ def _process_stmt_array(self, statements):
):
expressions = statement['expression'].get('expressions', [])
if len(expressions) > 1:
- new_stmts = [make_expression_statement(expression) for expression in expressions]
- statements[i : i + 1] = new_stmts
- i += len(new_stmts)
+ new_statements = [make_expression_statement(expression) for expression in expressions]
+ statements[index : index + 1] = new_statements
+ index += len(new_statements)
self.set_changed()
continue
# Split multi-declarator VariableDeclaration
# (but not inside for-loop init — those aren't in body arrays)
if statement.get('type') == 'VariableDeclaration':
- decls = statement.get('declarations', [])
- if len(decls) > 1:
+ declarations = statement.get('declarations', [])
+ if len(declarations) > 1:
kind = statement.get('kind', 'var')
- new_stmts = [
+ new_statements = [
{
'type': 'VariableDeclaration',
'kind': kind,
'declarations': [declaration],
}
- for declaration in decls
+ for declaration in declarations
]
- statements[i : i + 1] = new_stmts
- i += len(new_stmts)
+ statements[index : index + 1] = new_statements
+ index += len(new_statements)
self.set_changed()
continue
# Split SequenceExpression in single declarator init
- if len(decls) == 1:
- split_result = self._try_split_single_declarator_init(statement, decls[0])
+ if len(declarations) == 1:
+ split_result = self._try_split_single_declarator_init(statement, declarations[0])
if split_result:
- statements[i : i + 1] = split_result
- i += len(split_result)
+ statements[index : index + 1] = split_result
+ index += len(split_result)
self.set_changed()
continue
- i += 1
+ index += 1
@staticmethod
- def _try_split_single_declarator_init(stmt, declarator):
+ def _try_split_single_declarator_init(statement: dict, declarator: dict) -> list | None:
"""Split SequenceExpression from a single VariableDeclarator init.
Handles both direct sequences and sequences inside AwaitExpression.
@@ -190,12 +188,12 @@ def _try_split_single_declarator_init(stmt, declarator):
# Direct: const x = (a, b, expr()) → a; b; const x = expr();
if init.get('type') == 'SequenceExpression':
- exprs = init.get('expressions', [])
- if len(exprs) <= 1:
+ expressions = init.get('expressions', [])
+ if len(expressions) <= 1:
return None
- prefix = [make_expression_statement(e) for e in exprs[:-1]]
- declarator['init'] = exprs[-1]
- prefix.append(stmt)
+ prefix = [make_expression_statement(expression) for expression in expressions[:-1]]
+ declarator['init'] = expressions[-1]
+ prefix.append(statement)
return prefix
# Await-wrapped: var x = await (a, b, expr()) → a; b; var x = await expr();
@@ -204,12 +202,12 @@ def _try_split_single_declarator_init(stmt, declarator):
and isinstance(init.get('argument'), dict)
and init['argument'].get('type') == 'SequenceExpression'
):
- exprs = init['argument'].get('expressions', [])
- if len(exprs) <= 1:
+ expressions = init['argument'].get('expressions', [])
+ if len(expressions) <= 1:
return None
- prefix = [make_expression_statement(e) for e in exprs[:-1]]
- init['argument'] = exprs[-1]
- prefix.append(stmt)
+ prefix = [make_expression_statement(expression) for expression in expressions[:-1]]
+ init['argument'] = expressions[-1]
+ prefix.append(statement)
return prefix
return None
diff --git a/pyjsclear/transforms/single_use_vars.py b/pyjsclear/transforms/single_use_vars.py
index d8078d4..30984ea 100644
--- a/pyjsclear/transforms/single_use_vars.py
+++ b/pyjsclear/transforms/single_use_vars.py
@@ -3,11 +3,11 @@
Targets patterns like:
const _0x337161 = require("process");
return _0x337161.env.LOCALAPPDATA;
-→ return require("process").env.LOCALAPPDATA;
+ return require("process").env.LOCALAPPDATA;
const _0x27439f = Buffer.from(_0x162d6f);
return _0x27439f.toString();
-→ return Buffer.from(_0x162d6f).toString();
+ return Buffer.from(_0x162d6f).toString();
Only inlines when:
- The variable is constant (no reassignments)
@@ -15,24 +15,30 @@
- The init expression is not too large (≤ 15 AST nodes)
"""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
from ..scope import build_scope_tree
from ..traverser import REMOVE
-from ..traverser import find_parent
from ..traverser import simple_traverse
from ..traverser import traverse
from ..utils.ast_helpers import deep_copy
from ..utils.ast_helpers import is_identifier
from .base import Transform
+if TYPE_CHECKING:
+ from ..scope import Scope
+
-def _count_nodes(node):
+def _count_nodes(node: dict) -> int:
"""Count AST nodes in a subtree."""
count = [0]
- def cb(n, parent):
+ def increment_count(_node: dict, parent: dict | None) -> None:
count[0] += 1
- simple_traverse(node, cb)
+ simple_traverse(node, increment_count)
return count[0]
@@ -45,15 +51,18 @@ class SingleUseVarInliner(Transform):
# Keeps inlined expressions readable; avoids ballooning line length.
_MAX_INIT_NODES = 15
- def execute(self):
- scope_tree, _ = build_scope_tree(self.ast)
+ def execute(self) -> bool:
+ if self.scope_tree is not None:
+ scope_tree = self.scope_tree
+ else:
+ scope_tree, _ = build_scope_tree(self.ast)
inlined = self._process_scope(scope_tree)
if not inlined:
return False
self._remove_declarators(inlined)
return self.has_changed()
- def _process_scope(self, scope):
+ def _process_scope(self, scope: Scope) -> list[dict]:
"""Find and inline single-use constant bindings."""
inlined_declarators = []
@@ -115,7 +124,7 @@ def _process_scope(self, scope):
return inlined_declarators
- def _is_mutated_member_object(self, ref_parent, ref_key):
+ def _is_mutated_member_object(self, ref_parent: dict | None, ref_key: str | None) -> bool:
"""Check if ref is the object of a member expression that is an assignment target.
Catches: obj[x] = val, obj.x = val, obj[x]++, etc.
@@ -126,30 +135,30 @@ def _is_mutated_member_object(self, ref_parent, ref_key):
if ref_key != 'object':
return False
# Now check if this MemberExpression is an assignment target
- parent_info = find_parent(self.ast, ref_parent)
+ parent_info = self.find_parent(ref_parent)
if not parent_info:
return False
- grandparent, gp_key, _ = parent_info
- if grandparent.get('type') == 'AssignmentExpression' and gp_key == 'left':
+ grandparent, grandparent_key, _ = parent_info
+ if grandparent.get('type') == 'AssignmentExpression' and grandparent_key == 'left':
return True
if grandparent.get('type') == 'UpdateExpression':
return True
return False
- def _remove_declarators(self, declarator_nodes):
+ def _remove_declarators(self, declarator_nodes: list[dict]) -> None:
"""Remove inlined VariableDeclarators from their parent declarations."""
- declarator_ids = {id(d) for d in declarator_nodes}
+ declarator_ids = {id(declarator) for declarator in declarator_nodes}
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object:
if node.get('type') != 'VariableDeclaration':
return
- decls = node.get('declarations', [])
- original_len = len(decls)
- decls[:] = [d for d in decls if id(d) not in declarator_ids]
- if len(decls) == original_len:
+ declarations = node.get('declarations', [])
+ original_length = len(declarations)
+ declarations[:] = [declarator for declarator in declarations if id(declarator) not in declarator_ids]
+ if len(declarations) == original_length:
return # No match — continue traversing children
self.set_changed()
- if not decls:
+ if not declarations:
return REMOVE
traverse(self.ast, {'enter': enter})
diff --git a/pyjsclear/transforms/string_revealer.py b/pyjsclear/transforms/string_revealer.py
index 3070f8f..eed74f5 100644
--- a/pyjsclear/transforms/string_revealer.py
+++ b/pyjsclear/transforms/string_revealer.py
@@ -2,11 +2,11 @@
import math
import re
+from typing import Any
from ..generator import generate
from ..scope import build_scope_tree
from ..traverser import REMOVE
-from ..traverser import find_parent
from ..traverser import simple_traverse
from ..traverser import traverse
from ..utils.ast_helpers import is_identifier
@@ -24,14 +24,14 @@
_RC4_REGEX = re.compile(r"""fromCharCode.{0,30}\^""")
-def _eval_numeric(node):
+def _eval_numeric(node: Any) -> int | float | None:
"""Evaluate an AST node to a numeric value if it's a constant expression."""
if not isinstance(node, dict):
return None
match node.get('type', ''):
case 'Literal':
- val = node.get('value')
- return val if isinstance(val, (int, float)) else None
+ value = node.get('value')
+ return value if isinstance(value, (int, float)) else None
case 'UnaryExpression':
arg = _eval_numeric(node.get('argument'))
if arg is None:
@@ -51,18 +51,18 @@ def _eval_numeric(node):
return None
-def _js_parse_int(s):
+def _js_parse_int(string: str) -> float:
"""Mimic JavaScript's parseInt: extract leading integer from string."""
- if not isinstance(s, str):
+ if not isinstance(string, str):
return float('nan')
- s = s.strip()
- m = re.match(r'^[+-]?\d+', s)
- if m:
- return int(m.group())
+ string = string.strip()
+ match = re.match(r'^[+-]?\d+', string)
+ if match:
+ return int(match.group())
return float('nan')
-def _apply_arith(operator, left, right):
+def _apply_arith(operator: str, left: int | float, right: int | float) -> int | float | None:
"""Apply a binary arithmetic operator."""
match operator:
case '+':
@@ -79,7 +79,7 @@ def _apply_arith(operator, left, right):
return None
-def _collect_object_literals(ast):
+def _collect_object_literals(ast: dict) -> dict[tuple[str, str], int | float | str]:
"""Collect simple object literal assignments: var o = {a: 0x1b1, b: 'str'}.
Returns a dict mapping (object_name, property_name) -> value (int or str).
@@ -117,7 +117,7 @@ def visitor(node, parent):
return result
-def _resolve_arg_value(arg, object_literals):
+def _resolve_arg_value(arg: dict, object_literals: dict) -> int | float | None:
"""Try to resolve a call argument to a numeric value.
Handles numeric literals, string hex literals, and member expressions
@@ -150,7 +150,7 @@ def _resolve_arg_value(arg, object_literals):
return None
-def _resolve_string_arg(arg, object_literals):
+def _resolve_string_arg(arg: dict, object_literals: dict) -> str | None:
"""Try to resolve a call argument to a string value.
Handles string literals and member expressions referencing known object properties.
@@ -170,14 +170,21 @@ def _resolve_string_arg(arg, object_literals):
class WrapperInfo:
"""Info about a wrapper function that calls the decoder."""
- def __init__(self, name, param_index, wrapper_offset, func_node, key_param_index=None):
+ def __init__(
+ self,
+ name: str,
+ param_index: int,
+ wrapper_offset: int,
+ func_node: dict,
+ key_param_index: int | None = None,
+ ) -> None:
self.name = name
self.param_index = param_index
self.wrapper_offset = wrapper_offset
self.func_node = func_node
self.key_param_index = key_param_index
- def get_effective_index(self, call_args):
+ def get_effective_index(self, call_args: list) -> int | None:
"""Given call argument values, compute the effective decoder index."""
if self.param_index >= len(call_args):
return None
@@ -186,7 +193,7 @@ def get_effective_index(self, call_args):
return None
return int(value) + self.wrapper_offset
- def get_key(self, call_args):
+ def get_key(self, call_args: list) -> str | None:
"""Get the RC4 key argument if applicable."""
if self.key_param_index is not None and self.key_param_index < len(call_args):
return call_args[self.key_param_index]
@@ -199,8 +206,11 @@ class StringRevealer(Transform):
rebuild_scope = True
_rotation_locals = {}
- def execute(self):
- scope_tree, node_scope = build_scope_tree(self.ast)
+ def execute(self) -> bool:
+ if self.scope_tree is not None:
+ scope_tree, node_scope = self.scope_tree, self.node_scope
+ else:
+ scope_tree, node_scope = build_scope_tree(self.ast)
# Strategy 1: Direct string array declarations (var arr = ["a","b","c"])
self._process_direct_arrays(scope_tree)
@@ -220,7 +230,7 @@ def execute(self):
# Strategy 2: Obfuscator.io pattern
# ================================================================
- def _process_obfuscatorio_pattern(self):
+ def _process_obfuscatorio_pattern(self) -> None:
"""Handle obfuscator.io: array func -> decoder func(s) -> wrapper funcs -> rotation."""
body = self.ast.get('body', [])
@@ -312,7 +322,7 @@ def _process_obfuscatorio_pattern(self):
indices_to_remove.add(array_func_idx)
self._remove_body_indices(body, *indices_to_remove)
- def _find_string_array_function(self, body):
+ def _find_string_array_function(self, body: list) -> tuple[str | None, list | None, int | None]:
"""Find the string array function declaration.
Pattern: function X() { var a = ['s1','s2',...]; X = function(){return a;}; return X(); }
@@ -334,7 +344,7 @@ def _find_string_array_function(self, body):
return None, None, None
@staticmethod
- def _string_array_from_expression(node):
+ def _string_array_from_expression(node: dict | None) -> list[str] | None:
"""Return list of string values if node is an ArrayExpression of all string literals."""
if not node or node.get('type') != 'ArrayExpression':
return None
@@ -343,7 +353,7 @@ def _string_array_from_expression(node):
return None
return [e['value'] for e in elements]
- def _extract_array_from_statement(self, stmt):
+ def _extract_array_from_statement(self, stmt: dict) -> list[str] | None:
"""Extract string array from a variable declaration or assignment."""
if stmt.get('type') == 'VariableDeclaration':
for declaration in stmt.get('declarations', []):
@@ -356,7 +366,7 @@ def _extract_array_from_statement(self, stmt):
return self._string_array_from_expression(expr.get('right'))
return None
- def _find_all_decoder_functions(self, body, array_func_name):
+ def _find_all_decoder_functions(self, body: list, array_func_name: str) -> list[tuple[str, int, int, DecoderType]]:
"""Find all decoder functions that call the array function.
Returns list of (func_name, offset, body_index, decoder_type) tuples.
@@ -384,7 +394,7 @@ def _find_all_decoder_functions(self, body, array_func_name):
return results
- def _function_calls(self, func_node, callee_name):
+ def _function_calls(self, func_node: dict, callee_name: str) -> bool:
"""Check if a function body contains a call to callee_name."""
found = [False]
@@ -401,7 +411,7 @@ def visitor(node, parent):
simple_traverse(func_node, visitor)
return found[0]
- def _extract_decoder_offset(self, func_node):
+ def _extract_decoder_offset(self, func_node: dict) -> int:
"""Extract offset from decoder's inner param = param OP EXPR pattern."""
found_offset = [None]
@@ -430,7 +440,7 @@ def find_offset(node, parent):
simple_traverse(func_node, find_offset)
return found_offset[0] if found_offset[0] is not None else 0
- def _create_base_decoder(self, string_array, offset, dtype):
+ def _create_base_decoder(self, string_array: list[str], offset: int, dtype: DecoderType) -> BasicStringDecoder | Base64StringDecoder | Rc4StringDecoder:
"""Create the appropriate decoder instance."""
match dtype:
case DecoderType.RC4:
@@ -440,7 +450,7 @@ def _create_base_decoder(self, string_array, offset, dtype):
case _:
return BasicStringDecoder(string_array, offset)
- def _find_all_wrappers(self, decoder_name):
+ def _find_all_wrappers(self, decoder_name: str) -> dict[str, 'WrapperInfo']:
"""Find all wrapper functions throughout the AST that call the decoder.
Pattern: function W(p0,..,pN) { return DECODER(p_i OP OFFSET, p_j); }
@@ -467,14 +477,14 @@ def visitor(node, parent):
simple_traverse(self.ast, visitor)
return wrappers
- def _analyze_wrapper(self, func_node, decoder_name):
+ def _analyze_wrapper(self, func_node: dict, decoder_name: str) -> 'WrapperInfo | None':
"""Check if a FunctionDeclaration is a wrapper. Returns WrapperInfo or None."""
func_name = func_node.get('id', {}).get('name')
if not func_name:
return None
return self._analyze_wrapper_expr(func_name, func_node, decoder_name)
- def _analyze_wrapper_expr(self, func_name, func_node, decoder_name):
+ def _analyze_wrapper_expr(self, func_name: str, func_node: dict, decoder_name: str) -> 'WrapperInfo | None':
"""Analyze a function node (declaration or expression) as a potential wrapper."""
func_body = func_node.get('body', {})
if func_body.get('type') == 'BlockStatement':
@@ -516,7 +526,7 @@ def _analyze_wrapper_expr(self, func_name, func_node, decoder_name):
return WrapperInfo(func_name, param_index, wrapper_offset, func_node, key_param_index)
- def _extract_wrapper_offset(self, expr, param_names):
+ def _extract_wrapper_offset(self, expr: dict, param_names: list[str]) -> tuple[int | None, int | None]:
"""Extract (param_index, offset) from wrapper's first argument to decoder.
Handles: p_N, p_N + LIT, p_N - LIT, p_N - -LIT, p_N + -LIT
@@ -542,7 +552,7 @@ def _extract_wrapper_offset(self, expr, param_names):
offset = int(-right_value) if operator == '-' else int(right_value)
return param_idx, offset
- def _remove_decoder_aliases(self, decoder_name, aliases):
+ def _remove_decoder_aliases(self, decoder_name: str, aliases: set[str]) -> None:
"""Remove variable declarations that are aliases for the decoder.
Removes: const _0xABC = _0x22e6; and transitive: const _0xDEF = _0xABC;
@@ -577,7 +587,7 @@ def enter(node, parent, key, index):
traverse(self.ast, {'enter': enter})
- def _find_decoder_aliases(self, decoder_name):
+ def _find_decoder_aliases(self, decoder_name: str) -> set[str]:
"""Find all variable declarations that are aliases for the decoder.
Handles transitive aliases: const a = decoder; const b = a; const c = b;
@@ -614,15 +624,15 @@ def visitor(node, parent):
def _find_and_execute_rotation(
self,
- body,
- array_func_name,
- string_array,
- decoder,
- wrappers,
- decoder_aliases=None,
- alias_decoder_map=None,
- all_decoders=None,
- ):
+ body: list,
+ array_func_name: str,
+ string_array: list[str],
+ decoder: Any,
+ wrappers: dict,
+ decoder_aliases: set[str] | None = None,
+ alias_decoder_map: dict | None = None,
+ all_decoders: dict | None = None,
+ ) -> tuple[int, Any] | None:
"""Find rotation IIFE and execute it.
Returns (body_index, rotation_call_expr_or_none) on success, or None.
@@ -669,15 +679,15 @@ def _find_and_execute_rotation(
def _try_execute_rotation_call(
self,
- call_expr,
- array_func_name,
- string_array,
- decoder,
- wrappers,
- decoder_aliases,
- alias_decoder_map=None,
- all_decoders=None,
- ):
+ call_expr: dict,
+ array_func_name: str,
+ string_array: list[str],
+ decoder: Any,
+ wrappers: dict,
+ decoder_aliases: set[str] | None,
+ alias_decoder_map: dict | None = None,
+ all_decoders: dict | None = None,
+ ) -> bool:
"""Try to parse and execute a single rotation call expression. Returns True on success."""
callee = call_expr.get('callee')
args = call_expr.get('arguments', [])
@@ -716,7 +726,7 @@ def _try_execute_rotation_call(
return True
@staticmethod
- def _collect_rotation_locals(iife_func):
+ def _collect_rotation_locals(iife_func: dict) -> dict[str, dict]:
"""Collect local object literal assignments from the rotation IIFE.
Returns dict: var_name -> {prop_name: value}.
@@ -739,21 +749,21 @@ def _collect_rotation_locals(iife_func):
if not key or not value:
continue
if is_identifier(key):
- k = key['name']
+ prop_name = key['name']
elif is_string_literal(key):
- k = key['value']
+ prop_name = key['value']
else:
continue
num = _eval_numeric(value)
if num is not None:
- obj[k] = int(num)
+ obj[prop_name] = int(num)
elif is_string_literal(value):
- obj[k] = value['value']
+ obj[prop_name] = value['value']
if obj:
result[name_node['name']] = obj
return result
- def _extract_rotation_expression(self, iife_func):
+ def _extract_rotation_expression(self, iife_func: dict) -> dict | None:
"""Extract the arithmetic expression from the try block in the rotation loop."""
func_body = iife_func.get('body', {}).get('body', [])
if not func_body:
@@ -782,7 +792,7 @@ def _extract_rotation_expression(self, iife_func):
return None
@staticmethod
- def _expression_from_try_block(first_statement):
+ def _expression_from_try_block(first_statement: dict) -> dict | None:
"""Extract the init/rhs expression from the first statement in a try block."""
if first_statement.get('type') == 'VariableDeclaration':
decls = first_statement.get('declarations', [])
@@ -793,7 +803,7 @@ def _expression_from_try_block(first_statement):
return expr.get('right')
return None
- def _parse_rotation_op(self, expr, wrappers, decoder_aliases=None):
+ def _parse_rotation_op(self, expr: dict, wrappers: dict, decoder_aliases: set[str] | None = None) -> dict | None:
"""Parse a rotation expression into an operation tree."""
if not isinstance(expr, dict):
return None
@@ -830,7 +840,7 @@ def _parse_rotation_op(self, expr, wrappers, decoder_aliases=None):
return None
- def _parse_parseInt_call(self, expr, wrappers, aliases):
+ def _parse_parseInt_call(self, expr: dict, wrappers: dict, aliases: set[str]) -> dict | None:
"""Parse parseInt(wrapperOrDecoder(...)) into an operation node."""
callee = expr.get('callee')
args = expr.get('arguments', [])
@@ -856,20 +866,20 @@ def _parse_parseInt_call(self, expr, wrappers, aliases):
return {'op': 'direct_decoder_call', 'alias_name': cname, 'args': arg_values}
return None
- def _resolve_rotation_arg(self, arg):
+ def _resolve_rotation_arg(self, arg: dict) -> int | str | None:
"""Resolve a rotation call argument to a numeric or string value.
Handles literals, string hex, and MemberExpression referencing local objects.
"""
- val = _eval_numeric(arg)
- if val is not None:
- return int(val)
+ numeric_value = _eval_numeric(arg)
+ if numeric_value is not None:
+ return int(numeric_value)
if is_string_literal(arg):
- s = arg['value']
+ string_value = arg['value']
try:
- return int(s, 16) if s.startswith('0x') else int(s)
+ return int(string_value, 16) if string_value.startswith('0x') else int(string_value)
except (ValueError, TypeError):
- return s
+ return string_value
# MemberExpression: J.A or J['A']
if arg.get('type') == 'MemberExpression':
obj = arg.get('object')
@@ -882,7 +892,7 @@ def _resolve_rotation_arg(self, arg):
return local_obj.get(prop['value'])
return None
- def _decode_and_parse_int(self, decoder, idx, key=None):
+ def _decode_and_parse_int(self, decoder: Any, idx: int | float, key: str | None = None) -> float:
"""Decode a string and parse it as an integer. Raises on failure."""
decoded = decoder.get_string(int(idx), key) if key is not None else decoder.get_string(int(idx))
if decoded is None:
@@ -892,7 +902,7 @@ def _decode_and_parse_int(self, decoder, idx, key=None):
raise ValueError('NaN from parseInt')
return result
- def _apply_rotation_op(self, operation, wrappers, decoder, alias_decoder_map=None):
+ def _apply_rotation_op(self, operation: dict, wrappers: dict, decoder: Any, alias_decoder_map: dict | None = None) -> int | float:
"""Evaluate a parsed rotation operation tree."""
match operation['op']:
case 'literal':
@@ -924,7 +934,7 @@ def _apply_rotation_op(self, operation, wrappers, decoder, alias_decoder_map=Non
case _:
raise ValueError(f'Unknown op: {operation["op"]}')
- def _execute_rotation(self, string_array, operation, wrappers, decoder, stop_value, alias_decoder_map=None):
+ def _execute_rotation(self, string_array: list[str], operation: dict, wrappers: dict, decoder: Any, stop_value: int, alias_decoder_map: dict | None = None) -> bool:
"""Rotate array until the expression evaluates to stop_value."""
# Collect all decoders that need cache clearing on each rotation
all_decoders = set()
@@ -941,14 +951,14 @@ def _execute_rotation(self, string_array, operation, wrappers, decoder, stop_val
except Exception:
string_array.append(string_array.pop(0))
# Clear decoder caches after rotation since array contents shifted
- for d in all_decoders:
- if hasattr(d, '_cache'):
- d._cache.clear()
+ for decoder_instance in all_decoders:
+ if hasattr(decoder_instance, '_cache'):
+ decoder_instance._cache.clear()
return False
# ---- Replacement ----
- def _replace_all_wrapper_calls(self, wrappers, decoder, obj_literals=None):
+ def _replace_all_wrapper_calls(self, wrappers: dict, decoder: Any, obj_literals: dict | None = None) -> bool:
"""Replace all calls to wrapper functions with decoded string literals."""
if not wrappers:
return True
@@ -999,7 +1009,7 @@ def enter(node, parent, key, index):
traverse(self.ast, {'enter': enter})
return all_replaced[0]
- def _replace_direct_decoder_calls(self, decoder_name, decoder, decoder_aliases=None, obj_literals=None):
+ def _replace_direct_decoder_calls(self, decoder_name: str, decoder: Any, decoder_aliases: set[str] | None = None, obj_literals: dict | None = None) -> None:
"""Replace direct calls to the decoder function (and its aliases) with literals."""
names = {decoder_name}
if decoder_aliases:
@@ -1037,7 +1047,7 @@ def enter(node, parent, key, index):
traverse(self.ast, {'enter': enter})
@staticmethod
- def _find_array_expression_in_statement(stmt):
+ def _find_array_expression_in_statement(stmt: dict) -> dict | None:
"""Find the first ArrayExpression node in a variable declaration or assignment."""
if stmt.get('type') == 'VariableDeclaration':
for declaration in stmt.get('declarations', []):
@@ -1052,7 +1062,7 @@ def _find_array_expression_in_statement(stmt):
return right
return None
- def _update_ast_array(self, func_node, rotated_array):
+ def _update_ast_array(self, func_node: dict, rotated_array: list[str]) -> None:
"""Update the AST's array function to contain the rotated string array."""
func_body = func_node.get('body', {}).get('body', [])
if not func_body:
@@ -1061,7 +1071,7 @@ def _update_ast_array(self, func_node, rotated_array):
if arr_expr is not None:
arr_expr['elements'] = [make_literal(s) for s in rotated_array]
- def _remove_body_indices(self, body, *indices):
+ def _remove_body_indices(self, body: list, *indices: int | None) -> None:
"""Remove statements at given indices from body."""
for idx in sorted(set(i for i in indices if i is not None), reverse=True):
if 0 <= idx < len(body):
@@ -1076,7 +1086,7 @@ def _remove_body_indices(self, body, *indices):
# var _0xDEC = function(a, b) { a = a - OFFSET; var x = _0xARR[a]; return x; };
# ================================================================
- def _process_var_array_pattern(self):
+ def _process_var_array_pattern(self) -> None:
"""Handle var-based string array with simple rotation and decoder."""
body = self.ast.get('body', [])
if len(body) < 3:
@@ -1122,7 +1132,7 @@ def _process_var_array_pattern(self):
indices_to_remove.add(rotation_idx)
self._remove_body_indices(body, *indices_to_remove)
- def _find_var_string_array(self, body):
+ def _find_var_string_array(self, body: list) -> tuple[str | None, list[str] | None, int | None]:
"""Find var _0x... = ['s1', 's2', ...] at top of body."""
for i, stmt in enumerate(body[:3]):
if stmt.get('type') != 'VariableDeclaration':
@@ -1142,7 +1152,7 @@ def _find_var_string_array(self, body):
return name_node['name'], [e['value'] for e in elements], i
return None, None, None
- def _find_simple_rotation(self, body, array_name):
+ def _find_simple_rotation(self, body: list, array_name: str) -> tuple[int | None, int | None]:
"""Find (function(arr, count) { ...push/shift... })(array, N) rotation IIFE."""
for i, stmt in enumerate(body):
if stmt.get('type') != 'ExpressionStatement':
@@ -1177,7 +1187,7 @@ def _find_simple_rotation(self, body, array_name):
return None, None
- def _find_var_decoder(self, body, array_name):
+ def _find_var_decoder(self, body: list, array_name: str) -> tuple[str | None, int | None, int | None]:
"""Find var _0xDEC = function(a) { a = a - OFFSET; var x = ARR[a]; return x; }."""
for i, stmt in enumerate(body):
if stmt.get('type') != 'VariableDeclaration':
@@ -1200,7 +1210,7 @@ def _find_var_decoder(self, body, array_name):
# Strategy 1: Direct string array declarations
# ================================================================
- def _try_replace_array_access(self, ref_parent, ref_key, string_array):
+ def _try_replace_array_access(self, ref_parent: dict | None, ref_key: str, string_array: list[str]) -> None:
"""Replace arr[N] member expression with the string literal if valid."""
if not ref_parent or ref_parent.get('type') != 'MemberExpression':
return
@@ -1215,7 +1225,7 @@ def _try_replace_array_access(self, ref_parent, ref_key, string_array):
self._replace_node_in_ast(ref_parent, make_literal(string_array[idx]))
self.set_changed()
- def _process_direct_arrays(self, scope_tree):
+ def _process_direct_arrays(self, scope_tree: Any) -> None:
"""Find direct array declarations and replace indexed accesses."""
for name, binding in list(scope_tree.bindings.items()):
node = binding.node
@@ -1234,7 +1244,7 @@ def _process_direct_arrays(self, scope_tree):
for child in scope_tree.children:
self._process_direct_arrays_in_scope(child, name, string_array)
- def _process_direct_arrays_in_scope(self, scope, name, string_array):
+ def _process_direct_arrays_in_scope(self, scope: Any, name: str, string_array: list[str]) -> None:
"""Process direct array accesses in child scopes."""
binding = scope.get_binding(name)
if not binding:
@@ -1242,20 +1252,21 @@ def _process_direct_arrays_in_scope(self, scope, name, string_array):
for reference_node, reference_parent, reference_key, ref_index in binding.references[:]:
self._try_replace_array_access(reference_parent, reference_key, string_array)
- def _replace_node_in_ast(self, target, replacement):
+ def _replace_node_in_ast(self, target: dict, replacement: dict) -> None:
"""Replace a node in the AST with a replacement."""
- result = find_parent(self.ast, target)
+ result = self.find_parent(target)
if result:
parent, key, index = result
if index is not None:
parent[key][index] = replacement
else:
parent[key] = replacement
+ self.invalidate_parent_map()
# ================================================================
# Strategy 3: Simple static array unpacking
# ================================================================
- def _process_static_arrays(self):
+ def _process_static_arrays(self) -> None:
"""No-op: static array unpacking is handled by _process_direct_arrays."""
pass
diff --git a/pyjsclear/transforms/unreachable_code.py b/pyjsclear/transforms/unreachable_code.py
index 9cd32d4..cb5e4a9 100644
--- a/pyjsclear/transforms/unreachable_code.py
+++ b/pyjsclear/transforms/unreachable_code.py
@@ -11,14 +11,14 @@
class UnreachableCodeRemover(Transform):
"""Remove statements that follow a terminator (return/throw/break/continue) in a block."""
- def execute(self):
- def enter(node, parent, key, index):
- t = node.get('type')
- if t in ('BlockStatement', 'Program'):
+ def execute(self) -> bool:
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
+ node_type = node.get('type')
+ if node_type in ('BlockStatement', 'Program'):
body = node.get('body')
if body and isinstance(body, list):
self._truncate_after_terminator(body, node, 'body')
- elif t == 'SwitchCase':
+ elif node_type == 'SwitchCase':
consequent = node.get('consequent')
if consequent and isinstance(consequent, list):
self._truncate_after_terminator(consequent, node, 'consequent')
@@ -26,12 +26,13 @@ def enter(node, parent, key, index):
traverse(self.ast, {'enter': enter})
return self.has_changed()
- def _truncate_after_terminator(self, stmts, node, key):
- for i, stmt in enumerate(stmts):
- if not isinstance(stmt, dict):
+ def _truncate_after_terminator(self, statements: list, node: dict, key: str) -> None:
+ for statement_index, statement in enumerate(statements):
+ if not isinstance(statement, dict):
continue
- if stmt.get('type') in _TERMINATORS:
- if i + 1 < len(stmts):
- self.set_changed()
- node[key] = stmts[: i + 1]
- return
+ if statement.get('type') not in _TERMINATORS:
+ continue
+ if statement_index + 1 < len(statements):
+ self.set_changed()
+ node[key] = statements[: statement_index + 1]
+ return
diff --git a/pyjsclear/transforms/unused_vars.py b/pyjsclear/transforms/unused_vars.py
index 64ebbba..24622e0 100644
--- a/pyjsclear/transforms/unused_vars.py
+++ b/pyjsclear/transforms/unused_vars.py
@@ -33,17 +33,20 @@ class UnusedVariableRemover(Transform):
rebuild_scope = True
- def execute(self):
- scope_tree, _ = build_scope_tree(self.ast)
- declarators_to_remove = set()
- functions_to_remove = set()
+ def execute(self) -> bool:
+ if self.scope_tree is not None:
+ scope_tree = self.scope_tree
+ else:
+ scope_tree, _ = build_scope_tree(self.ast)
+ declarators_to_remove: set[int] = set()
+ functions_to_remove: set[int] = set()
self._collect_unused(scope_tree, declarators_to_remove, functions_to_remove)
if not declarators_to_remove and not functions_to_remove:
return False
self._batch_remove(declarators_to_remove, functions_to_remove)
return self.has_changed()
- def _collect_unused(self, scope, declarators, functions):
+ def _collect_unused(self, scope: object, declarators: set[int], functions: set[int]) -> None:
skip_global = scope.parent is None
for name, binding in scope.bindings.items():
@@ -66,30 +69,31 @@ def _collect_unused(self, scope, declarators, functions):
for child in scope.children:
self._collect_unused(child, declarators, functions)
- def _batch_remove(self, declarators_to_remove, functions_to_remove):
+ def _batch_remove(self, declarators_to_remove: set[int], functions_to_remove: set[int]) -> None:
"""Remove all collected unused declarations in a single traversal."""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> object:
node_type = node.get('type')
if node_type == 'FunctionDeclaration' and id(node) in functions_to_remove:
self.set_changed()
return REMOVE
if node_type != 'VariableDeclaration':
- return
- decls = node.get('declarations')
- if not decls:
- return
- new_decls = [d for d in decls if id(d) not in declarators_to_remove]
- if len(new_decls) == len(decls):
- return
+ return None
+ declarations = node.get('declarations')
+ if not declarations:
+ return None
+ filtered_declarations = [declarator for declarator in declarations if id(declarator) not in declarators_to_remove]
+ if len(filtered_declarations) == len(declarations):
+ return None
self.set_changed()
- if not new_decls:
+ if not filtered_declarations:
return REMOVE
- node['declarations'] = new_decls
+ node['declarations'] = filtered_declarations
+ return None
traverse(self.ast, {'enter': enter})
- def _has_side_effects(self, node):
+ def _has_side_effects(self, node: dict) -> bool:
"""Conservative check for side effects in an expression."""
if not isinstance(node, dict):
return False
@@ -98,8 +102,7 @@ def _has_side_effects(self, node):
return True
if node_type in ('Literal', 'Identifier', 'ThisExpression', 'FunctionExpression', 'ArrowFunctionExpression'):
return False
- # For all other types (including ArrayExpression, ObjectExpression,
- # BinaryExpression, etc.), recurse into children to check for side effects
+ # Recurse into children (handles ArrayExpression, ObjectExpression, BinaryExpression, etc.)
for key in get_child_keys(node):
child = node.get(key)
if child is None:
diff --git a/pyjsclear/transforms/variable_renamer.py b/pyjsclear/transforms/variable_renamer.py
index a3ad71d..e85fc6b 100644
--- a/pyjsclear/transforms/variable_renamer.py
+++ b/pyjsclear/transforms/variable_renamer.py
@@ -135,38 +135,59 @@
'FormData': 'form',
}
+# fs-like methods
+_FS_METHODS = {
+ 'readFileSync',
+ 'writeFileSync',
+ 'existsSync',
+ 'mkdirSync',
+ 'statSync',
+ 'readdirSync',
+ 'unlinkSync',
+ 'createWriteStream',
+ 'createReadStream',
+ 'readFile',
+ 'writeFile',
+ 'appendFileSync',
+}
+
+# path-like methods
+_PATH_METHODS = {'join', 'resolve', 'basename', 'dirname', 'extname', 'normalize'}
+
+_ALPHABET = 'abcdefghijklmnopqrstuvwxyz'
+
-def _name_generator(reserved):
+def _name_generator(reserved: set) -> object:
"""Yield short identifier names, skipping reserved and taken names."""
- for c in 'abcdefghijklmnopqrstuvwxyz':
- if c not in reserved:
- yield c
- for c1 in 'abcdefghijklmnopqrstuvwxyz':
- for c2 in 'abcdefghijklmnopqrstuvwxyz':
- name = c1 + c2
+ for char in _ALPHABET:
+ if char not in reserved:
+ yield char
+ for first_char in _ALPHABET:
+ for second_char in _ALPHABET:
+ name = first_char + second_char
if name not in reserved:
yield name
- for c1 in 'abcdefghijklmnopqrstuvwxyz':
- for c2 in 'abcdefghijklmnopqrstuvwxyz':
- for c3 in 'abcdefghijklmnopqrstuvwxyz':
- name = c1 + c2 + c3
+ for first_char in _ALPHABET:
+ for second_char in _ALPHABET:
+ for third_char in _ALPHABET:
+ name = first_char + second_char + third_char
if name not in reserved:
yield name
-def _dedupe_name(base, reserved):
+def _dedupe_name(base: str, reserved: set) -> str:
"""Return base or base2, base3, ... until a non-reserved name is found."""
if base not in reserved:
return base
- n = 2
+ counter = 2
while True:
- candidate = f'{base}{n}'
+ candidate = f'{base}{counter}'
if candidate not in reserved:
return candidate
- n += 1
+ counter += 1
-def _infer_from_init(init):
+def _infer_from_init(init: dict | None) -> str | None:
"""Infer a variable name from its initializer expression."""
if not isinstance(init, dict) or 'type' not in init:
return None
@@ -180,11 +201,11 @@ def _infer_from_init(init):
if is_identifier(callee) and callee.get('name') == 'require' and len(args) == 1:
arg = args[0]
if arg.get('type') == 'Literal' and isinstance(arg.get('value'), str):
- mod = arg['value']
- if mod in _REQUIRE_NAMES:
- return _REQUIRE_NAMES[mod]
+ module_name = arg['value']
+ if module_name in _REQUIRE_NAMES:
+ return _REQUIRE_NAMES[module_name]
# Derive name from module path, sanitized to valid identifier
- base = mod.split('/')[-1].split('\\')[-1]
+ base = module_name.split('/')[-1].split('\\')[-1]
base = base.split('.')[0] # strip file extension
base = re.sub(r'[^a-zA-Z0-9_]', '_', base)
base = re.sub(r'^[0-9]+', '', base) # can't start with digit
@@ -200,20 +221,19 @@ def _infer_from_init(init):
if is_identifier(obj) and is_identifier(prop):
obj_name = obj.get('name')
prop_name = prop.get('name')
- if obj_name == 'Buffer' and prop_name == 'from':
- return 'buf'
- if obj_name == 'JSON' and prop_name == 'parse':
- return 'data'
- if obj_name == 'JSON' and prop_name == 'stringify':
- return 'json'
- if obj_name == 'Object' and prop_name == 'keys':
- return 'keys'
- if obj_name == 'Object' and prop_name == 'values':
- return 'values'
- if obj_name == 'Object' and prop_name == 'entries':
- return 'entries'
- if obj_name == 'Object' and prop_name == 'getOwnPropertyNames':
- return 'keys'
+ match (obj_name, prop_name):
+ case ('Buffer', 'from'):
+ return 'buf'
+ case ('JSON', 'parse'):
+ return 'data'
+ case ('JSON', 'stringify'):
+ return 'json'
+ case ('Object', 'keys') | ('Object', 'getOwnPropertyNames'):
+ return 'keys'
+ case ('Object', 'values'):
+ return 'values'
+ case ('Object', 'entries'):
+ return 'entries'
# new Date() → "date"
if init_type == 'NewExpression':
@@ -226,34 +246,28 @@ def _infer_from_init(init):
if is_identifier(prop):
return _CONSTRUCTOR_NAMES.get(prop.get('name'))
- # [] → "arr"
- if init_type == 'ArrayExpression':
- return 'arr'
-
- # {} → "obj"
- if init_type == 'ObjectExpression':
- return 'obj'
-
- # "string" → "str"
- if init_type == 'Literal':
- val = init.get('value')
- if isinstance(val, str):
- return 'str'
- if isinstance(val, bool):
- return 'flag'
-
- # await expr → infer from the inner expression
- if init_type == 'AwaitExpression':
- return _infer_from_init(init.get('argument'))
+ match init_type:
+ case 'ArrayExpression':
+ return 'arr'
+ case 'ObjectExpression':
+ return 'obj'
+ case 'Literal':
+ value = init.get('value')
+ if isinstance(value, str):
+ return 'str'
+ if isinstance(value, bool):
+ return 'flag'
+ case 'AwaitExpression':
+ return _infer_from_init(init.get('argument'))
return None
-def _infer_from_usage(binding):
+def _infer_from_usage(binding: object) -> str | None:
"""Infer a variable name from how it's used at reference sites."""
# Check what methods are called on this variable
methods = set()
- for ref_node, ref_parent, ref_key, ref_index in binding.references:
+ for ref_node, ref_parent, ref_key, _ref_index in binding.references:
if not ref_parent:
continue
# x.method() — ref is object of MemberExpression
@@ -262,26 +276,9 @@ def _infer_from_usage(binding):
if is_identifier(prop) and not ref_parent.get('computed'):
methods.add(prop.get('name'))
- # fs-like methods
- _FS_METHODS = {
- 'readFileSync',
- 'writeFileSync',
- 'existsSync',
- 'mkdirSync',
- 'statSync',
- 'readdirSync',
- 'unlinkSync',
- 'createWriteStream',
- 'createReadStream',
- 'readFile',
- 'writeFile',
- 'appendFileSync',
- }
if methods & _FS_METHODS:
return 'fs'
- # path-like methods
- _PATH_METHODS = {'join', 'resolve', 'basename', 'dirname', 'extname', 'normalize'}
if methods & _PATH_METHODS and not (methods - _PATH_METHODS - {'sep'}):
return 'path'
@@ -308,7 +305,7 @@ def _infer_from_usage(binding):
return None
-def _infer_loop_var(binding):
+def _infer_loop_var(binding: object) -> bool | None:
"""Check if this binding is a for-loop counter."""
node = binding.node
if not isinstance(node, dict):
@@ -319,40 +316,40 @@ def _infer_loop_var(binding):
init = node.get('init')
if not init or init.get('type') != 'Literal':
return None
- val = init.get('value')
- if not isinstance(val, (int, float)):
+ value = init.get('value')
+ if not isinstance(value, (int, float)):
return None
# Check if any assignment is an UpdateExpression (i++, i--)
if binding.assignments:
- for assign in binding.assignments:
- if isinstance(assign, dict) and assign.get('type') == 'UpdateExpression':
+ for assignment in binding.assignments:
+ if isinstance(assignment, dict) and assignment.get('type') == 'UpdateExpression':
return True
# Also check references for UpdateExpression parents
- for ref_node, ref_parent, ref_key, ref_index in binding.references:
+ for _ref_node, ref_parent, _ref_key, _ref_index in binding.references:
if ref_parent and ref_parent.get('type') == 'UpdateExpression':
return True
return None
-def _collect_pattern_idents(pattern, result):
+def _collect_pattern_idents(pattern: dict | None, result: list) -> None:
"""Collect all Identifier nodes from a destructuring pattern."""
if not isinstance(pattern, dict):
return
- pat_type = pattern.get('type')
- if pat_type == 'Identifier':
+ pattern_type = pattern.get('type')
+ if pattern_type == 'Identifier':
result.append(pattern)
- elif pat_type == 'ArrayPattern':
- for elem in pattern.get('elements', []):
- if elem:
- _collect_pattern_idents(elem, result)
- elif pat_type == 'ObjectPattern':
+ elif pattern_type == 'ArrayPattern':
+ for element in pattern.get('elements', []):
+ if element:
+ _collect_pattern_idents(element, result)
+ elif pattern_type == 'ObjectPattern':
for prop in pattern.get('properties', []):
- val = prop.get('value', prop.get('argument'))
- if val:
- _collect_pattern_idents(val, result)
- elif pat_type == 'RestElement':
+ value = prop.get('value', prop.get('argument'))
+ if value:
+ _collect_pattern_idents(value, result)
+ elif pattern_type == 'RestElement':
_collect_pattern_idents(pattern.get('argument'), result)
- elif pat_type == 'AssignmentPattern':
+ elif pattern_type == 'AssignmentPattern':
_collect_pattern_idents(pattern.get('left'), result)
@@ -361,26 +358,29 @@ class VariableRenamer(Transform):
rebuild_scope = True
- def execute(self):
- scope_tree, _ = build_scope_tree(self.ast)
+ def execute(self) -> bool:
+ if self.scope_tree is not None:
+ scope_tree = self.scope_tree
+ else:
+ scope_tree, _ = build_scope_tree(self.ast)
# Collect all non-obfuscated names across the entire tree to avoid conflicts
reserved = set(_JS_RESERVED)
self._collect_reserved(scope_tree, reserved)
# Rename bindings scope by scope
- gen = _name_generator(reserved)
+ generator = _name_generator(reserved)
# Track loop var counter for i, j, k assignment
self._loop_letters = list('ijklmn')
self._loop_idx = 0
- self._rename_scope(scope_tree, gen, reserved)
+ self._rename_scope(scope_tree, generator, reserved)
# Fix duplicate names in destructuring patterns (can come from broken obfuscated input)
self._fix_destructuring_dupes(reserved)
return self.has_changed()
- def _collect_reserved(self, scope, reserved):
+ def _collect_reserved(self, scope: object, reserved: set) -> None:
"""Collect all non-_0x binding names so we never generate a conflict."""
for name in scope.bindings:
if not _OBF_RE.match(name):
@@ -388,21 +388,21 @@ def _collect_reserved(self, scope, reserved):
for child in scope.children:
self._collect_reserved(child, reserved)
- def _rename_scope(self, scope, gen, reserved):
+ def _rename_scope(self, scope: object, generator: object, reserved: set) -> None:
"""Rename all _0x bindings in this scope and its children."""
for name, binding in list(scope.bindings.items()):
if not _OBF_RE.match(name):
continue
- new_name = self._pick_name(binding, gen, reserved)
+ new_name = self._pick_name(binding, generator, reserved)
reserved.add(new_name)
self._apply_rename(binding, new_name)
self.set_changed()
for child in scope.children:
- self._rename_scope(child, gen, reserved)
+ self._rename_scope(child, generator, reserved)
- def _pick_name(self, binding, gen, reserved):
+ def _pick_name(self, binding: object, generator: object, reserved: set) -> str:
"""Pick the best name for a binding using heuristics, with fallback."""
# 1. Check if it's a loop counter → i, j, k
if _infer_loop_var(binding):
@@ -437,9 +437,9 @@ def _pick_name(self, binding, gen, reserved):
pass # Fall through to sequential
# 5. Fallback: sequential name from generator
- return next(gen)
+ return next(generator)
- def _apply_rename(self, binding, new_name):
+ def _apply_rename(self, binding: object, new_name: str) -> None:
"""Rename a binding at its declaration site and all reference sites."""
old_name = binding.name
@@ -469,14 +469,14 @@ def _apply_rename(self, binding, new_name):
arg['name'] = new_name
# 2. Rename at all reference sites
- for ref_node, ref_parent, ref_key, ref_index in binding.references:
+ for ref_node, _ref_parent, _ref_key, _ref_index in binding.references:
if ref_node.get('type') == 'Identifier' and ref_node.get('name') == old_name:
ref_node['name'] = new_name
# 3. Update binding.name
binding.name = new_name
- def _fix_destructuring_dupes(self, reserved):
+ def _fix_destructuring_dupes(self, reserved: set) -> None:
"""Fix duplicate identifier names in destructuring patterns.
Obfuscators sometimes produce invalid code like `const [a, a, a] = x;`.
@@ -487,15 +487,15 @@ def _fix_destructuring_dupes(self, reserved):
the last step of the renamer post-pass.
"""
- def enter(node, parent, key, index):
+ def enter(node: dict, parent: dict | None, key: str | None, index: int | None) -> None:
if node.get('type') != 'VariableDeclarator':
return
- pat = node.get('id')
- if not pat or pat.get('type') not in ('ArrayPattern', 'ObjectPattern'):
+ pattern = node.get('id')
+ if not pattern or pattern.get('type') not in ('ArrayPattern', 'ObjectPattern'):
return
# Collect all identifier nodes in the pattern
idents = []
- _collect_pattern_idents(pat, idents)
+ _collect_pattern_idents(pattern, idents)
seen = {}
for ident_node in idents:
name = ident_node.get('name')
diff --git a/pyjsclear/transforms/xor_string_decode.py b/pyjsclear/transforms/xor_string_decode.py
index 0dab5e8..2ddd2ba 100644
--- a/pyjsclear/transforms/xor_string_decode.py
+++ b/pyjsclear/transforms/xor_string_decode.py
@@ -26,23 +26,23 @@
from .base import Transform
-def _extract_numeric_array(node):
+def _extract_numeric_array(node: dict | None) -> list[int] | None:
"""Extract a list of integers from an ArrayExpression node."""
if not node or node.get('type') != 'ArrayExpression':
return None
elements = node.get('elements', [])
result = []
- for el in elements:
- if not is_numeric_literal(el):
+ for element in elements:
+ if not is_numeric_literal(element):
return None
- val = el['value']
- if not isinstance(val, (int, float)) or val != int(val) or val < 0 or val > 255:
+ value = element['value']
+ if not isinstance(value, (int, float)) or value != int(value) or value < 0 or value > 255:
return None
- result.append(int(val))
+ result.append(int(value))
return result
-def _xor_decode(byte_array, prefix_len=4):
+def _xor_decode(byte_array: list[int], prefix_len: int = 4) -> str | None:
"""Decode XOR-obfuscated byte array: prefix XOR'd against remaining data."""
if len(byte_array) <= prefix_len:
return None
@@ -56,7 +56,7 @@ def _xor_decode(byte_array, prefix_len=4):
return None
-def _is_xor_decoder_function(node):
+def _is_xor_decoder_function(node: dict | None) -> bool:
"""Heuristic: check if a function body contains XOR (^=) on array elements
and a slice/Buffer.from pattern typical of XOR string decoders."""
if not node:
@@ -69,17 +69,17 @@ def _is_xor_decoder_function(node):
has_slice = [False]
has_tostring = [False]
- def scan(n, parent):
- if not isinstance(n, dict):
+ def scan(ast_node: dict, parent: dict) -> None:
+ if not isinstance(ast_node, dict):
return
# Look for ^= operator
- if n.get('type') == 'AssignmentExpression' and n.get('operator') == '^=':
+ if ast_node.get('type') == 'AssignmentExpression' and ast_node.get('operator') == '^=':
has_xor[0] = True
# Look for .slice or .from
- if n.get('type') == 'MemberExpression':
- prop = n.get('property')
- if prop:
- name = prop.get('name') or (prop.get('value') if prop.get('type') == 'Literal' else None)
+ if ast_node.get('type') == 'MemberExpression':
+ property_node = ast_node.get('property')
+ if property_node:
+ name = property_node.get('name') or (property_node.get('value') if property_node.get('type') == 'Literal' else None)
if name in ('slice', 'from'):
has_slice[0] = True
if name in ('toString', 'decode'):
@@ -92,11 +92,11 @@ def scan(n, parent):
class XorStringDecoder(Transform):
"""Decode XOR-obfuscated string constants and inline them."""
- def execute(self):
+ def execute(self) -> bool:
# Phase 1: Find XOR decoder functions
- decoder_funcs = set()
+ decoder_funcs: set[str] = set()
- def find_decoders(node, parent):
+ def find_decoders(node: dict, parent: dict) -> None:
if node.get('type') not in ('FunctionDeclaration', 'FunctionExpression'):
return
params = node.get('params', [])
@@ -111,9 +111,9 @@ def find_decoders(node, parent):
if func_id and is_identifier(func_id):
decoder_funcs.add(func_id['name'])
elif parent and parent.get('type') == 'VariableDeclarator':
- decl_id = parent.get('id')
- if decl_id and is_identifier(decl_id):
- decoder_funcs.add(decl_id['name'])
+ declaration_id = parent.get('id')
+ if declaration_id and is_identifier(declaration_id):
+ decoder_funcs.add(declaration_id['name'])
simple_traverse(self.ast, find_decoders)
@@ -121,14 +121,14 @@ def find_decoders(node, parent):
return False
# Phase 2: Find calls like `var X = decoder([...bytes...])` and decode
- decoded_vars = {} # var_name → decoded_string
+ decoded_vars: dict[str, str] = {} # var_name → decoded_string
- def find_calls(node, parent):
+ def find_calls(node: dict, parent: dict) -> None:
if node.get('type') != 'VariableDeclarator':
return
- decl_id = node.get('id')
+ declaration_id = node.get('id')
init = node.get('init')
- if not is_identifier(decl_id) or not init:
+ if not is_identifier(declaration_id) or not init:
return
if init.get('type') != 'CallExpression':
return
@@ -143,7 +143,7 @@ def find_calls(node, parent):
return
decoded = _xor_decode(byte_array)
if decoded is not None:
- decoded_vars[decl_id['name']] = decoded
+ decoded_vars[declaration_id['name']] = decoded
simple_traverse(self.ast, find_calls)
@@ -152,12 +152,12 @@ def find_calls(node, parent):
# Phase 3: Replace computed member accesses obj[_0xVAR] → obj.decoded
# and standalone identifier refs with string literals
- def replace_refs(node, parent, key, index):
+ def replace_refs(node: dict, parent: dict, key: str, index: int | None) -> dict | None:
# Handle computed member: obj[_0xVAR] → obj.decoded or obj["decoded"]
if node.get('type') == 'MemberExpression' and node.get('computed'):
- prop = node.get('property')
- if is_identifier(prop) and prop['name'] in decoded_vars:
- decoded = decoded_vars[prop['name']]
+ property_node = node.get('property')
+ if is_identifier(property_node) and property_node['name'] in decoded_vars:
+ decoded = decoded_vars[property_node['name']]
if is_valid_identifier(decoded):
node['property'] = make_identifier(decoded)
node['computed'] = False
@@ -193,32 +193,32 @@ def replace_refs(node, parent, key, index):
return self.has_changed()
- def _remove_dead_declarations(self, decoded_vars):
+ def _remove_dead_declarations(self, decoded_vars: dict[str, str]) -> None:
"""Remove var X = decoder([...]) declarations that are now inlined."""
remaining_refs = {name: 0 for name in decoded_vars}
- def count(node, parent):
+ def count_refs(node: dict, parent: dict) -> None:
if is_identifier(node) and node['name'] in remaining_refs:
if parent and parent.get('type') == 'VariableDeclarator' and node is parent.get('id'):
return
remaining_refs[node['name']] = remaining_refs.get(node['name'], 0) + 1
- simple_traverse(self.ast, count)
+ simple_traverse(self.ast, count_refs)
- dead_vars = {name for name, count in remaining_refs.items() if count == 0}
+ dead_vars = {name for name, ref_count in remaining_refs.items() if ref_count == 0}
if not dead_vars:
return
- def remove_decls(node, parent, key, index):
+ def remove_decls(node: dict, parent: dict, key: str, index: int | None) -> dict | None:
if node.get('type') != 'VariableDeclaration':
return
decls = node.get('declarations', [])
remaining = []
- for d in decls:
- did = d.get('id')
- if is_identifier(did) and did['name'] in dead_vars:
+ for declaration in decls:
+ declaration_id = declaration.get('id')
+ if is_identifier(declaration_id) and declaration_id['name'] in dead_vars:
continue
- remaining.append(d)
+ remaining.append(declaration)
if len(remaining) == len(decls):
return
if not remaining:
diff --git a/pyjsclear/traverser.py b/pyjsclear/traverser.py
index a3337d6..583425a 100644
--- a/pyjsclear/traverser.py
+++ b/pyjsclear/traverser.py
@@ -1,5 +1,8 @@
"""ESTree AST traversal with visitor pattern."""
+from collections.abc import Callable
+from typing import Any
+
from .utils.ast_helpers import _CHILD_KEYS
from .utils.ast_helpers import get_child_keys
@@ -12,135 +15,368 @@
# Local aliases for hot-path performance (~15% faster traversal)
_dict = dict
_list = list
-_isinstance = isinstance
+_type = type
+# Maximum recursion depth before falling back to iterative traversal.
+# CPython default recursion limit is ~1000; we switch well before that.
+_MAX_RECURSIVE_DEPTH = 500
-def traverse(node, visitor):
- """Traverse an ESTree AST calling visitor callbacks.
+# Stack frame opcodes for iterative traverse
+_OP_ENTER = 0
+_OP_EXIT = 1
+_OP_LIST_START = 2
+_OP_LIST_RESUME = 3
- visitor should be a dict or object with optional 'enter' and 'exit' callables.
- Each callback receives (node, parent, key, index) and can return:
- - None: continue normally
- - REMOVE: remove this node from parent
- - SKIP: (enter only) skip traversing children
- - a dict (node): replace this node with the returned node
- """
- if _isinstance(visitor, _dict):
- enter_fn = visitor.get('enter')
- exit_fn = visitor.get('exit')
- else:
- enter_fn = getattr(visitor, 'enter', None)
- exit_fn = getattr(visitor, 'exit', None)
+def _traverse_iterative(node: dict, enter_fn: Callable | None, exit_fn: Callable | None) -> None:
+ """Iterative stack-based traverse. Handles both enter and exit callbacks."""
child_keys_map = _CHILD_KEYS
_REMOVE = REMOVE
_SKIP = SKIP
+ _get_child_keys = get_child_keys
- def _visit(current_node, parent, key, index):
- node_type = current_node.get('type')
- if node_type is None:
- return current_node
+ stack = [(_OP_ENTER, node, None, None, None)]
+ stack_pop = stack.pop
+ stack_append = stack.append
+
+ while stack:
+ frame = stack_pop()
+ op = frame[0]
+
+ if op == _OP_ENTER:
+ current_node = frame[1]
+ parent = frame[2]
+ key = frame[3]
+ index = frame[4]
+
+ node_type = current_node.get('type')
+ if node_type is None:
+ continue
- # Enter
- if enter_fn:
- result = enter_fn(current_node, parent, key, index)
+ if enter_fn:
+ result = enter_fn(current_node, parent, key, index)
+ if result is _REMOVE:
+ if parent is not None:
+ if index is not None:
+ parent[key].pop(index)
+ else:
+ parent[key] = None
+ continue
+ if result is _SKIP:
+ if exit_fn:
+ exit_result = exit_fn(current_node, parent, key, index)
+ if exit_result is _REMOVE:
+ if parent is not None:
+ if index is not None:
+ parent[key].pop(index)
+ else:
+ parent[key] = None
+ elif _type(exit_result) is _dict and 'type' in exit_result:
+ if parent is not None:
+ if index is not None:
+ parent[key][index] = exit_result
+ else:
+ parent[key] = exit_result
+ continue
+ if _type(result) is _dict and 'type' in result:
+ current_node = result
+ if parent is not None:
+ if index is not None:
+ parent[key][index] = current_node
+ else:
+ parent[key] = current_node
+ node_type = current_node.get('type')
+
+ if exit_fn:
+ stack_append((_OP_EXIT, current_node, parent, key, index))
+
+ child_keys = child_keys_map.get(node_type)
+ if child_keys is None:
+ child_keys = _get_child_keys(current_node)
+
+ for index in range(len(child_keys) - 1, -1, -1):
+ child_key = child_keys[index]
+ child = current_node.get(child_key)
+ if child is None:
+ continue
+ if _type(child) is _list:
+ stack_append((_OP_LIST_START, current_node, child_key, 0, None))
+ elif _type(child) is _dict and 'type' in child:
+ stack_append((_OP_ENTER, child, current_node, child_key, None))
+
+ elif op == _OP_EXIT:
+ current_node = frame[1]
+ parent = frame[2]
+ key = frame[3]
+ index = frame[4]
+ result = exit_fn(current_node, parent, key, index)
if result is _REMOVE:
- return _REMOVE
- if result is _SKIP:
- if not exit_fn:
- return current_node
- exit_result = exit_fn(current_node, parent, key, index)
- if exit_result is _REMOVE:
- return _REMOVE
- if _isinstance(exit_result, _dict) and 'type' in exit_result:
- return exit_result
- return current_node
- if _isinstance(result, _dict) and 'type' in result:
- current_node = result
if parent is not None:
if index is not None:
- parent[key][index] = current_node
+ parent[key].pop(index)
+ else:
+ parent[key] = None
+ elif _type(result) is _dict and 'type' in result:
+ if parent is not None:
+ if index is not None:
+ parent[key][index] = result
else:
- parent[key] = current_node
+ parent[key] = result
+
+ elif op == _OP_LIST_START:
+ parent_node = frame[1]
+ child_key = frame[2]
+ idx = frame[3]
+ child_list = parent_node[child_key]
+ if idx >= len(child_list):
+ continue
+ item = child_list[idx]
+ if _type(item) is _dict and 'type' in item:
+ stack_append((_OP_LIST_RESUME, parent_node, child_key, idx, len(child_list)))
+ stack_append((_OP_ENTER, item, parent_node, child_key, idx))
+ else:
+ stack_append((_OP_LIST_START, parent_node, child_key, idx + 1, None))
+
+ else: # _OP_LIST_RESUME
+ parent_node = frame[1]
+ child_key = frame[2]
+ idx = frame[3]
+ pre_len = frame[4]
+ child_list = parent_node[child_key]
+ current_len = len(child_list)
+ if current_len < pre_len:
+ next_idx = idx
+ else:
+ next_idx = idx + 1
+ if next_idx < current_len:
+ stack_append((_OP_LIST_START, parent_node, child_key, next_idx, None))
+
+
+def _traverse_enter_only(node: dict, enter_fn: Callable) -> None:
+ """Recursive enter-only traverse with depth-limited fallback to iterative."""
+ child_keys_map = _CHILD_KEYS
+ _REMOVE = REMOVE
+ _SKIP = SKIP
+ _get_child_keys = get_child_keys
+ _max_depth = _MAX_RECURSIVE_DEPTH
+
+ def _visit(current_node: dict, parent: dict | None, key: str | None, index: int | None, depth: int) -> None:
+ node_type = current_node['type']
+ if node_type is None:
+ return
- # Visit children
- child_keys = child_keys_map.get(current_node.get('type'))
+ result = enter_fn(current_node, parent, key, index)
+ if result is _REMOVE:
+ if parent is not None:
+ if index is not None:
+ parent[key].pop(index)
+ else:
+ parent[key] = None
+ return
+ if result is _SKIP:
+ return
+ if _type(result) is _dict and 'type' in result:
+ current_node = result
+ if parent is not None:
+ if index is not None:
+ parent[key][index] = current_node
+ else:
+ parent[key] = current_node
+ node_type = current_node['type']
+
+ # Depth check: fall back to iterative for deep subtrees
+ if depth > _max_depth:
+ _traverse_iterative(current_node, enter_fn, None)
+ return
+
+ child_keys = child_keys_map.get(node_type)
if child_keys is None:
- child_keys = get_child_keys(current_node)
+ child_keys = _get_child_keys(current_node)
+
+ next_depth = depth + 1
for child_key in child_keys:
child = current_node.get(child_key)
if child is None:
continue
- if _isinstance(child, _list):
+ if _type(child) is _list:
+ child_len = len(child)
i = 0
- while i < len(child):
+ while i < child_len:
item = child[i]
- if _isinstance(item, _dict) and 'type' in item:
- result = _visit(item, current_node, child_key, i)
- if result is _REMOVE:
- child.pop(i)
+ if _type(item) is _dict and 'type' in item:
+ _visit(item, current_node, child_key, i, next_depth)
+ new_len = len(child)
+ if new_len < child_len:
+ child_len = new_len
continue
- elif result is not item:
- child[i] = result
+ child_len = new_len
i += 1
- elif _isinstance(child, _dict) and 'type' in child:
- result = _visit(child, current_node, child_key, None)
- if result is _REMOVE:
- current_node[child_key] = None
- elif result is not child:
- current_node[child_key] = result
+ elif _type(child) is _dict and 'type' in child:
+ _visit(child, current_node, child_key, None, next_depth)
- # Exit
- if exit_fn:
- result = exit_fn(current_node, parent, key, index)
- if result is _REMOVE:
- return _REMOVE
- if _isinstance(result, _dict) and 'type' in result:
- return result
+ if _type(node) is _dict and 'type' in node:
+ _visit(node, None, None, None, 0)
- return current_node
- _visit(node, None, None, None)
+def traverse(node: dict, visitor: dict | object) -> None:
+ """Traverse an ESTree AST calling visitor callbacks.
+ visitor should be a dict or object with optional 'enter' and 'exit' callables.
+ Each callback receives (node, parent, key, index) and can return:
+ - None: continue normally
+ - REMOVE: remove this node from parent
+ - SKIP: (enter only) skip traversing children
+ - a dict (node): replace this node with the returned node
-def simple_traverse(node, callback):
- """Simple traversal that calls callback(node, parent) for every node.
- No replacement support - just visiting.
+ Uses recursive traversal for enter-only visitors (fast path) with
+ automatic fallback to iterative for deep subtrees. Uses iterative
+ traversal when an exit callback is present.
"""
+ if isinstance(visitor, _dict):
+ enter_fn = visitor.get('enter')
+ exit_fn = visitor.get('exit')
+ else:
+ enter_fn = getattr(visitor, 'enter', None)
+ exit_fn = getattr(visitor, 'exit', None)
+
+ if exit_fn is None and enter_fn is not None:
+ _traverse_enter_only(node, enter_fn)
+ else:
+ _traverse_iterative(node, enter_fn, exit_fn)
+
+
+def _simple_traverse_iterative(node: dict, callback: Callable) -> None:
+ """Iterative stack-based simple traversal."""
+ child_keys_map = _CHILD_KEYS
+ _get_child_keys = get_child_keys
+
+ stack = [(node, None)]
+ stack_pop = stack.pop
+ stack_append = stack.append
+
+ while stack:
+ current_node, parent = stack_pop()
+ node_type = current_node['type']
+ if node_type is None:
+ continue
+ callback(current_node, parent)
+ child_keys = child_keys_map.get(node_type)
+ if child_keys is None:
+ child_keys = _get_child_keys(current_node)
+ for key in reversed(child_keys):
+ child = current_node.get(key)
+ if child is None:
+ continue
+ if _type(child) is _list:
+ for i in range(len(child) - 1, -1, -1):
+ item = child[i]
+ if _type(item) is _dict and 'type' in item:
+ stack_append((item, current_node))
+ elif _type(child) is _dict and 'type' in child:
+ stack_append((child, current_node))
+
+
+def _simple_traverse_recursive(node: dict, callback: Callable) -> None:
+ """Recursive simple traversal with depth-limited fallback to iterative."""
child_keys_map = _CHILD_KEYS
+ _get_child_keys = get_child_keys
+ _max_depth = _MAX_RECURSIVE_DEPTH
- def _visit(current_node, parent):
- node_type = current_node.get('type')
+ def _visit(current_node: dict, parent: dict | None, depth: int) -> None:
+ node_type = current_node['type']
if node_type is None:
return
callback(current_node, parent)
+
+ if depth > _max_depth:
+ # Fall back to iterative for this subtree's children
+ child_keys = child_keys_map.get(node_type)
+ if child_keys is None:
+ child_keys = _get_child_keys(current_node)
+ for key in child_keys:
+ child = current_node.get(key)
+ if child is None:
+ continue
+ if _type(child) is _list:
+ for item in child:
+ if _type(item) is _dict and 'type' in item:
+ _simple_traverse_iterative(item, callback)
+ elif _type(child) is _dict and 'type' in child:
+ _simple_traverse_iterative(child, callback)
+ return
+
child_keys = child_keys_map.get(node_type)
if child_keys is None:
- child_keys = get_child_keys(current_node)
+ child_keys = _get_child_keys(current_node)
+ next_depth = depth + 1
for key in child_keys:
child = current_node.get(key)
if child is None:
continue
- if _isinstance(child, _list):
+ if _type(child) is _list:
for item in child:
- if _isinstance(item, _dict) and 'type' in item:
- _visit(item, current_node)
- elif _isinstance(child, _dict) and 'type' in child:
- _visit(child, current_node)
+ if _type(item) is _dict and 'type' in item:
+ _visit(item, current_node, next_depth)
+ elif _type(child) is _dict and 'type' in child:
+ _visit(child, current_node, next_depth)
+
+ if _type(node) is _dict and 'type' in node:
+ _visit(node, None, 0)
+
+
+def simple_traverse(node: dict, callback: Callable) -> None:
+ """Simple traversal that calls callback(node, parent) for every node.
+ No replacement support - just visiting.
- _visit(node, None)
+ Uses recursive traversal with automatic fallback to iterative for deep subtrees.
+ """
+ _simple_traverse_recursive(node, callback)
-def collect_nodes(ast, node_type):
+def collect_nodes(ast: dict, node_type: str) -> list[dict]:
"""Collect all nodes of a given type."""
- result = []
+ collected = []
- def cb(node, parent):
+ def collect_callback(node: dict, parent: dict | None) -> None:
if node.get('type') == node_type:
- result.append(node)
+ collected.append(node)
+
+ simple_traverse(ast, collect_callback)
+ return collected
+
+
+def build_parent_map(ast: dict) -> dict:
+ """Build a map from id(node) -> (parent, key, index) for all nodes in the AST.
+
+ This allows O(1) parent lookups instead of O(n) find_parent() calls.
+ """
+ parent_map = {}
+ child_keys_map = _CHILD_KEYS
+ _get_child_keys = get_child_keys
+
+ stack = [(ast, None, None, None)]
+ while stack:
+ current_node, parent, key, index = stack.pop()
+ node_type = current_node['type']
+ if node_type is None:
+ continue
+ parent_map[id(current_node)] = (parent, key, index)
+ child_keys = child_keys_map.get(node_type)
+ if child_keys is None:
+ child_keys = _get_child_keys(current_node)
+ for child_key in child_keys:
+ child = current_node.get(child_key)
+ if child is None:
+ continue
+ if _type(child) is _list:
+ for i in range(len(child) - 1, -1, -1):
+ item = child[i]
+ if _type(item) is _dict and 'type' in item:
+ stack.append((item, current_node, child_key, i))
+ elif _type(child) is _dict and 'type' in child:
+ stack.append((child, current_node, child_key, None))
- simple_traverse(ast, cb)
- return result
+ return parent_map
class _FoundParent(Exception):
@@ -148,14 +384,17 @@ class _FoundParent(Exception):
__slots__ = ('value',)
- def __init__(self, value):
+ def __init__(self, value: tuple) -> None:
self.value = value
-def find_parent(ast, target_node):
- """Find the parent of a node in the AST. Returns (parent, key, index) or None."""
+def find_parent(ast: dict, target_node: dict) -> tuple | None:
+ """Find the parent of a node in the AST. Returns (parent, key, index) or None.
+
+ For multiple lookups, consider using build_parent_map() instead.
+ """
- def _visit(node):
+ def _visit(node: dict) -> None:
if not isinstance(node, dict) or 'type' not in node:
return
for child_key in get_child_keys(node):
@@ -163,9 +402,9 @@ def _visit(node):
if child is None:
continue
if isinstance(child, list):
- for i, item in enumerate(child):
+ for child_index, item in enumerate(child):
if item is target_node:
- raise _FoundParent((node, child_key, i))
+ raise _FoundParent((node, child_key, child_index))
_visit(item)
elif isinstance(child, dict):
if child is target_node:
@@ -174,12 +413,12 @@ def _visit(node):
try:
_visit(ast)
- except _FoundParent as found:
- return found.value
+ except _FoundParent as found_parent:
+ return found_parent.value
return None
-def replace_in_parent(parent, key, index, new_node):
+def replace_in_parent(parent: dict, key: str, index: int | None, new_node: dict) -> None:
"""Replace a node within its parent."""
if index is not None:
parent[key][index] = new_node
@@ -187,7 +426,7 @@ def replace_in_parent(parent, key, index, new_node):
parent[key] = new_node
-def remove_from_parent(parent, key, index):
+def remove_from_parent(parent: dict, key: str, index: int | None) -> None:
"""Remove a node from its parent."""
if index is not None:
parent[key].pop(index)
diff --git a/pyjsclear/utils/ast_helpers.py b/pyjsclear/utils/ast_helpers.py
index cb4e2b7..4c1d94b 100644
--- a/pyjsclear/utils/ast_helpers.py
+++ b/pyjsclear/utils/ast_helpers.py
@@ -4,42 +4,42 @@
import re
-def deep_copy(node):
+def deep_copy(node: dict) -> dict:
"""Deep copy an AST node."""
return copy.deepcopy(node)
-def is_literal(node):
+def is_literal(node: object) -> bool:
"""Check if node is a Literal."""
return isinstance(node, dict) and node.get('type') == 'Literal'
-def is_identifier(node):
+def is_identifier(node: object) -> bool:
"""Check if node is an Identifier."""
return isinstance(node, dict) and node.get('type') == 'Identifier'
-def is_string_literal(node):
+def is_string_literal(node: object) -> bool:
"""Check if node is a string Literal."""
return is_literal(node) and isinstance(node.get('value'), str)
-def is_numeric_literal(node):
+def is_numeric_literal(node: object) -> bool:
"""Check if node is a numeric Literal."""
return is_literal(node) and isinstance(node.get('value'), (int, float))
-def is_boolean_literal(node):
+def is_boolean_literal(node: object) -> bool:
"""Check if node is a boolean-ish literal (true/false or !0/!1)."""
return is_literal(node) and isinstance(node.get('value'), bool)
-def is_null_literal(node):
+def is_null_literal(node: object) -> bool:
"""Check if node is null literal."""
return is_literal(node) and node.get('value') is None and node.get('raw') == 'null'
-def is_undefined(node):
+def is_undefined(node: object) -> bool:
"""Check if node represents undefined (identifier or ``void 0``)."""
if is_identifier(node) and node.get('name') == 'undefined':
return True
@@ -55,14 +55,14 @@ def is_undefined(node):
return False
-def get_literal_value(node):
+def get_literal_value(node: object) -> tuple:
"""Extract the value from a literal node. Returns (value, True) or (None, False)."""
if not is_literal(node):
return None, False
return node.get('value'), True
-def make_literal(value, raw=None):
+def make_literal(value: object, raw: str | None = None) -> dict:
"""Create a Literal AST node."""
if raw is not None:
return {'type': 'Literal', 'value': value, 'raw': raw}
@@ -90,22 +90,22 @@ def make_literal(value, raw=None):
return {'type': 'Literal', 'value': value, 'raw': raw}
-def make_identifier(name):
+def make_identifier(name: str) -> dict:
"""Create an Identifier AST node."""
return {'type': 'Identifier', 'name': name}
-def make_expression_statement(expr):
+def make_expression_statement(expr: dict) -> dict:
"""Wrap an expression in an ExpressionStatement."""
return {'type': 'ExpressionStatement', 'expression': expr}
-def make_block_statement(body):
+def make_block_statement(body: list) -> dict:
"""Create a BlockStatement."""
return {'type': 'BlockStatement', 'body': body}
-def make_var_declaration(name, init=None, kind='var'):
+def make_var_declaration(name: str, init: dict | None = None, kind: str = 'var') -> dict:
"""Create a VariableDeclaration with a single declarator."""
return {
'type': 'VariableDeclaration',
@@ -117,7 +117,7 @@ def make_var_declaration(name, init=None, kind='var'):
_IDENT_RE = re.compile(r'^[a-zA-Z_$][a-zA-Z0-9_$]*$')
-def is_valid_identifier(name):
+def is_valid_identifier(name: object) -> bool:
"""Check if a string is a valid JS identifier (for obj.prop access)."""
if not isinstance(name, str) or not name:
return False
@@ -207,7 +207,7 @@ def is_valid_identifier(name):
)
-def get_child_keys(node):
+def get_child_keys(node: object) -> tuple | list:
"""Get keys of a node that may contain child nodes/arrays."""
if not isinstance(node, dict) or 'type' not in node:
return ()
@@ -217,15 +217,15 @@ def get_child_keys(node):
return keys
# Fallback: return all keys that look like they might contain nodes
return [
- k
- for k, v in node.items()
- if k not in _SKIP_KEYS
- and not (k == 'expression' and node_type != 'ExpressionStatement')
- and isinstance(v, (dict, list))
+ key
+ for key, value in node.items()
+ if key not in _SKIP_KEYS
+ and not (key == 'expression' and node_type != 'ExpressionStatement')
+ and isinstance(value, (dict, list))
]
-def replace_identifiers(node, param_map):
+def replace_identifiers(node: dict, param_map: dict) -> None:
"""Replace Identifier nodes whose names are in param_map with deep copies.
Skips non-computed property names in MemberExpressions.
@@ -238,10 +238,10 @@ def replace_identifiers(node, param_map):
continue
is_noncomputed_prop = key == 'property' and node.get('type') == 'MemberExpression' and not node.get('computed')
if isinstance(child, list):
- for i, item in enumerate(child):
+ for index, item in enumerate(child):
if isinstance(item, dict) and item.get('type') == 'Identifier':
if not is_noncomputed_prop and item.get('name', '') in param_map:
- child[i] = copy.deepcopy(param_map[item['name']])
+ child[index] = copy.deepcopy(param_map[item['name']])
elif isinstance(item, dict) and 'type' in item:
replace_identifiers(item, param_map)
elif isinstance(child, dict):
@@ -252,12 +252,12 @@ def replace_identifiers(node, param_map):
replace_identifiers(child, param_map)
-def identifiers_match(a, b):
+def identifiers_match(node_a: object, node_b: object) -> bool:
"""Check if two nodes are the same identifier."""
- return is_identifier(a) and is_identifier(b) and a.get('name') == b.get('name')
+ return is_identifier(node_a) and is_identifier(node_b) and node_a.get('name') == node_b.get('name')
-def is_side_effect_free(node):
+def is_side_effect_free(node: object) -> bool:
"""Check if an expression node is side-effect-free (safe to discard)."""
if not isinstance(node, dict):
return False
@@ -288,7 +288,7 @@ def is_side_effect_free(node):
return False
-def get_member_names(node):
+def get_member_names(node: object) -> tuple[str, str] | tuple[None, None]:
"""Extract (object_name, property_name) from a MemberExpression.
Handles both computed (obj["prop"]) and non-computed (obj.prop) forms.
@@ -311,17 +311,17 @@ def get_member_names(node):
return None, None
-def nodes_equal(a, b):
+def nodes_equal(node_a: object, node_b: object) -> bool:
"""Check if two AST nodes are structurally equal (ignoring position info)."""
- if type(a) != type(b):
+ if type(node_a) != type(node_b):
return False
- match a:
+ match node_a:
case dict():
- keys_a = {k for k in a if k not in ('start', 'end', 'loc', 'range')}
- keys_b = {k for k in b if k not in ('start', 'end', 'loc', 'range')}
+ keys_a = {key for key in node_a if key not in ('start', 'end', 'loc', 'range')}
+ keys_b = {key for key in node_b if key not in ('start', 'end', 'loc', 'range')}
if keys_a != keys_b:
return False
- return all(nodes_equal(a[k], b[k]) for k in keys_a)
+ return all(nodes_equal(node_a[key], node_b[key]) for key in keys_a)
case list():
- return len(a) == len(b) and all(nodes_equal(x, y) for x, y in zip(a, b))
- return a == b
+ return len(node_a) == len(node_b) and all(nodes_equal(x, y) for x, y in zip(node_a, node_b))
+ return node_a == node_b
diff --git a/pyjsclear/utils/string_decoders.py b/pyjsclear/utils/string_decoders.py
index cb2b7fa..bbd428c 100644
--- a/pyjsclear/utils/string_decoders.py
+++ b/pyjsclear/utils/string_decoders.py
@@ -1,10 +1,9 @@
"""String decoder implementations for obfuscator.io patterns."""
-from enum import Enum
-from urllib.parse import unquote
+from enum import StrEnum
-class DecoderType(Enum):
+class DecoderType(StrEnum):
BASIC = 'basic'
BASE_64 = 'base64'
RC4 = 'rc4'
@@ -13,7 +12,7 @@ class DecoderType(Enum):
_BASE_64_ALPHABET = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+/='
-def base64_transform(encoded_string):
+def base64_transform(encoded_string: str) -> str:
"""Decode obfuscator.io's custom base64 encoding."""
# Decode 4 base64 chars into 3 bytes using 6-bit groups.
# bit_buffer accumulates bits; every non-first char in a group yields a byte
@@ -39,19 +38,19 @@ def base64_transform(encoded_string):
class StringDecoder:
"""Base string decoder."""
- def __init__(self, string_array, index_offset):
+ def __init__(self, string_array: list[str], index_offset: int) -> None:
self.string_array = string_array
self.index_offset = index_offset
self.is_first_call = True
@property
- def type(self):
+ def type(self) -> DecoderType:
return DecoderType.BASIC
- def get_string(self, index, *args):
+ def get_string(self, index: int, *args) -> str | None:
raise NotImplementedError
- def get_string_for_rotation(self, index, *args, **kwargs):
+ def get_string_for_rotation(self, index: int, *args, **kwargs) -> str | None:
if self.is_first_call:
self.is_first_call = False
raise RuntimeError('First call')
@@ -62,10 +61,10 @@ class BasicStringDecoder(StringDecoder):
"""Simple array index + offset decoder."""
@property
- def type(self):
+ def type(self) -> DecoderType:
return DecoderType.BASIC
- def get_string(self, index, *args):
+ def get_string(self, index: int, *args) -> str | None:
array_index = index + self.index_offset
if 0 <= array_index < len(self.string_array):
return self.string_array[array_index]
@@ -75,15 +74,15 @@ def get_string(self, index, *args):
class Base64StringDecoder(StringDecoder):
"""Base64 string decoder."""
- def __init__(self, string_array, index_offset):
+ def __init__(self, string_array: list[str], index_offset: int) -> None:
super().__init__(string_array, index_offset)
- self._cache = {}
+ self._cache: dict[int, str] = {}
@property
- def type(self):
+ def type(self) -> DecoderType:
return DecoderType.BASE_64
- def get_string(self, index, *args):
+ def get_string(self, index: int, *args) -> str | None:
if index in self._cache:
return self._cache[index]
array_index = index + self.index_offset
@@ -97,15 +96,15 @@ def get_string(self, index, *args):
class Rc4StringDecoder(StringDecoder):
"""RC4 string decoder."""
- def __init__(self, string_array, index_offset):
+ def __init__(self, string_array: list[str], index_offset: int) -> None:
super().__init__(string_array, index_offset)
- self._cache = {}
+ self._cache: dict[tuple[int, str], str] = {}
@property
- def type(self):
+ def type(self) -> DecoderType:
return DecoderType.RC4
- def get_string(self, index, key=None):
+ def get_string(self, index: int, key: str | None = None) -> str | None:
if not key:
return None
# Include key in cache to avoid collisions with different RC4 keys
@@ -120,7 +119,7 @@ def get_string(self, index, key=None):
self._cache[cache_key] = decoded
return decoded
- def _rc4_decode(self, encoded_string, key):
+ def _rc4_decode(self, encoded_string: str, key: str) -> str:
"""RC4 decryption with base64 pre-processing."""
encoded_string = base64_transform(encoded_string)
# KSA
diff --git a/tests/conftest.py b/tests/conftest.py
index fed2159..7b08d41 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,9 @@
"""Root test configuration."""
+import pytest
-def pytest_addoption(parser):
+
+def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
'--update-snapshots',
action='store_true',
diff --git a/tests/fuzz/conftest_fuzz.py b/tests/fuzz/conftest_fuzz.py
index 99553d7..2ce846f 100644
--- a/tests/fuzz/conftest_fuzz.py
+++ b/tests/fuzz/conftest_fuzz.py
@@ -1,12 +1,14 @@
"""Shared helpers for fuzz targets."""
+import argparse
import os
import random
-import struct
import sys
+import time
+from typing import Any, Callable
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
MAX_INPUT_SIZE = 102_400 # 100KB cap
@@ -14,72 +16,71 @@
# ESTree node types for synthetic AST generation
_STATEMENT_TYPES = [
- "ExpressionStatement",
- "VariableDeclaration",
- "ReturnStatement",
- "IfStatement",
- "WhileStatement",
- "ForStatement",
- "BlockStatement",
- "EmptyStatement",
- "BreakStatement",
- "ContinueStatement",
- "ThrowStatement",
+ 'ExpressionStatement',
+ 'VariableDeclaration',
+ 'ReturnStatement',
+ 'IfStatement',
+ 'WhileStatement',
+ 'ForStatement',
+ 'BlockStatement',
+ 'EmptyStatement',
+ 'BreakStatement',
+ 'ContinueStatement',
+ 'ThrowStatement',
]
_EXPRESSION_TYPES = [
- "Literal",
- "Identifier",
- "BinaryExpression",
- "UnaryExpression",
- "CallExpression",
- "MemberExpression",
- "AssignmentExpression",
- "ConditionalExpression",
- "ArrayExpression",
- "ObjectExpression",
- "FunctionExpression",
- "ThisExpression",
+ 'Literal',
+ 'Identifier',
+ 'BinaryExpression',
+ 'UnaryExpression',
+ 'CallExpression',
+ 'MemberExpression',
+ 'AssignmentExpression',
+ 'ConditionalExpression',
+ 'ArrayExpression',
+ 'ObjectExpression',
+ 'FunctionExpression',
+ 'ThisExpression',
]
_BINARY_OPS = [
- "+",
- "-",
- "*",
- "/",
- "%",
- "==",
- "!=",
- "===",
- "!==",
- "<",
- ">",
- "<=",
- ">=",
- "&&",
- "||",
- "&",
- "|",
- "^",
- "<<",
- ">>",
- ">>>",
+ '+',
+ '-',
+ '*',
+ '/',
+ '%',
+ '==',
+ '!=',
+ '===',
+ '!==',
+ '<',
+ '>',
+ '<=',
+ '>=',
+ '&&',
+ '||',
+ '&',
+ '|',
+ '^',
+ '<<',
+ '>>',
+ '>>>',
]
-_UNARY_OPS = ["-", "+", "!", "~", "typeof", "void"]
+_UNARY_OPS = ['-', '+', '!', '~', 'typeof', 'void']
-def bytes_to_js(data):
+def bytes_to_js(data: bytes) -> str:
"""Decode bytes to a JS string with size limit."""
- text = data[:MAX_INPUT_SIZE].decode("utf-8", errors="replace")
- return text
+ return data[:MAX_INPUT_SIZE].decode('utf-8', errors='replace')
-def bytes_to_ast_dict(data, max_depth=5, max_children=4):
+def bytes_to_ast_dict(data: bytes, max_depth: int = 5, max_children: int = 4) -> dict:
"""Build a synthetic ESTree AST dict from bytes for testing generator/traverser."""
- rng = random.Random(int.from_bytes(data[:8].ljust(8, b"\x00"), "little"))
+ rng = random.Random(int.from_bytes(data[:8].ljust(8, b'\x00'), 'little'))
pos = 8
- def consume_byte():
+ def consume_byte() -> int:
nonlocal pos
if pos < len(data):
val = data[pos]
@@ -87,214 +88,218 @@ def consume_byte():
return val
return rng.randint(0, 255)
- def make_literal():
- kind = consume_byte() % 4
- if kind == 0:
- return {"type": "Literal", "value": consume_byte(), "raw": str(consume_byte())}
- elif kind == 1:
- return {"type": "Literal", "value": True, "raw": "true"}
- elif kind == 2:
- return {"type": "Literal", "value": None, "raw": "null"}
- else:
- return {"type": "Literal", "value": "fuzz", "raw": '"fuzz"'}
-
- def make_identifier():
- names = ["a", "b", "c", "x", "y", "foo", "bar", "_", "$"]
- return {"type": "Identifier", "name": names[consume_byte() % len(names)]}
-
- def make_node(depth=0):
+ def make_literal() -> dict:
+ match consume_byte() % 4:
+ case 0:
+ return {'type': 'Literal', 'value': consume_byte(), 'raw': str(consume_byte())}
+ case 1:
+ return {'type': 'Literal', 'value': True, 'raw': 'true'}
+ case 2:
+ return {'type': 'Literal', 'value': None, 'raw': 'null'}
+ case _:
+ return {'type': 'Literal', 'value': 'fuzz', 'raw': '"fuzz"'}
+
+ def make_identifier() -> dict:
+ names = ['a', 'b', 'c', 'x', 'y', 'foo', 'bar', '_', '$']
+ return {'type': 'Identifier', 'name': names[consume_byte() % len(names)]}
+
+ def make_node(depth: int = 0) -> dict:
if depth >= max_depth:
return make_literal() if consume_byte() % 2 == 0 else make_identifier()
type_idx = consume_byte()
- if type_idx % 3 == 0:
- # Expression
- expr_type = _EXPRESSION_TYPES[consume_byte() % len(_EXPRESSION_TYPES)]
- if expr_type == "Literal":
+ if type_idx % 3 != 0:
+ return make_statement(depth)
+
+ expr_type = _EXPRESSION_TYPES[consume_byte() % len(_EXPRESSION_TYPES)]
+ match expr_type:
+ case 'Literal':
return make_literal()
- elif expr_type == "Identifier":
+ case 'Identifier':
return make_identifier()
- elif expr_type == "BinaryExpression":
+ case 'BinaryExpression':
return {
- "type": "BinaryExpression",
- "operator": _BINARY_OPS[consume_byte() % len(_BINARY_OPS)],
- "left": make_node(depth + 1),
- "right": make_node(depth + 1),
+ 'type': 'BinaryExpression',
+ 'operator': _BINARY_OPS[consume_byte() % len(_BINARY_OPS)],
+ 'left': make_node(depth + 1),
+ 'right': make_node(depth + 1),
}
- elif expr_type == "UnaryExpression":
+ case 'UnaryExpression':
return {
- "type": "UnaryExpression",
- "operator": _UNARY_OPS[consume_byte() % len(_UNARY_OPS)],
- "argument": make_node(depth + 1),
- "prefix": True,
+ 'type': 'UnaryExpression',
+ 'operator': _UNARY_OPS[consume_byte() % len(_UNARY_OPS)],
+ 'argument': make_node(depth + 1),
+ 'prefix': True,
}
- elif expr_type == "CallExpression":
+ case 'CallExpression':
num_args = consume_byte() % max_children
return {
- "type": "CallExpression",
- "callee": make_node(depth + 1),
- "arguments": [make_node(depth + 1) for _ in range(num_args)],
+ 'type': 'CallExpression',
+ 'callee': make_node(depth + 1),
+ 'arguments': [make_node(depth + 1) for _ in range(num_args)],
}
- elif expr_type == "MemberExpression":
+ case 'MemberExpression':
computed = consume_byte() % 2 == 0
return {
- "type": "MemberExpression",
- "object": make_node(depth + 1),
- "property": make_node(depth + 1),
- "computed": computed,
+ 'type': 'MemberExpression',
+ 'object': make_node(depth + 1),
+ 'property': make_node(depth + 1),
+ 'computed': computed,
}
- elif expr_type == "AssignmentExpression":
+ case 'AssignmentExpression':
return {
- "type": "AssignmentExpression",
- "operator": "=",
- "left": make_identifier(),
- "right": make_node(depth + 1),
+ 'type': 'AssignmentExpression',
+ 'operator': '=',
+ 'left': make_identifier(),
+ 'right': make_node(depth + 1),
}
- elif expr_type == "ConditionalExpression":
+ case 'ConditionalExpression':
return {
- "type": "ConditionalExpression",
- "test": make_node(depth + 1),
- "consequent": make_node(depth + 1),
- "alternate": make_node(depth + 1),
+ 'type': 'ConditionalExpression',
+ 'test': make_node(depth + 1),
+ 'consequent': make_node(depth + 1),
+ 'alternate': make_node(depth + 1),
}
- elif expr_type == "ArrayExpression":
+ case 'ArrayExpression':
num = consume_byte() % max_children
return {
- "type": "ArrayExpression",
- "elements": [make_node(depth + 1) for _ in range(num)],
+ 'type': 'ArrayExpression',
+ 'elements': [make_node(depth + 1) for _ in range(num)],
}
- elif expr_type == "ObjectExpression":
+ case 'ObjectExpression':
num = consume_byte() % max_children
return {
- "type": "ObjectExpression",
- "properties": [
+ 'type': 'ObjectExpression',
+ 'properties': [
{
- "type": "Property",
- "key": make_identifier(),
- "value": make_node(depth + 1),
- "kind": "init",
- "computed": False,
- "method": False,
- "shorthand": False,
+ 'type': 'Property',
+ 'key': make_identifier(),
+ 'value': make_node(depth + 1),
+ 'kind': 'init',
+ 'computed': False,
+ 'method': False,
+ 'shorthand': False,
}
for _ in range(num)
],
}
- elif expr_type == "FunctionExpression":
+ case 'FunctionExpression':
return {
- "type": "FunctionExpression",
- "id": None,
- "params": [],
- "body": {
- "type": "BlockStatement",
- "body": [make_statement(depth + 1) for _ in range(consume_byte() % 3)],
+ 'type': 'FunctionExpression',
+ 'id': None,
+ 'params': [],
+ 'body': {
+ 'type': 'BlockStatement',
+ 'body': [make_statement(depth + 1) for _ in range(consume_byte() % 3)],
},
- "generator": False,
- "async": False,
+ 'generator': False,
+ 'async': False,
}
- else:
- return {"type": "ThisExpression"}
- else:
- return make_statement(depth)
+ case _:
+ return {'type': 'ThisExpression'}
- def make_statement(depth=0):
+ def make_statement(depth: int = 0) -> dict:
if depth >= max_depth:
- return {
- "type": "ExpressionStatement",
- "expression": make_literal(),
- }
+ return {'type': 'ExpressionStatement', 'expression': make_literal()}
stmt_type = _STATEMENT_TYPES[consume_byte() % len(_STATEMENT_TYPES)]
- if stmt_type == "ExpressionStatement":
- return {"type": "ExpressionStatement", "expression": make_node(depth + 1)}
- elif stmt_type == "VariableDeclaration":
- return {
- "type": "VariableDeclaration",
- "declarations": [
- {
- "type": "VariableDeclarator",
- "id": make_identifier(),
- "init": make_node(depth + 1) if consume_byte() % 2 == 0 else None,
- }
- ],
- "kind": ["var", "let", "const"][consume_byte() % 3],
- }
- elif stmt_type == "ReturnStatement":
- return {"type": "ReturnStatement", "argument": make_node(depth + 1) if consume_byte() % 2 == 0 else None}
- elif stmt_type == "IfStatement":
- return {
- "type": "IfStatement",
- "test": make_node(depth + 1),
- "consequent": {"type": "BlockStatement", "body": [make_statement(depth + 1)]},
- "alternate": (
- {"type": "BlockStatement", "body": [make_statement(depth + 1)]} if consume_byte() % 2 == 0 else None
- ),
- }
- elif stmt_type == "WhileStatement":
- return {
- "type": "WhileStatement",
- "test": make_node(depth + 1),
- "body": {"type": "BlockStatement", "body": [make_statement(depth + 1)]},
- }
- elif stmt_type == "ForStatement":
- return {
- "type": "ForStatement",
- "init": None,
- "test": make_node(depth + 1),
- "update": None,
- "body": {"type": "BlockStatement", "body": [make_statement(depth + 1)]},
- }
- elif stmt_type == "BlockStatement":
- num = consume_byte() % max_children
- return {"type": "BlockStatement", "body": [make_statement(depth + 1) for _ in range(num)]}
- elif stmt_type == "EmptyStatement":
- return {"type": "EmptyStatement"}
- elif stmt_type == "BreakStatement":
- return {"type": "BreakStatement", "label": None}
- elif stmt_type == "ContinueStatement":
- return {"type": "ContinueStatement", "label": None}
- elif stmt_type == "ThrowStatement":
- return {"type": "ThrowStatement", "argument": make_node(depth + 1)}
- return {"type": "EmptyStatement"}
+ match stmt_type:
+ case 'ExpressionStatement':
+ return {'type': 'ExpressionStatement', 'expression': make_node(depth + 1)}
+ case 'VariableDeclaration':
+ return {
+ 'type': 'VariableDeclaration',
+ 'declarations': [
+ {
+ 'type': 'VariableDeclarator',
+ 'id': make_identifier(),
+ 'init': make_node(depth + 1) if consume_byte() % 2 == 0 else None,
+ }
+ ],
+ 'kind': ['var', 'let', 'const'][consume_byte() % 3],
+ }
+ case 'ReturnStatement':
+ return {
+ 'type': 'ReturnStatement',
+ 'argument': make_node(depth + 1) if consume_byte() % 2 == 0 else None,
+ }
+ case 'IfStatement':
+ return {
+ 'type': 'IfStatement',
+ 'test': make_node(depth + 1),
+ 'consequent': {'type': 'BlockStatement', 'body': [make_statement(depth + 1)]},
+ 'alternate': (
+ {'type': 'BlockStatement', 'body': [make_statement(depth + 1)]}
+ if consume_byte() % 2 == 0
+ else None
+ ),
+ }
+ case 'WhileStatement':
+ return {
+ 'type': 'WhileStatement',
+ 'test': make_node(depth + 1),
+ 'body': {'type': 'BlockStatement', 'body': [make_statement(depth + 1)]},
+ }
+ case 'ForStatement':
+ return {
+ 'type': 'ForStatement',
+ 'init': None,
+ 'test': make_node(depth + 1),
+ 'update': None,
+ 'body': {'type': 'BlockStatement', 'body': [make_statement(depth + 1)]},
+ }
+ case 'BlockStatement':
+ num = consume_byte() % max_children
+ return {'type': 'BlockStatement', 'body': [make_statement(depth + 1) for _ in range(num)]}
+ case 'EmptyStatement':
+ return {'type': 'EmptyStatement'}
+ case 'BreakStatement':
+ return {'type': 'BreakStatement', 'label': None}
+ case 'ContinueStatement':
+ return {'type': 'ContinueStatement', 'label': None}
+ case 'ThrowStatement':
+ return {'type': 'ThrowStatement', 'argument': make_node(depth + 1)}
+ case _:
+ return {'type': 'EmptyStatement'}
num_stmts = max(1, consume_byte() % 6)
return {
- "type": "Program",
- "body": [make_statement(0) for _ in range(num_stmts)],
- "sourceType": "script",
+ 'type': 'Program',
+ 'body': [make_statement(0) for _ in range(num_stmts)],
+ 'sourceType': 'script',
}
class SimpleFuzzedDataProvider:
"""Minimal FuzzedDataProvider for when atheris is not available."""
- def __init__(self, data):
+ def __init__(self, data: bytes) -> None:
self._data = data
self._pos = 0
- def ConsumeUnicode(self, max_length):
+ def ConsumeUnicode(self, max_length: int) -> str:
end = min(self._pos + max_length, len(self._data))
- chunk = self._data[self._pos : end]
+ chunk = self._data[self._pos:end]
self._pos = end
- return chunk.decode("utf-8", errors="replace")
+ return chunk.decode('utf-8', errors='replace')
- def ConsumeBytes(self, max_length):
+ def ConsumeBytes(self, max_length: int) -> bytes:
end = min(self._pos + max_length, len(self._data))
- chunk = self._data[self._pos : end]
+ chunk = self._data[self._pos:end]
self._pos = end
return chunk
- def ConsumeIntInRange(self, min_val, max_val):
+ def ConsumeIntInRange(self, min_val: int, max_val: int) -> int:
if self._pos < len(self._data):
val = self._data[self._pos]
self._pos += 1
return min_val + (val % (max_val - min_val + 1))
return min_val
- def ConsumeBool(self):
+ def ConsumeBool(self) -> bool:
return self.ConsumeIntInRange(0, 1) == 1
- def remaining_bytes(self):
+ def remaining_bytes(self) -> int:
return len(self._data) - self._pos
@@ -307,7 +312,11 @@ def remaining_bytes(self):
FuzzedDataProvider = SimpleFuzzedDataProvider
-def run_fuzzer(target_fn, argv=None, custom_setup=None):
+def run_fuzzer(
+ target_fn: Callable[[bytes], None],
+ argv: list[str] | None = None,
+ custom_setup: Callable | None = None,
+) -> None:
"""Run a fuzz target with atheris if available, otherwise with random inputs."""
if atheris is not None:
if custom_setup:
@@ -315,50 +324,46 @@ def run_fuzzer(target_fn, argv=None, custom_setup=None):
custom_setup()
atheris.Setup(argv or sys.argv, target_fn)
atheris.Fuzz()
- else:
- # Standalone random-based fuzzing
- import argparse
-
- parser = argparse.ArgumentParser()
- parser.add_argument("corpus_dirs", nargs="*", default=[])
- parser.add_argument("-max_total_time", type=int, default=10)
- parser.add_argument("-max_len", type=int, default=MAX_INPUT_SIZE)
- parser.add_argument("-timeout", type=int, default=30)
- parser.add_argument("-rss_limit_mb", type=int, default=2048)
- parser.add_argument("-runs", type=int, default=0)
- args = parser.parse_args(argv[1:] if argv else sys.argv[1:])
-
- # First, run seed corpus files
- seeds_run = 0
- for corpus_dir in args.corpus_dirs:
- if os.path.isdir(corpus_dir):
- for fname in sorted(os.listdir(corpus_dir)):
- fpath = os.path.join(corpus_dir, fname)
- if os.path.isfile(fpath):
- with open(fpath, "rb") as f:
- data = f.read()
- try:
- target_fn(data)
- except Exception as e:
- if not isinstance(e, SAFE_EXCEPTIONS):
- print(f"FINDING in seed {fname}: {type(e).__name__}: {e}")
- seeds_run += 1
-
- # Then random inputs
- import time
-
- rng = random.Random(42)
- start = time.time()
- runs = 0
- max_runs = args.runs if args.runs > 0 else float("inf")
- while time.time() - start < args.max_total_time and runs < max_runs:
- length = rng.randint(0, min(args.max_len, 4096))
- data = bytes(rng.randint(0, 255) for _ in range(length))
+ return
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('corpus_dirs', nargs='*', default=[])
+ parser.add_argument('-max_total_time', type=int, default=10)
+ parser.add_argument('-max_len', type=int, default=MAX_INPUT_SIZE)
+ parser.add_argument('-timeout', type=int, default=30)
+ parser.add_argument('-rss_limit_mb', type=int, default=2048)
+ parser.add_argument('-runs', type=int, default=0)
+ args = parser.parse_args(argv[1:] if argv else sys.argv[1:])
+
+ seeds_run = 0
+ for corpus_dir in args.corpus_dirs:
+ if not os.path.isdir(corpus_dir):
+ continue
+ for fname in sorted(os.listdir(corpus_dir)):
+ fpath = os.path.join(corpus_dir, fname)
+ if not os.path.isfile(fpath):
+ continue
+ with open(fpath, 'rb') as file_handle:
+ seed_data = file_handle.read()
try:
- target_fn(data)
- except Exception as e:
- if not isinstance(e, SAFE_EXCEPTIONS):
- print(f"FINDING at run {runs}: {type(e).__name__}: {e}")
- runs += 1
-
- print(f"Fuzzing complete: {seeds_run} seeds + {runs} random inputs in {time.time() - start:.1f}s")
+ target_fn(seed_data)
+ except Exception as exc:
+ if not isinstance(exc, SAFE_EXCEPTIONS):
+ print(f'FINDING in seed {fname}: {type(exc).__name__}: {exc}')
+ seeds_run += 1
+
+ rng = random.Random(42)
+ start = time.time()
+ runs = 0
+ max_runs = args.runs if args.runs > 0 else float('inf')
+ while time.time() - start < args.max_total_time and runs < max_runs:
+ length = rng.randint(0, min(args.max_len, 4096))
+ random_data = bytes(rng.randint(0, 255) for _ in range(length))
+ try:
+ target_fn(random_data)
+ except Exception as exc:
+ if not isinstance(exc, SAFE_EXCEPTIONS):
+ print(f'FINDING at run {runs}: {type(exc).__name__}: {exc}')
+ runs += 1
+
+ print(f'Fuzzing complete: {seeds_run} seeds + {runs} random inputs in {time.time() - start:.1f}s')
diff --git a/tests/fuzz/fuzz_deobfuscate.py b/tests/fuzz/fuzz_deobfuscate.py
index 5a950f3..5d2e243 100755
--- a/tests/fuzz/fuzz_deobfuscate.py
+++ b/tests/fuzz/fuzz_deobfuscate.py
@@ -9,7 +9,7 @@
import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from conftest_fuzz import SAFE_EXCEPTIONS
from conftest_fuzz import bytes_to_js
@@ -18,7 +18,7 @@
from pyjsclear import deobfuscate
-def TestOneInput(data):
+def TestOneInput(data: bytes) -> None:
if len(data) < 2:
return
@@ -29,11 +29,9 @@ def TestOneInput(data):
except SAFE_EXCEPTIONS:
return
- # Core safety guarantee: result must never be None
- assert result is not None, "deobfuscate() returned None"
- # Result must be a string
- assert isinstance(result, str), f"deobfuscate() returned {type(result)}, expected str"
+ assert result is not None, 'deobfuscate() returned None'
+ assert isinstance(result, str), f'deobfuscate() returned {type(result)}, expected str'
-if __name__ == "__main__":
+if __name__ == '__main__':
run_fuzzer(TestOneInput)
diff --git a/tests/fuzz/fuzz_expression_simplifier.py b/tests/fuzz/fuzz_expression_simplifier.py
index 001ca9f..f3b8438 100755
--- a/tests/fuzz/fuzz_expression_simplifier.py
+++ b/tests/fuzz/fuzz_expression_simplifier.py
@@ -9,7 +9,7 @@
import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from conftest_fuzz import SAFE_EXCEPTIONS
from conftest_fuzz import bytes_to_js
@@ -20,7 +20,7 @@
from pyjsclear.transforms.expression_simplifier import ExpressionSimplifier
-def TestOneInput(data):
+def TestOneInput(data: bytes) -> None:
if len(data) < 2:
return
@@ -42,8 +42,8 @@ def TestOneInput(data):
except SAFE_EXCEPTIONS:
return
- assert isinstance(result, str), f"generate() returned {type(result)} after ExpressionSimplifier"
+ assert isinstance(result, str), f'generate() returned {type(result)} after ExpressionSimplifier'
-if __name__ == "__main__":
+if __name__ == '__main__':
run_fuzzer(TestOneInput)
diff --git a/tests/fuzz/fuzz_generator.py b/tests/fuzz/fuzz_generator.py
index be2a47a..1f3d228 100755
--- a/tests/fuzz/fuzz_generator.py
+++ b/tests/fuzz/fuzz_generator.py
@@ -10,7 +10,7 @@
import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from conftest_fuzz import SAFE_EXCEPTIONS
from conftest_fuzz import FuzzedDataProvider
@@ -22,7 +22,7 @@
from pyjsclear.parser import parse
-def TestOneInput(data):
+def TestOneInput(data: bytes) -> None:
if len(data) < 4:
return
@@ -30,7 +30,7 @@ def TestOneInput(data):
mode = fdp.ConsumeIntInRange(0, 1)
if mode == 0:
- # Roundtrip mode: parse valid JS then generate
+ # Roundtrip: parse then generate
code = bytes_to_js(fdp.ConsumeBytes(fdp.remaining_bytes()))
try:
ast = parse(code)
@@ -42,22 +42,21 @@ def TestOneInput(data):
except SAFE_EXCEPTIONS:
return
- assert isinstance(result, str), f"generate() returned {type(result)}, expected str"
+ assert isinstance(result, str), f'generate() returned {type(result)}, expected str'
else:
- # Synthetic AST mode: test with malformed input
+ # Synthetic AST: test with malformed input
remaining = fdp.ConsumeBytes(fdp.remaining_bytes())
ast = bytes_to_ast_dict(remaining)
try:
result = generate(ast)
except (KeyError, TypeError, AttributeError, ValueError):
- # Expected for malformed ASTs
return
except SAFE_EXCEPTIONS:
return
- assert isinstance(result, str), f"generate() returned {type(result)}, expected str"
+ assert isinstance(result, str), f'generate() returned {type(result)}, expected str'
-if __name__ == "__main__":
+if __name__ == '__main__':
run_fuzzer(TestOneInput)
diff --git a/tests/fuzz/fuzz_parser.py b/tests/fuzz/fuzz_parser.py
index 40aa615..1ff4275 100755
--- a/tests/fuzz/fuzz_parser.py
+++ b/tests/fuzz/fuzz_parser.py
@@ -5,7 +5,7 @@
import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from conftest_fuzz import SAFE_EXCEPTIONS
from conftest_fuzz import bytes_to_js
@@ -14,7 +14,7 @@
from pyjsclear.parser import parse
-def TestOneInput(data):
+def TestOneInput(data: bytes) -> None:
if len(data) < 1:
return
@@ -23,15 +23,13 @@ def TestOneInput(data):
try:
result = parse(code)
except SyntaxError:
- # Expected for invalid JS
return
except SAFE_EXCEPTIONS:
return
- # Successful parse must return a Program or Module dict
- assert isinstance(result, dict), f"parse() returned {type(result)}, expected dict"
- assert result.get("type") in ("Program", "Module"), f"Unexpected root type: {result.get('type')}"
+ assert isinstance(result, dict), f'parse() returned {type(result)}, expected dict'
+ assert result.get('type') in ('Program', 'Module'), f"Unexpected root type: {result.get('type')}"
-if __name__ == "__main__":
+if __name__ == '__main__':
run_fuzzer(TestOneInput)
diff --git a/tests/fuzz/fuzz_scope.py b/tests/fuzz/fuzz_scope.py
index 21cb853..8abb477 100755
--- a/tests/fuzz/fuzz_scope.py
+++ b/tests/fuzz/fuzz_scope.py
@@ -5,7 +5,7 @@
import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from conftest_fuzz import SAFE_EXCEPTIONS
from conftest_fuzz import bytes_to_js
@@ -15,7 +15,7 @@
from pyjsclear.scope import build_scope_tree
-def TestOneInput(data):
+def TestOneInput(data: bytes) -> None:
if len(data) < 2:
return
@@ -31,9 +31,9 @@ def TestOneInput(data):
except SAFE_EXCEPTIONS:
return
- assert root_scope is not None, "build_scope_tree returned None root_scope"
- assert isinstance(node_scope_map, dict), f"node_scope_map is {type(node_scope_map)}, expected dict"
+ assert root_scope is not None, 'build_scope_tree returned None root_scope'
+ assert isinstance(node_scope_map, dict), f'node_scope_map is {type(node_scope_map)}, expected dict'
-if __name__ == "__main__":
+if __name__ == '__main__':
run_fuzzer(TestOneInput)
diff --git a/tests/fuzz/fuzz_string_decoders.py b/tests/fuzz/fuzz_string_decoders.py
index 25647a3..4b45754 100755
--- a/tests/fuzz/fuzz_string_decoders.py
+++ b/tests/fuzz/fuzz_string_decoders.py
@@ -9,7 +9,7 @@
import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from conftest_fuzz import SAFE_EXCEPTIONS
from conftest_fuzz import FuzzedDataProvider
@@ -21,71 +21,65 @@
from pyjsclear.utils.string_decoders import base64_transform
-def TestOneInput(data):
+def TestOneInput(data: bytes) -> None:
if len(data) < 8:
return
fdp = FuzzedDataProvider(data)
decoder_choice = fdp.ConsumeIntInRange(0, 3)
- if decoder_choice == 0:
- # Test base64_transform with random strings
- encoded = fdp.ConsumeUnicode(1024)
- try:
- result = base64_transform(encoded)
- except SAFE_EXCEPTIONS:
- return
- assert isinstance(result, str), f"base64_transform returned {type(result)}"
-
- elif decoder_choice == 1:
- # Test BasicStringDecoder
- num_strings = fdp.ConsumeIntInRange(0, 20)
- string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)]
- offset = fdp.ConsumeIntInRange(-5, 5)
- decoder = BasicStringDecoder(string_array, offset)
- if num_strings > 0:
- idx = fdp.ConsumeIntInRange(-2, num_strings * 2)
+ match decoder_choice:
+ case 0:
+ encoded = fdp.ConsumeUnicode(1024)
try:
- result = decoder.get_string(idx)
+ result = base64_transform(encoded)
except SAFE_EXCEPTIONS:
return
- # None is valid for out-of-range indices
- if result is not None:
- assert isinstance(result, str), f"BasicStringDecoder returned {type(result)}"
+ assert isinstance(result, str), f'base64_transform returned {type(result)}'
- elif decoder_choice == 2:
- # Test Base64StringDecoder
- num_strings = fdp.ConsumeIntInRange(0, 20)
- string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)]
- offset = fdp.ConsumeIntInRange(-5, 5)
- decoder = Base64StringDecoder(string_array, offset)
- if num_strings > 0:
- idx = fdp.ConsumeIntInRange(-2, num_strings * 2)
- try:
- result = decoder.get_string(idx)
- except SAFE_EXCEPTIONS:
- return
- # None is valid for out-of-range indices
- if result is not None:
- assert isinstance(result, str), f"Base64StringDecoder returned {type(result)}"
+ case 1:
+ num_strings = fdp.ConsumeIntInRange(0, 20)
+ string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)]
+ offset = fdp.ConsumeIntInRange(-5, 5)
+ decoder = BasicStringDecoder(string_array, offset)
+ if num_strings > 0:
+ idx = fdp.ConsumeIntInRange(-2, num_strings * 2)
+ try:
+ result = decoder.get_string(idx)
+ except SAFE_EXCEPTIONS:
+ return
+ if result is not None:
+ assert isinstance(result, str), f'BasicStringDecoder returned {type(result)}'
- elif decoder_choice == 3:
- # Test Rc4StringDecoder - potential ZeroDivisionError with empty key
- num_strings = fdp.ConsumeIntInRange(0, 20)
- string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)]
- offset = fdp.ConsumeIntInRange(-5, 5)
- decoder = Rc4StringDecoder(string_array, offset)
- if num_strings > 0:
- idx = fdp.ConsumeIntInRange(-2, num_strings * 2)
- key = fdp.ConsumeUnicode(32) # May be empty - tests empty key guard
- try:
- result = decoder.get_string(idx, key=key)
- except SAFE_EXCEPTIONS:
- return
- # None is valid for out-of-range or None key
- if result is not None:
- assert isinstance(result, str), f"Rc4StringDecoder returned {type(result)}"
+ case 2:
+ num_strings = fdp.ConsumeIntInRange(0, 20)
+ string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)]
+ offset = fdp.ConsumeIntInRange(-5, 5)
+ decoder = Base64StringDecoder(string_array, offset)
+ if num_strings > 0:
+ idx = fdp.ConsumeIntInRange(-2, num_strings * 2)
+ try:
+ result = decoder.get_string(idx)
+ except SAFE_EXCEPTIONS:
+ return
+ if result is not None:
+ assert isinstance(result, str), f'Base64StringDecoder returned {type(result)}'
+
+ case 3:
+ num_strings = fdp.ConsumeIntInRange(0, 20)
+ string_array = [fdp.ConsumeUnicode(64) for _ in range(num_strings)]
+ offset = fdp.ConsumeIntInRange(-5, 5)
+ decoder = Rc4StringDecoder(string_array, offset)
+ if num_strings > 0:
+ idx = fdp.ConsumeIntInRange(-2, num_strings * 2)
+ key = fdp.ConsumeUnicode(32) # may be empty - tests empty key guard
+ try:
+ result = decoder.get_string(idx, key=key)
+ except SAFE_EXCEPTIONS:
+ return
+ if result is not None:
+ assert isinstance(result, str), f'Rc4StringDecoder returned {type(result)}'
-if __name__ == "__main__":
+if __name__ == '__main__':
run_fuzzer(TestOneInput)
diff --git a/tests/fuzz/fuzz_transforms.py b/tests/fuzz/fuzz_transforms.py
index f5b0eeb..ac9378e 100755
--- a/tests/fuzz/fuzz_transforms.py
+++ b/tests/fuzz/fuzz_transforms.py
@@ -9,7 +9,7 @@
import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from conftest_fuzz import SAFE_EXCEPTIONS
from conftest_fuzz import FuzzedDataProvider
@@ -54,7 +54,7 @@
]
-def TestOneInput(data):
+def TestOneInput(data: bytes) -> None:
if len(data) < 4:
return
@@ -84,5 +84,5 @@ def TestOneInput(data):
return
-if __name__ == "__main__":
+if __name__ == '__main__':
run_fuzzer(TestOneInput)
diff --git a/tests/fuzz/fuzz_traverser.py b/tests/fuzz/fuzz_traverser.py
index 64ed65e..b076fd1 100755
--- a/tests/fuzz/fuzz_traverser.py
+++ b/tests/fuzz/fuzz_traverser.py
@@ -9,7 +9,7 @@
import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from conftest_fuzz import SAFE_EXCEPTIONS
from conftest_fuzz import FuzzedDataProvider
@@ -25,7 +25,7 @@
MAX_VISITED = 10_000
-def TestOneInput(data):
+def TestOneInput(data: bytes) -> None:
if len(data) < 8:
return
@@ -36,55 +36,53 @@ def TestOneInput(data):
visited = 0
- if mode == 0:
- # traverse with enter that sometimes returns SKIP
- action_byte = remaining[0] if remaining else 0
-
- def enter(node, parent, key, index):
- nonlocal visited
- visited += 1
- if visited > MAX_VISITED:
- return SKIP
- if isinstance(node, dict) and node.get("type") == "Literal" and action_byte % 3 == 0:
- return SKIP
- return None
-
- try:
- traverse(ast, {"enter": enter})
- except SAFE_EXCEPTIONS:
- return
-
- elif mode == 1:
- # traverse with enter that returns REMOVE for some nodes
- action_byte = remaining[1] if len(remaining) > 1 else 0
-
- def enter(node, parent, key, index):
- nonlocal visited
- visited += 1
- if visited > MAX_VISITED:
- return SKIP
- if isinstance(node, dict) and node.get("type") == "EmptyStatement" and action_byte % 2 == 0:
- return REMOVE
- return None
-
- try:
- traverse(ast, {"enter": enter})
- except SAFE_EXCEPTIONS:
- return
-
- elif mode == 2:
- # simple_traverse - just visit all nodes
- def callback(node, parent):
- nonlocal visited
- visited += 1
- if visited > MAX_VISITED:
- raise StopIteration("too many nodes")
-
- try:
- simple_traverse(ast, callback)
- except (StopIteration, SAFE_EXCEPTIONS[0]):
- return
-
-
-if __name__ == "__main__":
+ match mode:
+ case 0:
+ action_byte = remaining[0] if remaining else 0
+
+ def enter(node, parent, key, index):
+ nonlocal visited
+ visited += 1
+ if visited > MAX_VISITED:
+ return SKIP
+ if isinstance(node, dict) and node.get('type') == 'Literal' and action_byte % 3 == 0:
+ return SKIP
+ return None
+
+ try:
+ traverse(ast, {'enter': enter})
+ except SAFE_EXCEPTIONS:
+ return
+
+ case 1:
+ action_byte = remaining[1] if len(remaining) > 1 else 0
+
+ def enter(node, parent, key, index):
+ nonlocal visited
+ visited += 1
+ if visited > MAX_VISITED:
+ return SKIP
+ if isinstance(node, dict) and node.get('type') == 'EmptyStatement' and action_byte % 2 == 0:
+ return REMOVE
+ return None
+
+ try:
+ traverse(ast, {'enter': enter})
+ except SAFE_EXCEPTIONS:
+ return
+
+ case 2:
+ def callback(node, parent):
+ nonlocal visited
+ visited += 1
+ if visited > MAX_VISITED:
+ raise StopIteration('too many nodes')
+
+ try:
+ simple_traverse(ast, callback)
+ except (StopIteration, SAFE_EXCEPTIONS[0]):
+ return
+
+
+if __name__ == '__main__':
run_fuzzer(TestOneInput)
diff --git a/tests/test_regression.py b/tests/test_regression.py
index 304abed..6fb8703 100644
--- a/tests/test_regression.py
+++ b/tests/test_regression.py
@@ -13,7 +13,6 @@
- Cross-cutting quality invariants
"""
-import os
import re
from pathlib import Path
@@ -31,17 +30,17 @@
RE_HEX_NUMERIC = re.compile(r'\b0x[0-9a-fA-F]+\b')
-def _deobfuscate(filename):
+def _deobfuscate(filename: str) -> tuple[str, str]:
code = (SAMPLES_DIR / filename).read_text()
result = pyjsclear.deobfuscate(code)
return code, result
-def _count_0x(text):
+def _count_0x(text: str) -> int:
return len(RE_0X.findall(text))
-def _count_hex(text):
+def _count_hex(text: str) -> int:
return len(RE_HEX.findall(text))
@@ -176,7 +175,7 @@ def test_animatedflatlist_property_simplifier(self):
assert result != code, 'PropertySimplifier should transform the code'
assert len(result) < len(code), 'Output should be smaller (bracket -> dot)'
# String decode should NOT fire — the array literal should still be present
- assert "createAnimatedComponent" in result
+ assert 'createAnimatedComponent' in result
def test_animatedimage_2element_array_no_decode(self):
"""AnimatedImage: 2-element array — below Strategy 2b threshold.
@@ -187,8 +186,8 @@ def test_animatedimage_2element_array_no_decode(self):
code, result = _deobfuscate('AnimatedImage-obfuscated.js')
assert result != code, 'PropertySimplifier should still fire'
# String decode should NOT fire — the array literal should still be present
- assert "createAnimatedComponent" in result
- assert "exports" in result
+ assert 'createAnimatedComponent' in result
+ assert 'exports' in result
# ================================================================
@@ -525,7 +524,7 @@ def test_multiple_decoders_clean_output(self):
# ================================================================
-def _deobfuscate_resource(filename):
+def _deobfuscate_resource(filename: str) -> tuple[str, str]:
"""Load and deobfuscate a file from tests/resources/."""
code = (RESOURCES_DIR / filename).read_text()
result = pyjsclear.deobfuscate(code)
@@ -579,7 +578,7 @@ def test_no_extremely_long_lines(self, sample_result):
def test_few_long_lines(self, sample_result):
_, result = sample_result
- long_lines = sum(1 for l in result.splitlines() if len(l) > 500)
+ long_lines = sum(1 for line in result.splitlines() if len(line) > 500)
assert long_lines <= 5, f'{long_lines} lines > 500 chars'
@@ -698,7 +697,7 @@ class TestSampleRegressionGuards:
def test_no_proxy_inliner_blowup(self, sample_result):
"""OptionalChaining + ProxyFunctionInliner interaction guard."""
_, result = sample_result
- max_line_len = max(len(l) for l in result.splitlines())
+ max_line_len = max(len(line) for line in result.splitlines())
assert max_line_len < 2000, f'Max line length {max_line_len} suggests proxy inliner blowup'
def test_helper_functions_preserved(self, sample_result):
@@ -756,19 +755,19 @@ class TestQualityInvariants:
def test_no_empty_output(self):
"""No sample should produce empty output."""
- for f in SAMPLES_DIR.glob('*.js'):
- code = f.read_text()
+ for sample_file in SAMPLES_DIR.glob('*.js'):
+ code = sample_file.read_text()
result = pyjsclear.deobfuscate(code)
- assert len(result.strip()) > 0, f'{f.name} produced empty output'
+ assert len(result.strip()) > 0, f'{sample_file.name} produced empty output'
def test_no_hex_increase(self):
"""Deobfuscation should never introduce new hex escapes."""
- for f in SAMPLES_DIR.glob('*.js'):
- code = f.read_text()
+ for sample_file in SAMPLES_DIR.glob('*.js'):
+ code = sample_file.read_text()
result = pyjsclear.deobfuscate(code)
assert _count_hex(result) <= _count_hex(
code
- ), f'{f.name}: hex escapes increased from {_count_hex(code)} to {_count_hex(result)}'
+ ), f'{sample_file.name}: hex escapes increased from {_count_hex(code)} to {_count_hex(result)}'
def test_output_not_larger_than_input(self):
"""Deobfuscated output should never be larger than the input.
@@ -776,11 +775,11 @@ def test_output_not_larger_than_input(self):
Deobfuscation removes string arrays, dead code, and infrastructure.
If output grows, something is wrong (e.g., proxy inlining blowup).
"""
- for f in SAMPLES_DIR.glob('*.js'):
- code = f.read_text()
+ for sample_file in SAMPLES_DIR.glob('*.js'):
+ code = sample_file.read_text()
result = pyjsclear.deobfuscate(code)
assert len(result) <= len(code) * 1.1, (
- f'{f.name}: output ({len(result)}) > 110% of input ({len(code)}). '
+ f'{sample_file.name}: output ({len(result)}) > 110% of input ({len(code)}). '
f'Ratio: {len(result)/len(code):.2f}'
)
@@ -790,10 +789,8 @@ def test_output_parseable(self):
If the output doesn't parse, a transform likely corrupted the AST.
We skip files whose input doesn't parse (ES modules with import).
"""
- from pyjsclear.parser import parse
-
- for f in SAMPLES_DIR.glob('*.js'):
- code = f.read_text()
+ for sample_file in SAMPLES_DIR.glob('*.js'):
+ code = sample_file.read_text()
# Skip files that don't parse as input (ES modules)
try:
parse(code)
@@ -802,8 +799,8 @@ def test_output_parseable(self):
result = pyjsclear.deobfuscate(code)
try:
parse(result)
- except SyntaxError as e:
- pytest.fail(f'{f.name}: output does not parse: {e}')
+ except SyntaxError as error:
+ pytest.fail(f'{sample_file.name}: output does not parse: {error}')
def test_no_extremely_long_lines(self):
"""No output line should exceed 5000 chars.
@@ -812,10 +809,10 @@ def test_no_extremely_long_lines(self):
or other expansion bugs. The limit is generous (5000) to accommodate
files with legitimately long array literals.
"""
- for f in SAMPLES_DIR.glob('*.js'):
- code = f.read_text()
+ for sample_file in SAMPLES_DIR.glob('*.js'):
+ code = sample_file.read_text()
result = pyjsclear.deobfuscate(code)
for i, line in enumerate(result.splitlines(), 1):
assert len(line) <= 5000, (
- f'{f.name} line {i}: {len(line)} chars (max 5000). ' f'Preview: {line[:80]}...'
+ f'{sample_file.name} line {i}: {len(line)} chars (max 5000). ' f'Preview: {line[:80]}...'
)
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 2b16919..d188766 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -1,10 +1,12 @@
"""Shared test helpers for pyjsclear unit tests."""
+from typing import Any
+
from pyjsclear.generator import generate
from pyjsclear.parser import parse
-def roundtrip(js_code, transform_class):
+def roundtrip(js_code: str, transform_class: type) -> tuple[str, bool]:
"""Parse JS, apply a transform, return (generated_code, changed)."""
ast = parse(js_code)
t = transform_class(ast)
@@ -12,12 +14,12 @@ def roundtrip(js_code, transform_class):
return generate(ast), changed
-def parse_expr(js_expr):
+def parse_expr(js_expr: str) -> dict[str, Any]:
"""Parse a JS expression and return the expression AST node."""
ast = parse(js_expr + ';')
return ast['body'][0]['expression']
-def normalize(js_code):
+def normalize(js_code: str) -> str:
"""Collapse all whitespace to single spaces for comparison."""
return ' '.join(js_code.split())
diff --git a/tests/unit/deobfuscator_test.py b/tests/unit/deobfuscator_test.py
index c5e4bfe..3c4e9d5 100644
--- a/tests/unit/deobfuscator_test.py
+++ b/tests/unit/deobfuscator_test.py
@@ -10,6 +10,9 @@
from pyjsclear.deobfuscator import TRANSFORM_CLASSES
from pyjsclear.deobfuscator import Deobfuscator
from pyjsclear.deobfuscator import _count_nodes
+from pyjsclear.transforms.control_flow import ControlFlowRecoverer
+from pyjsclear.transforms.proxy_functions import ProxyFunctionInliner
+from pyjsclear.transforms.string_revealer import StringRevealer
class TestTransformClasses:
@@ -20,14 +23,10 @@ def test_transform_classes_length(self):
assert len(TRANSFORM_CLASSES) == 37
def test_string_revealer_appears_twice(self):
- from pyjsclear.transforms.string_revealer import StringRevealer
-
occurrences = [cls for cls in TRANSFORM_CLASSES if cls is StringRevealer]
assert len(occurrences) == 2
def test_string_revealer_is_first_and_last(self):
- from pyjsclear.transforms.string_revealer import StringRevealer
-
assert TRANSFORM_CLASSES[0] is StringRevealer
assert TRANSFORM_CLASSES[-1] is StringRevealer
@@ -261,8 +260,6 @@ def test_lite_mode_strips_expensive_transforms(self, mock_transforms, mock_gener
mock_ast = MagicMock()
mock_parse.return_value = mock_ast
- from pyjsclear.transforms.control_flow import ControlFlowRecoverer
-
# One cheap transform that changes, one expensive that should be skipped
cheap_instance = MagicMock()
cheap_instance.execute.return_value = False
@@ -292,8 +289,6 @@ def test_large_node_count_strips_expensive_transforms(self, mock_transforms, moc
cheap_instance.execute.return_value = False
cheap_transform = MagicMock(return_value=cheap_instance)
- from pyjsclear.transforms.proxy_functions import ProxyFunctionInliner
-
mock_transforms.__iter__ = lambda self: iter([cheap_transform, ProxyFunctionInliner])
# Code > _LARGE_FILE_SIZE to trigger node counting, but < _MAX_CODE_SIZE (not lite mode)
diff --git a/tests/unit/generator_test.py b/tests/unit/generator_test.py
index 79a6e75..17f35a4 100644
--- a/tests/unit/generator_test.py
+++ b/tests/unit/generator_test.py
@@ -2,6 +2,9 @@
import pytest
+from pyjsclear.generator import _expr_precedence
+from pyjsclear.generator import _gen_property
+from pyjsclear.generator import _gen_stmt
from pyjsclear.generator import generate
from pyjsclear.parser import parse
@@ -1497,8 +1500,6 @@ class TestGenStmtNoneNode:
"""Line 113: _gen_stmt with None node returns ''."""
def test_gen_stmt_none(self):
- from pyjsclear.generator import _gen_stmt
-
assert _gen_stmt(None, 0) == ''
@@ -1509,8 +1510,6 @@ def test_statement_ending_with_semicolon(self):
# EmptyStatement generates ';' and is in _NO_SEMI_TYPES,
# but we can construct a node whose generate() output ends with ';'
# that is NOT in _NO_SEMI_TYPES. Use a manual approach.
- from pyjsclear.generator import _gen_stmt
-
# Create a fake node type that generates code ending with ';'
# A VariableDeclaration ending with ';' (by appending manually)
# Actually, let's just test the path: _gen_stmt appends ';' only if code doesn't already end with it
@@ -1625,8 +1624,6 @@ class TestRestElementInProperty:
def test_gen_property_rest_element(self):
# _gen_property generates key: value for a Property node
- from pyjsclear.generator import _gen_property
-
node = {
'type': 'Property',
'key': {'type': 'Identifier', 'name': 'a'},
@@ -1822,8 +1819,6 @@ def test_assignment_expression_precedence(self):
def test_yield_expression_precedence_in_binary(self):
"""YieldExpression has precedence 2, needs parens in binary context."""
- from pyjsclear.generator import _expr_precedence
-
yield_node = {'type': 'YieldExpression', 'argument': _id('x'), 'delegate': False}
assert _expr_precedence(yield_node) == 2
@@ -1849,19 +1844,13 @@ def test_nested_precedence_conditional_in_assignment(self):
assert 'x = a ? b : c' in result
def test_arrow_function_precedence(self):
- from pyjsclear.generator import _expr_precedence
-
arrow_node = {'type': 'ArrowFunctionExpression'}
assert _expr_precedence(arrow_node) == 3
def test_unknown_type_precedence(self):
- from pyjsclear.generator import _expr_precedence
-
node = {'type': 'SomeUnknownExpression'}
assert _expr_precedence(node) == 0
def test_non_dict_precedence(self):
- from pyjsclear.generator import _expr_precedence
-
assert _expr_precedence(42) == 20
assert _expr_precedence('str') == 20
diff --git a/tests/unit/scope_test.py b/tests/unit/scope_test.py
index 67ea10b..c17aee2 100644
--- a/tests/unit/scope_test.py
+++ b/tests/unit/scope_test.py
@@ -695,3 +695,95 @@ def test_unbound_identifier_reference(self):
root_scope, _ = build_scope_tree(ast)
# 'x' is not declared, so no binding should exist
assert root_scope.get_own_binding('x') is None
+
+
+# ---------------------------------------------------------------------------
+# Deep AST test (exercises iterative fallback at depth > 500)
+# ---------------------------------------------------------------------------
+
+
+class TestDeepASTScope:
+ """Verify that build_scope_tree handles ASTs deeper than 500."""
+
+ def test_deep_nested_functions(self):
+ """Build deeply nested function scopes and verify bindings resolve."""
+ # Build a chain: Program -> func0 body -> func1 body -> ... -> funcN body -> var x = 1;
+ var_decl = {
+ 'type': 'VariableDeclaration',
+ 'kind': 'var',
+ 'declarations': [
+ {
+ 'type': 'VariableDeclarator',
+ 'id': {'type': 'Identifier', 'name': 'x'},
+ 'init': {'type': 'Literal', 'value': 1, 'raw': '1'},
+ }
+ ],
+ }
+ # Reference to x
+ ref_stmt = {
+ 'type': 'ExpressionStatement',
+ 'expression': {'type': 'Identifier', 'name': 'x'},
+ }
+ node = {'type': 'BlockStatement', 'body': [var_decl, ref_stmt]}
+ depth = 600
+ for i in range(depth):
+ node = {
+ 'type': 'FunctionDeclaration',
+ 'id': {'type': 'Identifier', 'name': f'f{i}'},
+ 'params': [],
+ 'body': node,
+ }
+ ast = {'type': 'Program', 'sourceType': 'script', 'body': [node]}
+
+ root_scope, node_scope = build_scope_tree(ast)
+ # Should have created many scopes without stack overflow
+ assert len(node_scope) > depth
+ # The innermost x binding should be resolvable
+ # Walk to the deepest function scope
+ scope = root_scope
+ for _ in range(depth):
+ assert len(scope.children) >= 1
+ scope = scope.children[0]
+ x_binding = scope.get_own_binding('x')
+ assert x_binding is not None
+ assert x_binding.kind == 'var'
+
+ def test_deep_block_statements(self):
+ """Build deeply nested block statements and verify scope creation."""
+ # Innermost has a let binding
+ inner = {
+ 'type': 'BlockStatement',
+ 'body': [
+ {
+ 'type': 'VariableDeclaration',
+ 'kind': 'let',
+ 'declarations': [
+ {
+ 'type': 'VariableDeclarator',
+ 'id': {'type': 'Identifier', 'name': 'deep'},
+ 'init': {'type': 'Literal', 'value': 42, 'raw': '42'},
+ }
+ ],
+ }
+ ],
+ }
+ node = inner
+ for _ in range(600):
+ node = {'type': 'BlockStatement', 'body': [node]}
+ ast = {'type': 'Program', 'sourceType': 'script', 'body': [node]}
+
+ root_scope, node_scope = build_scope_tree(ast)
+ # Should complete without stack overflow
+ assert root_scope is not None
+ # The 'deep' binding should exist somewhere in the scope tree
+ found = False
+
+ def _check(scope):
+ nonlocal found
+ if scope.get_own_binding('deep'):
+ found = True
+ for child in scope.children:
+ _check(child)
+
+ _check(root_scope)
+ assert found
diff --git a/tests/unit/transforms/aa_decode_test.py b/tests/unit/transforms/aa_decode_test.py
index 8404858..24c0f52 100644
--- a/tests/unit/transforms/aa_decode_test.py
+++ b/tests/unit/transforms/aa_decode_test.py
@@ -2,8 +2,8 @@
import pytest
-from pyjsclear.transforms.aa_decode import is_aa_encoded
from pyjsclear.transforms.aa_decode import aa_decode
+from pyjsclear.transforms.aa_decode import is_aa_encoded
class TestIsAAEncoded:
@@ -42,12 +42,12 @@ def test_synthetic_simple(self):
"""Synthetic AAEncode for 'Hi' (H=110 octal, i=151 octal).
This builds a minimal AAEncoded payload that the decoder can parse.
- Note: real AAEncode uses U+FF70 (\uff70 halfwidth), NOT U+30FC (fullwidth).
+ Note: real AAEncode uses U+FF70 (\\uff70 halfwidth), NOT U+30FC (fullwidth).
"""
# H = 0x48 = 110 octal, i = 0x69 = 151 octal
- # Digit 1 = (\uff9f\uff70\uff9f), Digit 0 = (c^_^o),
- # Digit 5 = ((\uff9f\uff70\uff9f) + (\uff9f\uff70\uff9f) + (\uff9f\u0398\uff9f))
- sep = '(\uff9f\u0414\uff9f)[\uff9f\u03b5\uff9f]+'
+ # Digit 1 = (\\uff9f\\uff70\\uff9f), Digit 0 = (c^_^o),
+ # Digit 5 = ((\\uff9f\\uff70\\uff9f) + (\\uff9f\\uff70\\uff9f) + (\\uff9f\\u0398\\uff9f))
+ separator = '(\uff9f\u0414\uff9f)[\uff9f\u03b5\uff9f]+'
h_digits = '(\uff9f\uff70\uff9f)+(\uff9f\uff70\uff9f)+(c^_^o)' # 1 1 0
i_digits = (
'(\uff9f\uff70\uff9f)+'
@@ -55,7 +55,7 @@ def test_synthetic_simple(self):
'(\uff9f\uff70\uff9f)'
) # 1 5 1
- data = sep + h_digits + sep + i_digits
+ data = separator + h_digits + separator + i_digits
# Add execution wrapper with the signature
code = data + "(\uff9f\u0414\uff9f)['_'](\uff9f\u0398\uff9f)"
diff --git a/tests/unit/transforms/anti_tamper_test.py b/tests/unit/transforms/anti_tamper_test.py
index ed51f40..06c19b6 100644
--- a/tests/unit/transforms/anti_tamper_test.py
+++ b/tests/unit/transforms/anti_tamper_test.py
@@ -2,6 +2,7 @@
import pytest
+from pyjsclear.parser import parse
from pyjsclear.transforms.anti_tamper import AntiTamperRemover
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
@@ -101,10 +102,10 @@ def test_multiple_statements_only_anti_tamper_removed(self):
code = 'var a = 1;(function() { console["log"] = function(){}; })();var b = 2;'
result, changed = roundtrip(code, AntiTamperRemover)
assert changed is True
- norm = normalize(result)
- assert 'var a = 1' in norm
- assert 'var b = 2' in norm
- assert 'console' not in norm
+ normalized = normalize(result)
+ assert 'var a = 1' in normalized
+ assert 'var b = 2' in normalized
+ assert 'console' not in normalized
class TestAntiTamperRemoverEdgeCases:
@@ -139,8 +140,7 @@ def test_debug_protection_setInterval(self):
def test_exception_during_generate(self):
"""Lines 87-88: Exception during generate() should be caught gracefully."""
- # This is hard to trigger directly via roundtrip since we'd need a malformed AST.
- # We test indirectly that normal IIFE processing doesn't crash.
+ # Hard to trigger directly; test that normal IIFE processing doesn't crash
code = '(function() { var x = 1; var y = 2; })();'
result, changed = roundtrip(code, AntiTamperRemover)
assert changed is False
@@ -154,32 +154,26 @@ def test_debugger_with_setInterval_removed(self):
def test_expression_statement_no_expression(self):
"""Line 70: ExpressionStatement with expression set to None."""
- from pyjsclear.parser import parse
-
ast = parse('a();')
# Manually set expression to None to trigger early return
ast['body'][0]['expression'] = None
- t = AntiTamperRemover(ast)
- changed = t.execute()
+ transform = AntiTamperRemover(ast)
+ changed = transform.execute()
assert changed is False
- def test_call_without_callee(self):
+ def test_call_without_callee_ast(self):
"""Line 78: Call node without callee."""
- from pyjsclear.parser import parse
-
ast = parse('(function() { x(); })();')
# Find the outer CallExpression and remove its callee
call = ast['body'][0]['expression']
if call.get('type') == 'CallExpression':
call['callee'] = None
- t = AntiTamperRemover(ast)
- changed = t.execute()
+ transform = AntiTamperRemover(ast)
+ changed = transform.execute()
assert changed is False
def test_exception_during_generate_malformed_callee(self):
"""Lines 87-88: Exception during generate() with malformed AST."""
- from pyjsclear.parser import parse
-
ast = parse('(function() { x(); })();')
# Find the IIFE callee (FunctionExpression) and corrupt its body
call = ast['body'][0]['expression']
@@ -188,7 +182,7 @@ def test_exception_during_generate_malformed_callee(self):
if callee and callee.get('type') == 'FunctionExpression':
# Corrupt the body to make generate() raise
callee['body'] = 'not_a_valid_body'
- t = AntiTamperRemover(ast)
- changed = t.execute()
+ transform = AntiTamperRemover(ast)
+ changed = transform.execute()
# Should not crash, just skip the node
assert changed is False
diff --git a/tests/unit/transforms/base_test.py b/tests/unit/transforms/base_test.py
index 4668bfe..5752f3b 100644
--- a/tests/unit/transforms/base_test.py
+++ b/tests/unit/transforms/base_test.py
@@ -6,50 +6,50 @@
class TestTransformInit:
def test_stores_ast(self):
ast = {'type': 'Program'}
- t = Transform(ast)
- assert t.ast is ast
+ transform = Transform(ast)
+ assert transform.ast is ast
def test_stores_scope_tree(self):
scope_tree = {'root': True}
- t = Transform('ast', scope_tree=scope_tree)
- assert t.scope_tree is scope_tree
+ transform = Transform('ast', scope_tree=scope_tree)
+ assert transform.scope_tree is scope_tree
def test_stores_node_scope(self):
node_scope = {'node': 'scope'}
- t = Transform('ast', node_scope=node_scope)
- assert t.node_scope is node_scope
+ transform = Transform('ast', node_scope=node_scope)
+ assert transform.node_scope is node_scope
def test_scope_tree_defaults_to_none(self):
- t = Transform('ast')
- assert t.scope_tree is None
+ transform = Transform('ast')
+ assert transform.scope_tree is None
def test_node_scope_defaults_to_none(self):
- t = Transform('ast')
- assert t.node_scope is None
+ transform = Transform('ast')
+ assert transform.node_scope is None
class TestTransformExecute:
def test_raises_not_implemented(self):
- t = Transform('ast')
+ transform = Transform('ast')
with pytest.raises(NotImplementedError):
- t.execute()
+ transform.execute()
class TestTransformChangedTracking:
def test_has_changed_initially_false(self):
- t = Transform('ast')
- assert t.has_changed() is False
+ transform = Transform('ast')
+ assert transform.has_changed() is False
def test_set_changed_makes_has_changed_true(self):
- t = Transform('ast')
- t.set_changed()
- assert t.has_changed() is True
+ transform = Transform('ast')
+ transform.set_changed()
+ assert transform.has_changed() is True
def test_set_changed_is_idempotent(self):
- t = Transform('ast')
- t.set_changed()
- t.set_changed()
- assert t.has_changed() is True
+ transform = Transform('ast')
+ transform.set_changed()
+ transform.set_changed()
+ assert transform.has_changed() is True
class TestTransformRebuildScope:
@@ -57,8 +57,8 @@ def test_class_default_is_false(self):
assert Transform.rebuild_scope is False
def test_instance_inherits_default(self):
- t = Transform('ast')
- assert t.rebuild_scope is False
+ transform = Transform('ast')
+ assert transform.rebuild_scope is False
def test_subclass_can_override(self):
class MyTransform(Transform):
@@ -68,5 +68,5 @@ def execute(self):
pass
assert MyTransform.rebuild_scope is True
- t = MyTransform('ast')
- assert t.rebuild_scope is True
+ transform = MyTransform('ast')
+ assert transform.rebuild_scope is True
diff --git a/tests/unit/transforms/class_string_decoder_test.py b/tests/unit/transforms/class_string_decoder_test.py
index e2bc2e4..4926283 100644
--- a/tests/unit/transforms/class_string_decoder_test.py
+++ b/tests/unit/transforms/class_string_decoder_test.py
@@ -1,6 +1,8 @@
"""Tests for the ClassStringDecoder transform."""
from pyjsclear.transforms.class_string_decoder import ClassStringDecoder
+from pyjsclear.transforms.dead_class_props import DeadClassPropRemover
+from pyjsclear.utils.ast_helpers import get_member_names
from tests.unit.conftest import roundtrip
@@ -26,7 +28,7 @@ class TestClassPropCollection:
def test_string_prop_collected(self):
"""String assignments on class vars should be tracked internally."""
- # This is an indirect test — if props aren't collected, resolution won't work
+ # Indirect test — if props aren't collected, resolution won't work
code = '''
var Cls = class {};
Cls.p1 = "foo";
@@ -41,15 +43,11 @@ class TestDeadClassPropRemover:
"""Tests for the DeadClassPropRemover transform."""
def test_no_classes_returns_false(self):
- from pyjsclear.transforms.dead_class_props import DeadClassPropRemover
-
result, changed = roundtrip('var x = 1;', DeadClassPropRemover)
assert changed is False
def test_dead_prop_removed(self):
"""Property written but never read should be removed."""
- from pyjsclear.transforms.dead_class_props import DeadClassPropRemover
-
code = '''
var Cls = class {};
Cls.deadProp = "never_used";
@@ -60,8 +58,6 @@ def test_dead_prop_removed(self):
def test_read_prop_preserved(self):
"""Property that is read should NOT be removed."""
- from pyjsclear.transforms.dead_class_props import DeadClassPropRemover
-
code = '''
var Cls = class {};
Cls.liveProp = "used";
@@ -73,8 +69,6 @@ def test_read_prop_preserved(self):
def test_fully_dead_class(self):
"""Class that is only used via property assignments — all props dead."""
- from pyjsclear.transforms.dead_class_props import DeadClassPropRemover
-
code = '''
var Cls = class {};
Cls.a = "x";
@@ -87,8 +81,6 @@ def test_fully_dead_class(self):
def test_assignment_class_detected(self):
"""Class assigned via `X = class {}` (not var) should be detected."""
- from pyjsclear.transforms.dead_class_props import DeadClassPropRemover
-
code = '''
var Cls;
Cls = class {};
@@ -100,8 +92,6 @@ def test_assignment_class_detected(self):
def test_sequence_expression_dead_props(self):
"""Dead props inside SequenceExpression should be stripped."""
- from pyjsclear.transforms.dead_class_props import DeadClassPropRemover
-
code = '''
var Cls = class {};
Cls.dead1 = "a", Cls.dead2 = "b";
@@ -113,8 +103,6 @@ def test_sequence_expression_dead_props(self):
def test_sequence_expression_partial_removal(self):
"""Only dead props removed from a sequence; live ones kept."""
- from pyjsclear.transforms.dead_class_props import DeadClassPropRemover
-
code = '''
var Cls = class {};
Cls.dead = "a", Cls.live = "b";
@@ -130,51 +118,41 @@ class TestClassStringDecoderHelpers:
"""Tests for ClassStringDecoder helper functions."""
def test_get_member_names_computed_string(self):
- from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
-
node = {
'type': 'MemberExpression',
'object': {'type': 'Identifier', 'name': 'obj'},
'property': {'type': 'Literal', 'value': 'prop'},
'computed': True,
}
- assert _get_member_names(node) == ('obj', 'prop')
+ assert get_member_names(node) == ('obj', 'prop')
def test_get_member_names_non_identifier_prop(self):
- from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
-
node = {
'type': 'MemberExpression',
'object': {'type': 'Identifier', 'name': 'obj'},
'property': {'type': 'Literal', 'value': 42},
'computed': True,
}
- assert _get_member_names(node) == (None, None)
+ assert get_member_names(node) == (None, None)
def test_get_member_names_dot_notation(self):
- from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
-
node = {
'type': 'MemberExpression',
'object': {'type': 'Identifier', 'name': 'obj'},
'property': {'type': 'Identifier', 'name': 'prop'},
'computed': False,
}
- assert _get_member_names(node) == ('obj', 'prop')
+ assert get_member_names(node) == ('obj', 'prop')
def test_get_member_names_no_prop(self):
- from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
-
- assert _get_member_names(None) == (None, None)
- assert _get_member_names({'type': 'Literal'}) == (None, None)
+ assert get_member_names(None) == (None, None)
+ assert get_member_names({'type': 'Literal'}) == (None, None)
def test_get_member_names_non_identifier_object(self):
- from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
-
node = {
'type': 'MemberExpression',
'object': {'type': 'Literal', 'value': 1},
'property': {'type': 'Identifier', 'name': 'prop'},
'computed': False,
}
- assert _get_member_names(node) == (None, None)
+ assert get_member_names(node) == (None, None)
diff --git a/tests/unit/transforms/constant_prop_test.py b/tests/unit/transforms/constant_prop_test.py
index 6be27d8..a132f68 100644
--- a/tests/unit/transforms/constant_prop_test.py
+++ b/tests/unit/transforms/constant_prop_test.py
@@ -3,6 +3,7 @@
import pytest
from pyjsclear.transforms.constant_prop import ConstantProp
+from pyjsclear.transforms.constant_prop import _should_skip_reference
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
@@ -60,8 +61,7 @@ def test_skip_update_target(self):
"""Line 20: Reference used as update target should be skipped."""
code = 'var x = 5; x++;'
result, changed = roundtrip(code, ConstantProp)
- # x++ should not become 5++
- # x has writes (x++) so it's not constant, so no propagation
+ # x++ should not become 5++; x has writes so it's not constant
assert changed is False
def test_skip_declarator_id(self):
@@ -112,42 +112,30 @@ class TestConstantPropSkipReferenceEdgeCases:
def test_ref_parent_is_none(self):
"""Line 15: ref_parent is None → return True (skip)."""
- from pyjsclear.transforms.constant_prop import _should_skip_reference
-
assert _should_skip_reference(None, 'left') is True
def test_update_expression_parent(self):
"""Line 20: UpdateExpression parent → return True."""
- from pyjsclear.transforms.constant_prop import _should_skip_reference
-
parent = {'type': 'UpdateExpression', 'operator': '++', 'argument': {}}
assert _should_skip_reference(parent, 'argument') is True
def test_variable_declarator_id(self):
"""Line 22: VariableDeclarator id parent → return True."""
- from pyjsclear.transforms.constant_prop import _should_skip_reference
-
parent = {'type': 'VariableDeclarator', 'id': {}, 'init': None}
assert _should_skip_reference(parent, 'id') is True
def test_variable_declarator_init_not_skipped(self):
"""VariableDeclarator with key='init' should NOT be skipped."""
- from pyjsclear.transforms.constant_prop import _should_skip_reference
-
parent = {'type': 'VariableDeclarator', 'id': {}, 'init': {}}
assert _should_skip_reference(parent, 'init') is False
def test_assignment_expression_right_not_skipped(self):
"""AssignmentExpression with key='right' should NOT be skipped."""
- from pyjsclear.transforms.constant_prop import _should_skip_reference
-
parent = {'type': 'AssignmentExpression', 'operator': '=', 'left': {}, 'right': {}}
assert _should_skip_reference(parent, 'right') is False
def test_normal_parent_not_skipped(self):
"""Normal parent (e.g., CallExpression) should NOT be skipped."""
- from pyjsclear.transforms.constant_prop import _should_skip_reference
-
parent = {'type': 'CallExpression', 'callee': {}, 'arguments': []}
assert _should_skip_reference(parent, 'arguments') is False
@@ -163,10 +151,7 @@ def test_binding_with_assignments_not_removed(self):
assert changed is False
def test_non_dict_decl_node_skip(self):
- """Line 82: decl_node not a dict — should be skipped."""
- from pyjsclear.transforms.constant_prop import _should_skip_reference
-
- # This is a defensive check; test it doesn't crash in normal flow
+ """Line 82: decl_node not a dict — defensive check doesn't crash."""
code = 'const a = 1; y(a);'
result, changed = roundtrip(code, ConstantProp)
assert changed is True
diff --git a/tests/unit/transforms/control_flow_test.py b/tests/unit/transforms/control_flow_test.py
index e3f3fac..247c1c2 100644
--- a/tests/unit/transforms/control_flow_test.py
+++ b/tests/unit/transforms/control_flow_test.py
@@ -7,8 +7,8 @@
from tests.unit.conftest import roundtrip
-def rt(js_code):
- """Shorthand roundtrip for ControlFlowRecoverer."""
+def roundtrip_cff(js_code: str) -> tuple[str, bool]:
+ """Run a roundtrip through ControlFlowRecoverer and normalize output."""
code, changed = roundtrip(js_code, ControlFlowRecoverer)
return normalize(code), changed
@@ -161,50 +161,50 @@ class TestIsSplitCall:
"""Test the _is_split_call detection method."""
def setup_method(self):
- self.t = ControlFlowRecoverer(_program([]))
+ self.transform = ControlFlowRecoverer(_program([]))
def test_valid_split_call(self):
node = _split_call('1|0|3|2')
- assert self.t._is_split_call(node) is True
+ assert self.transform._is_split_call(node) is True
def test_non_dict_returns_false(self):
- assert self.t._is_split_call(None) is False
- assert self.t._is_split_call('string') is False
+ assert self.transform._is_split_call(None) is False
+ assert self.transform._is_split_call('string') is False
def test_non_call_expression(self):
- assert self.t._is_split_call({'type': 'Identifier', 'name': 'x'}) is False
+ assert self.transform._is_split_call({'type': 'Identifier', 'name': 'x'}) is False
def test_callee_not_member_expression(self):
node = _call_expr(_identifier('split'), [_literal('|')])
- assert self.t._is_split_call(node) is False
+ assert self.transform._is_split_call(node) is False
def test_object_not_string_literal(self):
node = _call_expr(
callee=_member_expr(_identifier('arr'), _identifier('split')),
arguments=[_literal('|')],
)
- assert self.t._is_split_call(node) is False
+ assert self.transform._is_split_call(node) is False
def test_property_not_split(self):
node = _call_expr(
callee=_member_expr(_literal('1|2'), _identifier('join')),
arguments=[_literal('|')],
)
- assert self.t._is_split_call(node) is False
+ assert self.transform._is_split_call(node) is False
def test_no_arguments(self):
node = _call_expr(
callee=_member_expr(_literal('1|2'), _identifier('split')),
arguments=[],
)
- assert self.t._is_split_call(node) is False
+ assert self.transform._is_split_call(node) is False
def test_argument_not_string(self):
node = _call_expr(
callee=_member_expr(_literal('1|2'), _identifier('split')),
arguments=[_literal(1)],
)
- assert self.t._is_split_call(node) is False
+ assert self.transform._is_split_call(node) is False
# ---------------------------------------------------------------------------
@@ -216,23 +216,23 @@ class TestExtractSplitStates:
"""Test the _extract_split_states method."""
def setup_method(self):
- self.t = ControlFlowRecoverer(_program([]))
+ self.transform = ControlFlowRecoverer(_program([]))
def test_basic_extraction(self):
node = _split_call('1|0|3|2')
- assert self.t._extract_split_states(node) == ['1', '0', '3', '2']
+ assert self.transform._extract_split_states(node) == ['1', '0', '3', '2']
def test_single_state(self):
node = _split_call('0')
- assert self.t._extract_split_states(node) == ['0']
+ assert self.transform._extract_split_states(node) == ['0']
def test_five_states(self):
node = _split_call('4|2|0|1|3')
- assert self.t._extract_split_states(node) == ['4', '2', '0', '1', '3']
+ assert self.transform._extract_split_states(node) == ['4', '2', '0', '1', '3']
def test_custom_separator(self):
node = _split_call('a-b-c', separator='-')
- assert self.t._extract_split_states(node) == ['a', 'b', 'c']
+ assert self.transform._extract_split_states(node) == ['a', 'b', 'c']
# ---------------------------------------------------------------------------
@@ -250,8 +250,8 @@ def test_two_states_reordered(self):
'1': [_expr_stmt(_call_expr(_identifier('a'), []))],
}
ast = _make_cff_ast_var_pattern('1|0', '_a', '_i', cases)
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is True
body = ast['body']
@@ -270,8 +270,8 @@ def test_three_states(self):
'2': [_expr_stmt(_call_expr(_identifier('c'), []))],
}
ast = _make_cff_ast_var_pattern('2|0|1', '_a', '_i', cases)
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is True
body = ast['body']
@@ -288,8 +288,8 @@ def test_sequential_order(self):
'2': [_expr_stmt(_call_expr(_identifier('c'), []))],
}
ast = _make_cff_ast_var_pattern('0|1|2', '_a', '_i', cases)
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is True
body = ast['body']
@@ -304,8 +304,8 @@ def test_continue_statements_filtered(self):
'0': [_expr_stmt(_call_expr(_identifier('a'), []))],
}
ast = _make_cff_ast_var_pattern('0', '_a', '_i', cases)
- t = ControlFlowRecoverer(ast)
- t.execute()
+ transform = ControlFlowRecoverer(ast)
+ transform.execute()
body = ast['body']
for stmt in body:
@@ -318,8 +318,8 @@ def test_return_statement_preserved(self):
'1': [_return_stmt(_literal(42))],
}
ast = _make_cff_ast_var_pattern('0|1', '_a', '_i', cases)
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is True
body = ast['body']
@@ -344,8 +344,8 @@ def test_expression_pattern_two_states(self):
'1': [_expr_stmt(_call_expr(_identifier('a'), []))],
}
ast = _make_cff_ast_expr_pattern('1|0', '_a', '_i', cases)
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is True
body = ast['body']
@@ -360,8 +360,8 @@ def test_expression_pattern_three_states(self):
'2': [_expr_stmt(_call_expr(_identifier('z'), []))],
}
ast = _make_cff_ast_expr_pattern('2|1|0', '_arr', '_idx', cases)
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is True
body = ast['body']
@@ -386,8 +386,8 @@ def test_plain_statements_unchanged(self):
_expr_stmt(_call_expr(_identifier('bar'), [])),
]
)
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is False
assert len(ast['body']) == 2
@@ -395,16 +395,16 @@ def test_plain_statements_unchanged(self):
def test_while_without_switch_unchanged(self):
loop = _while_true([_expr_stmt(_call_expr(_identifier('doStuff'), []))])
ast = _program([loop])
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is False
def test_var_decl_without_split_unchanged(self):
decl = _var_declaration([_var_declarator('x', _literal(10))])
ast = _program([decl])
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is False
@@ -424,7 +424,7 @@ def test_basic_cff_roundtrip(self):
' switch (_a[_i++]) { case "0": b(); continue; case "1": a(); continue; }'
' break; }'
)
- code, changed = rt(js)
+ code, changed = roundtrip_cff(js)
assert changed is True
assert 'a();' in code
assert 'b();' in code
@@ -438,7 +438,7 @@ def test_three_state_roundtrip(self):
' switch (_s[_c++]) { case "0": first(); continue; case "1": second(); continue; case "2": third(); continue; }'
' break; }'
)
- code, changed = rt(js)
+ code, changed = roundtrip_cff(js)
assert changed is True
# Order: "2|0|1" => third, first, second
assert code.index('third()') < code.index('first()')
@@ -495,35 +495,35 @@ class TestIsTruthy:
"""Test the _is_truthy helper method."""
def setup_method(self):
- self.t = ControlFlowRecoverer(_program([]))
+ self.transform = ControlFlowRecoverer(_program([]))
def test_literal_true(self):
- assert self.t._is_truthy(_literal(True)) is True
+ assert self.transform._is_truthy(_literal(True)) is True
def test_literal_1(self):
- assert self.t._is_truthy(_literal(1)) is True
+ assert self.transform._is_truthy(_literal(1)) is True
def test_literal_false(self):
- assert self.t._is_truthy(_literal(False)) is False
+ assert self.transform._is_truthy(_literal(False)) is False
def test_literal_0(self):
- assert self.t._is_truthy(_literal(0)) is False
+ assert self.transform._is_truthy(_literal(0)) is False
def test_not_zero_is_truthy(self):
"""!0 should be recognized as truthy."""
node = {'type': 'UnaryExpression', 'operator': '!', 'argument': _literal(0), 'prefix': True}
- assert self.t._is_truthy(node) is True
+ assert self.transform._is_truthy(node) is True
def test_not_dict(self):
- assert self.t._is_truthy(None) is False
- assert self.t._is_truthy('string') is False
+ assert self.transform._is_truthy(None) is False
+ assert self.transform._is_truthy('string') is False
def test_double_not_array_is_truthy(self):
"""!![] should be truthy."""
inner = {'type': 'ArrayExpression', 'elements': []}
not_inner = {'type': 'UnaryExpression', 'operator': '!', 'argument': inner, 'prefix': True}
double_not = {'type': 'UnaryExpression', 'operator': '!', 'argument': not_inner, 'prefix': True}
- assert self.t._is_truthy(double_not) is True
+ assert self.transform._is_truthy(double_not) is True
# ---------------------------------------------------------------------------
@@ -537,8 +537,8 @@ class TestNonDictInBody:
def test_non_dict_in_body_skipped(self):
"""Non-dict items in body should be skipped without crashing."""
ast = _program([None, 'not_a_dict', _expr_stmt(_call_expr(_identifier('a'), []))])
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is False
@@ -553,8 +553,8 @@ def test_expression_pattern_no_split(self):
_while_true([_break_stmt()]),
]
)
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is False
def test_expression_pattern_missing_loop(self):
@@ -563,8 +563,8 @@ def test_expression_pattern_missing_loop(self):
counter_stmt = _expr_stmt(_assignment('_i', _literal(0)))
# No loop after counter, just ends
ast = _program([assign_stmt, counter_stmt])
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is False
@@ -572,28 +572,28 @@ class TestFindCounterInit:
"""Lines 175-183, 194: _find_counter_init with VariableDeclaration and ExpressionStatement."""
def setup_method(self):
- self.t = ControlFlowRecoverer(_program([]))
+ self.transform = ControlFlowRecoverer(_program([]))
def test_variable_declaration_counter(self):
stmt = _var_declaration([_var_declarator('_i', _literal(0))])
- result = self.t._find_counter_init(stmt)
+ result = self.transform._find_counter_init(stmt)
assert result == '_i'
def test_expression_statement_counter(self):
stmt = _expr_stmt(_assignment('_i', _literal(0)))
- result = self.t._find_counter_init(stmt)
+ result = self.transform._find_counter_init(stmt)
assert result == '_i'
def test_non_numeric_init_ignored(self):
stmt = _var_declaration([_var_declarator('_i', _literal('hello'))])
- result = self.t._find_counter_init(stmt)
+ result = self.transform._find_counter_init(stmt)
assert result is None
def test_non_dict_returns_none(self):
- result = self.t._find_counter_init(None)
+ result = self.transform._find_counter_init(None)
assert result is None
- result = self.t._find_counter_init('not a dict')
+ result = self.transform._find_counter_init('not a dict')
assert result is None
@@ -608,7 +608,7 @@ def test_for_statement_recovery(self):
' switch (_a[_j++]) { case "0": b(); continue; case "1": a(); continue; }'
' break; }'
)
- code, changed = rt(js)
+ code, changed = roundtrip_cff(js)
assert changed is True
assert 'a()' in code
assert 'b()' in code
@@ -655,25 +655,25 @@ class TestExtractSwitchFromLoopBody:
"""Lines 302, 308-310: _extract_switch_from_loop_body edge cases."""
def setup_method(self):
- self.t = ControlFlowRecoverer(_program([]))
+ self.transform = ControlFlowRecoverer(_program([]))
def test_non_block_statement(self):
"""Non-BlockStatement body returns None."""
- result = self.t._extract_switch_from_loop_body(_expr_stmt(_call_expr(_identifier('a'), [])))
+ result = self.transform._extract_switch_from_loop_body(_expr_stmt(_call_expr(_identifier('a'), [])))
assert result is None
def test_switch_directly_as_body(self):
"""SwitchStatement directly as loop body."""
switch = _switch_stmt(_identifier('x'), [])
- result = self.t._extract_switch_from_loop_body(switch)
+ result = self.transform._extract_switch_from_loop_body(switch)
assert result is not None
assert result['type'] == 'SwitchStatement'
def test_non_dict_body(self):
- result = self.t._extract_switch_from_loop_body(None)
+ result = self.transform._extract_switch_from_loop_body(None)
assert result is None
- result = self.t._extract_switch_from_loop_body('not a dict')
+ result = self.transform._extract_switch_from_loop_body('not a dict')
assert result is None
@@ -688,7 +688,7 @@ def test_while_not_zero(self):
' switch(_a[_i++]) { case "0": b(); continue; case "1": a(); continue; }'
' break; }'
)
- result, changed = rt(code)
+ result, changed = roundtrip_cff(code)
assert changed
assert 'a()' in result
assert 'b()' in result
@@ -701,7 +701,7 @@ def test_while_double_not_array(self):
' switch(_a[_i++]) { case "0": a(); continue; case "1": b(); continue; }'
' break; }'
)
- result, changed = rt(code)
+ result, changed = roundtrip_cff(code)
assert changed
assert 'a()' in result
assert 'b()' in result
@@ -714,27 +714,27 @@ def test_case_with_return_in_roundtrip(self):
' while(true) { switch(_a[_i++]) { case "0": a(); continue; case "1": return b(); } break; }'
' }'
)
- result, changed = rt(code)
+ result, changed = roundtrip_cff(code)
assert changed
assert 'return' in result
def test_is_truthy_not_array_is_false(self):
"""![] is falsy (line 324)."""
- t = ControlFlowRecoverer(_program([]))
+ transform = ControlFlowRecoverer(_program([]))
node = {
'type': 'UnaryExpression',
'operator': '!',
'argument': {'type': 'ArrayExpression', 'elements': []},
'prefix': True,
}
- assert t._is_truthy(node) is False
+ assert transform._is_truthy(node) is False
def test_is_truthy_literal_non_bool(self):
"""Literal with non-bool truthy value (line 317)."""
- t = ControlFlowRecoverer(_program([]))
- assert t._is_truthy(_literal(42)) is True
- assert t._is_truthy(_literal('')) is False
- assert t._is_truthy(_literal('hello')) is True
+ transform = ControlFlowRecoverer(_program([]))
+ assert transform._is_truthy(_literal(42)) is True
+ assert transform._is_truthy(_literal('')) is False
+ assert transform._is_truthy(_literal('hello')) is True
def test_visited_set_dedup(self):
"""Lines 36-37: visited set prevents re-processing the same node."""
@@ -742,8 +742,8 @@ def test_visited_set_dedup(self):
shared = _expr_stmt(_call_expr(_identifier('a'), []))
block = {'type': 'BlockStatement', 'body': [shared]}
ast = _program([block])
- t = ControlFlowRecoverer(ast)
- changed = t.execute()
+ transform = ControlFlowRecoverer(ast)
+ changed = transform.execute()
assert changed is False
diff --git a/tests/unit/transforms/expression_simplifier_test.py b/tests/unit/transforms/expression_simplifier_test.py
index 5469b49..f2e6fed 100644
--- a/tests/unit/transforms/expression_simplifier_test.py
+++ b/tests/unit/transforms/expression_simplifier_test.py
@@ -1,13 +1,16 @@
"""Unit tests for ExpressionSimplifier transform."""
+import math
+
import pytest
from pyjsclear.transforms.expression_simplifier import ExpressionSimplifier
+from pyjsclear.transforms.expression_simplifier import _JS_NULL
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
-def rt(js_code):
+def rt(js_code: str) -> tuple:
"""Shorthand roundtrip for ExpressionSimplifier."""
code, changed = roundtrip(js_code, ExpressionSimplifier)
return normalize(code), changed
@@ -265,7 +268,7 @@ def test_non_dict_returns_false(self):
val, ok = es._get_resolvable_value(None)
assert ok is False
- val, ok = es._get_resolvable_value("string")
+ val, ok = es._get_resolvable_value('string')
assert ok is False
@@ -462,14 +465,10 @@ def test_string_numeric(self):
assert es._js_to_number('5') == 5
def test_null_to_zero(self):
- from pyjsclear.transforms.expression_simplifier import _JS_NULL
-
es = ExpressionSimplifier({'type': 'Program', 'body': []})
assert es._js_to_number(_JS_NULL) == 0
def test_undefined_to_nan(self):
- import math
-
es = ExpressionSimplifier({'type': 'Program', 'body': []})
result = es._js_to_number(None)
assert math.isnan(result)
@@ -483,8 +482,6 @@ class TestJsToString:
"""Lines 343-358: _js_to_string for various types."""
def test_null_to_string(self):
- from pyjsclear.transforms.expression_simplifier import _JS_NULL
-
es = ExpressionSimplifier({'type': 'Program', 'body': []})
assert es._js_to_string(_JS_NULL) == 'null'
@@ -516,8 +513,6 @@ def test_string_comparison(self):
assert es._js_compare('a', 'a') == 0
def test_nan_comparison(self):
- import math
-
es = ExpressionSimplifier({'type': 'Program', 'body': []})
result = es._js_compare(float('nan'), 1)
assert math.isnan(result)
@@ -530,8 +525,6 @@ class TestValueToNode:
"""Lines 384-403: _value_to_node for various types."""
def test_null_value(self):
- from pyjsclear.transforms.expression_simplifier import _JS_NULL
-
es = ExpressionSimplifier({'type': 'Program', 'body': []})
result = es._value_to_node(_JS_NULL)
assert result['type'] == 'Literal'
@@ -622,7 +615,6 @@ def test_unary_minus_undefined_returns_nan_no_node(self):
"""-undefined → NaN, which _value_to_node returns None for."""
code, changed = rt('-undefined;')
# -undefined produces NaN, which can't be represented as a literal
- # so it stays unchanged
assert changed is False
diff --git a/tests/unit/transforms/hex_numerics_test.py b/tests/unit/transforms/hex_numerics_test.py
index cb797ba..1342487 100644
--- a/tests/unit/transforms/hex_numerics_test.py
+++ b/tests/unit/transforms/hex_numerics_test.py
@@ -1,5 +1,7 @@
"""Tests for the HexNumerics transform."""
+from pyjsclear.generator import generate
+from pyjsclear.parser import parse
from pyjsclear.transforms.hex_numerics import HexNumerics
from tests.unit.conftest import roundtrip
@@ -47,9 +49,6 @@ def test_float_unchanged(self):
def test_no_raw_field_unchanged(self):
"""Numeric literal without a raw field should not be transformed."""
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
ast = parse('var x = 1;')
# Manually remove raw to simulate a synthetic node
for stmt in ast['body']:
@@ -62,9 +61,6 @@ def test_no_raw_field_unchanged(self):
def test_negative_hex_value(self):
"""Hex literal with negative value uses plain str() path (line 26)."""
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
ast = parse('var x = 0x1;')
# Manually set the value to a negative to test the else branch
decl = ast['body'][0]['declarations'][0]
diff --git a/tests/unit/transforms/logical_to_if_test.py b/tests/unit/transforms/logical_to_if_test.py
index 98791f5..bab9130 100644
--- a/tests/unit/transforms/logical_to_if_test.py
+++ b/tests/unit/transforms/logical_to_if_test.py
@@ -1,5 +1,7 @@
-import pytest
+"""Tests for the LogicalToIf transform."""
+from pyjsclear.generator import generate
+from pyjsclear.parser import parse
from pyjsclear.transforms.logical_to_if import LogicalToIf
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
@@ -150,9 +152,6 @@ def test_return_single_element_sequence(self):
"""Line 104: Return with single-element sequence (len <= 1) returns None."""
# Manually constructing is tricky; a single-element SequenceExpression
# is unusual. We test via the AST directly.
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
ast = parse('function f() { return a; }')
# Manually make the return argument a SequenceExpression with 1 element
ret_stmt = ast['body'][0]['body']['body'][0]
@@ -181,9 +180,6 @@ def test_return_logical_right_not_sequence(self):
def test_return_logical_right_sequence_single_element(self):
"""Line 123: Return logical where right side is sequence with <=1 elements."""
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
ast = parse('function f() { return a || b; }')
ret_stmt = ast['body'][0]['body']['body'][0]
ret_stmt['argument'] = {
@@ -201,9 +197,6 @@ def test_return_logical_right_sequence_single_element(self):
def test_nullish_coalescing_not_converted(self):
"""Lines 147-148: _logical_to_if with unknown operator (e.g. '??') returns None."""
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
ast = parse('a ?? b();')
# Esprima may not parse ?? as LogicalExpression, so force it
expr_stmt = ast['body'][0]
@@ -219,8 +212,6 @@ def test_nullish_coalescing_not_converted(self):
def test_expression_stmt_non_dict_expression(self):
"""Line 75: ExpressionStatement with non-dict expression returns None."""
- from pyjsclear.parser import parse
-
ast = parse('a();')
ast['body'][0]['expression'] = 42
t = LogicalToIf(ast)
diff --git a/tests/unit/transforms/member_chain_resolver_test.py b/tests/unit/transforms/member_chain_resolver_test.py
index afad3b3..1d43c72 100644
--- a/tests/unit/transforms/member_chain_resolver_test.py
+++ b/tests/unit/transforms/member_chain_resolver_test.py
@@ -1,6 +1,7 @@
"""Tests for the MemberChainResolver transform."""
from pyjsclear.transforms.member_chain_resolver import MemberChainResolver
+from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
@@ -90,13 +91,9 @@ class TestHelperFunctions:
"""Direct tests for _get_member_names helper."""
def test_get_member_names_none(self):
- from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
-
assert _get_member_names(None) == (None, None)
def test_get_member_names_no_prop(self):
- from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
-
node = {
'type': 'MemberExpression',
'object': {'type': 'Identifier', 'name': 'x'},
@@ -106,8 +103,6 @@ def test_get_member_names_no_prop(self):
assert _get_member_names(node) == (None, None)
def test_get_member_names_computed_non_string(self):
- from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
-
node = {
'type': 'MemberExpression',
'object': {'type': 'Identifier', 'name': 'x'},
@@ -117,8 +112,6 @@ def test_get_member_names_computed_non_string(self):
assert _get_member_names(node) == (None, None)
def test_get_member_names_non_identifier_obj(self):
- from pyjsclear.utils.ast_helpers import get_member_names as _get_member_names
-
node = {
'type': 'MemberExpression',
'object': {'type': 'Literal', 'value': 1},
diff --git a/tests/unit/transforms/object_packer_test.py b/tests/unit/transforms/object_packer_test.py
index d9afd53..118dc1b 100644
--- a/tests/unit/transforms/object_packer_test.py
+++ b/tests/unit/transforms/object_packer_test.py
@@ -1,5 +1,8 @@
+"""Tests for the ObjectPacker transform."""
+
import pytest
+from pyjsclear.parser import parse
from pyjsclear.transforms.object_packer import ObjectPacker
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
@@ -128,7 +131,6 @@ def test_object_name_mismatch(self):
def test_property_node_is_none(self):
"""Line 87: Property node is None stops packing."""
from pyjsclear.generator import generate
- from pyjsclear.parser import parse
ast = parse('var o = {}; o.x = 1;')
# Manually set the property of the MemberExpression to None
@@ -165,8 +167,6 @@ def test_references_name_identifier_match(self):
def test_non_dict_in_body_direct_ast(self):
"""Line 22/57: non-dict in body triggers skip in _process_bodies and _try_pack_body."""
- from pyjsclear.parser import parse
-
ast = parse('var o = {}; o.x = 1;')
ast['body'].append(42)
t = ObjectPacker(ast)
diff --git a/tests/unit/transforms/object_simplifier_test.py b/tests/unit/transforms/object_simplifier_test.py
index feec01a..b2ea715 100644
--- a/tests/unit/transforms/object_simplifier_test.py
+++ b/tests/unit/transforms/object_simplifier_test.py
@@ -1,5 +1,9 @@
+"""Tests for the ObjectSimplifier transform."""
+
import pytest
+from pyjsclear.generator import generate
+from pyjsclear.parser import parse
from pyjsclear.transforms.object_simplifier import ObjectSimplifier
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
@@ -182,9 +186,6 @@ def test_recurse_into_child_scopes(self):
def test_is_proxy_object_non_property_type(self):
"""Lines 121, 124: _is_proxy_object with non-Property type or missing value."""
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
ast = parse('const o = {x: 1}; y(o.x);')
t = ObjectSimplifier(ast)
# SpreadElement is not a Property
@@ -194,8 +195,6 @@ def test_is_proxy_object_non_property_type(self):
def test_get_property_key_no_key(self):
"""Lines 136: _get_property_key returns None when key is missing."""
- from pyjsclear.parser import parse
-
ast = parse('const o = {x: 1};')
t = ObjectSimplifier(ast)
assert t._get_property_key({}) is None
@@ -203,8 +202,6 @@ def test_get_property_key_no_key(self):
def test_computed_property_key_ignored(self):
"""Line 46: property key returns None for computed key (not Identifier or string Literal)."""
- from pyjsclear.parser import parse
-
ast = parse('const o = {x: 1}; var y = o.x;')
t = ObjectSimplifier(ast)
# A property with a computed (non-string, non-identifier) key returns None
@@ -212,8 +209,6 @@ def test_computed_property_key_ignored(self):
def test_has_property_assignment_me_parent_info_none(self):
"""Line 97: _has_property_assignment where find_parent returns None for me."""
- from pyjsclear.parser import parse
-
# Build a scenario with a detached member expression
ast = parse('const o = {x: 1}; var y = o.x;')
t = ObjectSimplifier(ast)
@@ -223,8 +218,6 @@ def test_has_property_assignment_me_parent_info_none(self):
def test_try_inline_function_call_me_parent_info_none(self):
"""Line 107: _try_inline_function_call where find_parent returns None."""
- from pyjsclear.parser import parse
-
ast = parse('const o = {fn: function(a) { return a; }}; o.fn(1);')
t = ObjectSimplifier(ast)
changed = t.execute()
@@ -232,8 +225,6 @@ def test_try_inline_function_call_me_parent_info_none(self):
def test_get_member_prop_name_no_property(self):
"""Line 148: _get_member_prop_name with no property returns None."""
- from pyjsclear.parser import parse
-
ast = parse('const o = {x: 1};')
t = ObjectSimplifier(ast)
assert t._get_member_prop_name({}) is None
@@ -241,12 +232,9 @@ def test_get_member_prop_name_no_property(self):
def test_body_not_block_not_expression(self):
"""Line 183: body that's not BlockStatement and not expression for non-arrow."""
- from pyjsclear.parser import parse
-
ast = parse('const o = {fn: function(a) { return a; }}; o.fn(1);')
t = ObjectSimplifier(ast)
# Manually set the function body to something that is not a BlockStatement
- # Find the function in the object
props = ast['body'][0]['declarations'][0]['init']['properties']
func = props[0]['value']
func['body'] = {'type': 'Literal', 'value': 1} # Not a BlockStatement
diff --git a/tests/unit/transforms/optional_chaining_test.py b/tests/unit/transforms/optional_chaining_test.py
index b0b931c..387f739 100644
--- a/tests/unit/transforms/optional_chaining_test.py
+++ b/tests/unit/transforms/optional_chaining_test.py
@@ -1,5 +1,7 @@
"""Tests for the OptionalChaining transform."""
+from pyjsclear.generator import generate
+from pyjsclear.parser import parse
from pyjsclear.transforms.optional_chaining import OptionalChaining
from pyjsclear.transforms.optional_chaining import _nodes_match
from pyjsclear.utils.ast_helpers import is_null_literal
@@ -10,7 +12,7 @@
class TestSimplePattern:
"""Tests for X === null || X === undefined ? undefined : X.prop → X?.prop."""
- def test_basic_member_access(self):
+ def test_basic_member_access(self) -> None:
code, changed = roundtrip(
'var y = x === null || x === undefined ? undefined : x.foo;',
OptionalChaining,
@@ -18,7 +20,7 @@ def test_basic_member_access(self):
assert changed is True
assert 'x?.foo' in code
- def test_computed_member_access(self):
+ def test_computed_member_access(self) -> None:
code, changed = roundtrip(
'var y = x === null || x === undefined ? undefined : x["foo"];',
OptionalChaining,
@@ -26,7 +28,7 @@ def test_computed_member_access(self):
assert changed is True
assert '?.["foo"]' in code
- def test_reversed_null_undefined(self):
+ def test_reversed_null_undefined(self) -> None:
code, changed = roundtrip(
'var y = x === undefined || x === null ? undefined : x.foo;',
OptionalChaining,
@@ -34,7 +36,7 @@ def test_reversed_null_undefined(self):
assert changed is True
assert 'x?.foo' in code
- def test_void_0_as_undefined(self):
+ def test_void_0_as_undefined(self) -> None:
code, changed = roundtrip(
'var y = x === null || x === void 0 ? void 0 : x.foo;',
OptionalChaining,
@@ -46,7 +48,7 @@ def test_void_0_as_undefined(self):
class TestTempAssignmentPattern:
"""Tests for (_tmp = expr) === null || _tmp === undefined ? undefined : _tmp.prop."""
- def test_temp_var_member(self):
+ def test_temp_var_member(self) -> None:
code, changed = roundtrip(
'var y = (_tmp = obj.prop) === null || _tmp === undefined ? undefined : _tmp.nested;',
OptionalChaining,
@@ -54,7 +56,7 @@ def test_temp_var_member(self):
assert changed is True
assert 'obj.prop?.nested' in code
- def test_temp_var_eliminates_temp(self):
+ def test_temp_var_eliminates_temp(self) -> None:
"""The temp variable should not appear in the output."""
code, changed = roundtrip(
'var y = (_tmp = obj.a) === null || _tmp === undefined ? undefined : _tmp.b;',
@@ -67,21 +69,21 @@ def test_temp_var_eliminates_temp(self):
class TestNoTransform:
"""Cases that should NOT trigger the transform."""
- def test_consequent_not_undefined(self):
+ def test_consequent_not_undefined(self) -> None:
code, changed = roundtrip(
'var y = x === null || x === undefined ? 0 : x.foo;',
OptionalChaining,
)
assert changed is False
- def test_different_variables_in_checks(self):
+ def test_different_variables_in_checks(self) -> None:
code, changed = roundtrip(
'var y = x === null || z === undefined ? undefined : x.foo;',
OptionalChaining,
)
assert changed is False
- def test_alternate_is_plain_identifier(self):
+ def test_alternate_is_plain_identifier(self) -> None:
"""x?. would require the alternate to be a member/call, not just x."""
code, changed = roundtrip(
'var y = x === null || x === undefined ? undefined : x;',
@@ -89,7 +91,7 @@ def test_alternate_is_plain_identifier(self):
)
assert changed is False
- def test_and_operator_not_or(self):
+ def test_and_operator_not_or(self) -> None:
"""&& instead of || should not match (that's the nullish coalescing pattern)."""
code, changed = roundtrip(
'var y = x === null && x === undefined ? undefined : x.foo;',
@@ -97,7 +99,7 @@ def test_and_operator_not_or(self):
)
assert changed is False
- def test_not_equality_check(self):
+ def test_not_equality_check(self) -> None:
"""!== instead of === should not match."""
code, changed = roundtrip(
'var y = x !== null || x !== undefined ? undefined : x.foo;',
@@ -109,7 +111,7 @@ def test_not_equality_check(self):
class TestOptionalCall:
"""Tests for X === null || X === undefined ? undefined : X() → X?.()."""
- def test_optional_call(self):
+ def test_optional_call(self) -> None:
code, changed = roundtrip(
'var y = fn === null || fn === undefined ? undefined : fn();',
OptionalChaining,
@@ -117,7 +119,7 @@ def test_optional_call(self):
assert changed is True
assert 'fn?.()' in code
- def test_optional_call_with_args(self):
+ def test_optional_call_with_args(self) -> None:
code, changed = roundtrip(
'var y = fn === null || fn === undefined ? undefined : fn(1, 2);',
OptionalChaining,
@@ -129,7 +131,7 @@ def test_optional_call_with_args(self):
class TestYodaStyle:
"""Tests for null/undefined on the left side of comparison."""
- def test_null_on_left(self):
+ def test_null_on_left(self) -> None:
"""null === x || undefined === x ? undefined : x.foo → x?.foo."""
code, changed = roundtrip(
'var y = null === x || undefined === x ? undefined : x.foo;',
@@ -138,7 +140,7 @@ def test_null_on_left(self):
assert changed is True
assert 'x?.foo' in code
- def test_void_0_on_consequent(self):
+ def test_void_0_on_consequent(self) -> None:
"""void 0 as consequent should be recognized as undefined."""
code, changed = roundtrip(
'var y = x === null || x === void 0 ? void 0 : x.foo;',
@@ -150,11 +152,11 @@ def test_void_0_on_consequent(self):
class TestHelperFunctions:
"""Direct tests for helper functions to cover edge cases."""
- def test_is_undefined_with_non_dict(self):
+ def test_is_undefined_with_non_dict(self) -> None:
assert is_undefined(None) is False
assert is_undefined('string') is False
- def test_is_undefined_with_void_0(self):
+ def test_is_undefined_with_void_0(self) -> None:
node = {
'type': 'UnaryExpression',
'operator': 'void',
@@ -162,7 +164,7 @@ def test_is_undefined_with_void_0(self):
}
assert is_undefined(node) is True
- def test_is_undefined_with_void_non_zero(self):
+ def test_is_undefined_with_void_non_zero(self) -> None:
node = {
'type': 'UnaryExpression',
'operator': 'void',
@@ -170,23 +172,23 @@ def test_is_undefined_with_void_non_zero(self):
}
assert is_undefined(node) is False
- def test_is_null_literal_true(self):
+ def test_is_null_literal_true(self) -> None:
assert is_null_literal({'type': 'Literal', 'value': None, 'raw': 'null'}) is True
- def test_is_null_literal_false(self):
+ def test_is_null_literal_false(self) -> None:
assert is_null_literal({'type': 'Literal', 'value': 0}) is False
assert is_null_literal(None) is False
- def test_nodes_match_non_dict(self):
+ def test_nodes_match_non_dict(self) -> None:
assert _nodes_match(None, None) is False
assert _nodes_match({}, None) is False
- def test_nodes_match_different_types(self):
+ def test_nodes_match_different_types(self) -> None:
a = {'type': 'Identifier', 'name': 'x'}
b = {'type': 'Literal', 'value': 1}
assert _nodes_match(a, b) is False
- def test_nodes_match_member_expression(self):
+ def test_nodes_match_member_expression(self) -> None:
a = {
'type': 'MemberExpression',
'object': {'type': 'Identifier', 'name': 'obj'},
@@ -201,7 +203,7 @@ def test_nodes_match_member_expression(self):
}
assert _nodes_match(a, b) is True
- def test_nodes_match_unknown_type_returns_false(self):
+ def test_nodes_match_unknown_type_returns_false(self) -> None:
a = {'type': 'CallExpression'}
b = {'type': 'CallExpression'}
assert _nodes_match(a, b) is False
@@ -210,34 +212,22 @@ def test_nodes_match_unknown_type_returns_false(self):
class TestGeneratorOutput:
"""Tests that ?. roundtrips through parse → generate correctly."""
- def test_optional_member_roundtrip(self):
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
+ def test_optional_member_roundtrip(self) -> None:
code = 'var x = a?.b;'
result = generate(parse(code))
assert 'a?.b' in result
- def test_optional_computed_roundtrip(self):
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
+ def test_optional_computed_roundtrip(self) -> None:
code = 'var x = a?.[0];'
result = generate(parse(code))
assert 'a?.[0]' in result
- def test_optional_call_roundtrip(self):
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
+ def test_optional_call_roundtrip(self) -> None:
code = 'var x = fn?.();'
result = generate(parse(code))
assert 'fn?.()' in result
- def test_chained_optional(self):
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
+ def test_chained_optional(self) -> None:
code = 'var x = a?.b?.c;'
result = generate(parse(code))
assert 'a?.b?.c' in result
diff --git a/tests/unit/transforms/property_simplifier_test.py b/tests/unit/transforms/property_simplifier_test.py
index ff2132f..ac44eef 100644
--- a/tests/unit/transforms/property_simplifier_test.py
+++ b/tests/unit/transforms/property_simplifier_test.py
@@ -1,5 +1,3 @@
-import pytest
-
from pyjsclear.transforms.property_simplifier import PropertySimplifier
from tests.unit.conftest import roundtrip
@@ -7,27 +5,27 @@
class TestBracketToDot:
"""Tests for converting obj["prop"] to obj.prop."""
- def test_simple_bracket_to_dot(self):
+ def test_simple_bracket_to_dot(self) -> None:
code, changed = roundtrip('obj["foo"];', PropertySimplifier)
assert changed is True
assert code == 'obj.foo;'
- def test_invalid_identifier_unchanged(self):
+ def test_invalid_identifier_unchanged(self) -> None:
code, changed = roundtrip('obj["0abc"];', PropertySimplifier)
assert changed is False
assert '"0abc"' in code or "'0abc'" in code
- def test_reserved_word_is_valid_identifier(self):
+ def test_reserved_word_is_valid_identifier(self) -> None:
code, changed = roundtrip('obj["class"];', PropertySimplifier)
assert changed is True
assert code == 'obj.class;'
- def test_non_string_computed_unchanged(self):
+ def test_non_string_computed_unchanged(self) -> None:
code, changed = roundtrip('obj[x];', PropertySimplifier)
assert changed is False
assert 'obj[x]' in code
- def test_already_dot_notation_unchanged(self):
+ def test_already_dot_notation_unchanged(self) -> None:
code, changed = roundtrip('obj.foo;', PropertySimplifier)
assert changed is False
assert 'obj.foo' in code
@@ -36,18 +34,18 @@ def test_already_dot_notation_unchanged(self):
class TestObjectLiteralKeys:
"""Tests for simplifying string literal keys in object literals."""
- def test_string_key_to_identifier(self):
+ def test_string_key_to_identifier(self) -> None:
"""String key becomes Identifier."""
code, changed = roundtrip('var x = {"foo": 1};', PropertySimplifier)
assert changed is True
assert 'foo: 1' in code or 'foo:' in code
- def test_invalid_identifier_key_unchanged(self):
+ def test_invalid_identifier_key_unchanged(self) -> None:
code, changed = roundtrip('var x = {"0abc": 1};', PropertySimplifier)
assert changed is False
assert '0abc' in code
- def test_already_identifier_key_no_change(self):
+ def test_already_identifier_key_no_change(self) -> None:
code, changed = roundtrip('var x = {foo: 1};', PropertySimplifier)
assert changed is False
assert 'foo' in code
@@ -56,14 +54,14 @@ def test_already_identifier_key_no_change(self):
class TestMultipleProperties:
"""Tests for mixed property access patterns."""
- def test_mixed_bracket_and_dot(self):
+ def test_mixed_bracket_and_dot(self) -> None:
code, changed = roundtrip('obj["foo"]; obj.bar; obj["0bad"];', PropertySimplifier)
assert changed is True
assert 'obj.foo' in code
assert 'obj.bar' in code
assert '0bad' in code
- def test_multiple_object_literal_keys(self):
+ def test_multiple_object_literal_keys(self) -> None:
code, changed = roundtrip('var x = {"good": 1, "0bad": 2, ok: 3};', PropertySimplifier)
assert changed is True
# "good" converted to identifier key
diff --git a/tests/unit/transforms/proxy_functions_test.py b/tests/unit/transforms/proxy_functions_test.py
index 660b0a6..89567e8 100644
--- a/tests/unit/transforms/proxy_functions_test.py
+++ b/tests/unit/transforms/proxy_functions_test.py
@@ -1,14 +1,14 @@
"""Unit tests for ProxyFunctionInliner transform."""
-import pytest
-
+from pyjsclear.generator import generate
+from pyjsclear.parser import parse
from pyjsclear.transforms.proxy_functions import ProxyFunctionInliner
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
class TestProxyFunctionInlinerBasic:
- def test_binary_proxy_inlined(self):
+ def test_binary_proxy_inlined(self) -> None:
code, changed = roundtrip(
'function p(a, b) { return a + b; } p(1, 2);',
ProxyFunctionInliner,
@@ -19,7 +19,7 @@ def test_binary_proxy_inlined(self):
assert '1 + 2;' in norm
assert 'p(1, 2)' not in norm
- def test_call_proxy_inlined(self):
+ def test_call_proxy_inlined(self) -> None:
code, changed = roundtrip(
'function p(a, b) { return a(b); } p(foo, 1);',
ProxyFunctionInliner,
@@ -29,7 +29,7 @@ def test_call_proxy_inlined(self):
assert 'foo(1);' in norm
assert 'p(foo, 1)' not in norm
- def test_arrow_proxy_inlined(self):
+ def test_arrow_proxy_inlined(self) -> None:
code, changed = roundtrip(
'var p = (a) => a * 2; p(3);',
ProxyFunctionInliner,
@@ -41,7 +41,7 @@ def test_arrow_proxy_inlined(self):
class TestProxyFunctionInlinerSkips:
- def test_multi_statement_body_not_inlined(self):
+ def test_multi_statement_body_not_inlined(self) -> None:
code, changed = roundtrip(
'function p(a) { var x = 1; return a + x; } p(5);',
ProxyFunctionInliner,
@@ -49,7 +49,7 @@ def test_multi_statement_body_not_inlined(self):
assert changed is False
assert 'p(5)' in normalize(code)
- def test_non_constant_binding_not_inlined(self):
+ def test_non_constant_binding_not_inlined(self) -> None:
code, changed = roundtrip(
'var p = (a) => a * 2; p = something; p(3);',
ProxyFunctionInliner,
@@ -57,7 +57,7 @@ def test_non_constant_binding_not_inlined(self):
assert changed is False
assert 'p(3)' in normalize(code)
- def test_no_proxy_functions_returns_false(self):
+ def test_no_proxy_functions_returns_false(self) -> None:
code, changed = roundtrip(
'var x = 1 + 2;',
ProxyFunctionInliner,
@@ -66,7 +66,7 @@ def test_no_proxy_functions_returns_false(self):
class TestProxyFunctionInlinerEdgeCases:
- def test_missing_args_substitutes_undefined(self):
+ def test_missing_args_substitutes_undefined(self) -> None:
code, changed = roundtrip(
'function p(a, b) { return a + b; } p(1);',
ProxyFunctionInliner,
@@ -75,7 +75,7 @@ def test_missing_args_substitutes_undefined(self):
norm = normalize(code)
assert '1 + undefined' in norm
- def test_function_with_no_return_value_gives_undefined(self):
+ def test_function_with_no_return_value_gives_undefined(self) -> None:
code, changed = roundtrip(
'function p() { return; } var x = p();',
ProxyFunctionInliner,
@@ -83,7 +83,7 @@ def test_function_with_no_return_value_gives_undefined(self):
assert changed is True
assert 'undefined' in code
- def test_nested_calls_processed_innermost_first(self):
+ def test_nested_calls_processed_innermost_first(self) -> None:
code, changed = roundtrip(
'function p(a, b) { return a + b; } p(p(1, 2), 3);',
ProxyFunctionInliner,
@@ -98,21 +98,21 @@ def test_nested_calls_processed_innermost_first(self):
class TestProxyFunctionInlinerDisallowed:
- def test_function_expression_in_return_not_inlined(self):
+ def test_function_expression_in_return_not_inlined(self) -> None:
code, changed = roundtrip(
'function p() { return function() {}; } p();',
ProxyFunctionInliner,
)
assert changed is False
- def test_assignment_expression_in_return_not_inlined(self):
+ def test_assignment_expression_in_return_not_inlined(self) -> None:
code, changed = roundtrip(
'function p(a) { return a = 1; } p(x);',
ProxyFunctionInliner,
)
assert changed is False
- def test_sequence_expression_in_return_not_inlined(self):
+ def test_sequence_expression_in_return_not_inlined(self) -> None:
code, changed = roundtrip(
'function p(a) { return (1, a); } p(x);',
ProxyFunctionInliner,
@@ -123,7 +123,7 @@ def test_sequence_expression_in_return_not_inlined(self):
class TestCoverageGaps:
"""Tests targeting uncovered lines in proxy_functions.py."""
- def test_callee_not_identifier(self):
+ def test_callee_not_identifier(self) -> None:
"""Line 42: CallExpression with non-identifier callee (e.g., member expression)."""
code, changed = roundtrip(
'function p(a) { return a; } obj.p(1);',
@@ -132,7 +132,7 @@ def test_callee_not_identifier(self):
# obj.p(1) callee is MemberExpression, not Identifier — should not inline
assert 'obj.p(1)' in normalize(code)
- def test_destructuring_params_not_proxy(self):
+ def test_destructuring_params_not_proxy(self) -> None:
"""Line 109: Function with non-identifier params (destructuring) is not a proxy."""
code, changed = roundtrip(
'function f({a, b}) { return a + b; } f({a: 1, b: 2});',
@@ -140,11 +140,8 @@ def test_destructuring_params_not_proxy(self):
)
assert changed is False
- def test_body_is_none(self):
+ def test_body_is_none(self) -> None:
"""Line 113: Function body is None — not a proxy."""
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
ast = parse('function f() { return 1; } f();')
# Manually remove the body
func = ast['body'][0]
@@ -153,7 +150,7 @@ def test_body_is_none(self):
changed = t.execute()
assert changed is False
- def test_block_body_non_return_statement(self):
+ def test_block_body_non_return_statement(self) -> None:
"""Line 126: Block body with a non-return statement (e.g., expression)."""
code, changed = roundtrip(
'function f() { console.log(1); } f();',
@@ -161,7 +158,7 @@ def test_block_body_non_return_statement(self):
)
assert changed is False
- def test_arrow_in_return_not_proxy(self):
+ def test_arrow_in_return_not_proxy(self) -> None:
"""Lines 158-159: _is_proxy_value rejects ArrowFunctionExpression."""
code, changed = roundtrip(
'function f() { return () => 1; } f();',
@@ -169,7 +166,7 @@ def test_arrow_in_return_not_proxy(self):
)
assert changed is False
- def test_list_child_with_disallowed_type(self):
+ def test_list_child_with_disallowed_type(self) -> None:
"""Line 157: _is_proxy_value rejects disallowed type in list child."""
# Array with a function expression as element
code, changed = roundtrip(
@@ -178,10 +175,8 @@ def test_list_child_with_disallowed_type(self):
)
assert changed is False
- def test_get_replacement_body_none(self):
+ def test_get_replacement_body_none(self) -> None:
"""Line 166: _get_replacement when body is None returns undefined."""
- from pyjsclear.parser import parse
-
ast = parse('function f() { return 1; } f();')
t = ProxyFunctionInliner(ast)
# Manually create a func_node with no body
@@ -190,10 +185,8 @@ def test_get_replacement_body_none(self):
assert result is not None
assert result.get('name') == 'undefined'
- def test_get_replacement_block_empty(self):
+ def test_get_replacement_block_empty(self) -> None:
"""Line 174: _get_replacement with empty block body returns None."""
- from pyjsclear.parser import parse
-
ast = parse('var x = 1;')
t = ProxyFunctionInliner(ast)
func_node = {
@@ -204,10 +197,8 @@ def test_get_replacement_block_empty(self):
result = t._get_replacement(func_node, [])
assert result is None
- def test_get_replacement_block_non_return(self):
+ def test_get_replacement_block_non_return(self) -> None:
"""Line 174: _get_replacement with block body that starts with non-return."""
- from pyjsclear.parser import parse
-
ast = parse('var x = 1;')
t = ProxyFunctionInliner(ast)
func_node = {
@@ -221,10 +212,8 @@ def test_get_replacement_block_non_return(self):
result = t._get_replacement(func_node, [])
assert result is None
- def test_get_replacement_not_block_not_arrow(self):
+ def test_get_replacement_not_block_not_arrow(self) -> None:
"""Line 180: _get_replacement with non-block, non-arrow body returns None."""
- from pyjsclear.parser import parse
-
ast = parse('var x = 1;')
t = ProxyFunctionInliner(ast)
func_node = {
@@ -235,10 +224,8 @@ def test_get_replacement_not_block_not_arrow(self):
result = t._get_replacement(func_node, [])
assert result is None
- def test_get_replacement_return_no_argument(self):
+ def test_get_replacement_return_no_argument(self) -> None:
"""Line 177: _get_replacement with return but no argument gives undefined."""
- from pyjsclear.parser import parse
-
ast = parse('var x = 1;')
t = ProxyFunctionInliner(ast)
func_node = {
@@ -253,20 +240,16 @@ def test_get_replacement_return_no_argument(self):
assert result is not None
assert result.get('name') == 'undefined'
- def test_is_proxy_value_non_dict(self):
+ def test_is_proxy_value_non_dict(self) -> None:
"""Line 148: _is_proxy_value with non-dict returns False."""
- from pyjsclear.parser import parse
-
ast = parse('var x = 1;')
t = ProxyFunctionInliner(ast)
assert t._is_proxy_value('not_a_dict') is False
assert t._is_proxy_value(None) is False
assert t._is_proxy_value(42) is False
- def test_function_non_block_body_not_arrow(self):
+ def test_function_non_block_body_not_arrow(self) -> None:
"""Line 132: body not BlockStatement and func is not ArrowFunction."""
- from pyjsclear.parser import parse
-
ast = parse('function f(a) { return a; } f(1);')
func = ast['body'][0]
# Replace body with a non-block to trigger line 132
@@ -276,10 +259,8 @@ def test_function_non_block_body_not_arrow(self):
# Should not crash; function is not a proxy since body is not block or arrow expr
assert not changed
- def test_is_proxy_value_child_dict_disallowed(self):
+ def test_is_proxy_value_child_dict_disallowed(self) -> None:
"""Line 158-159: _is_proxy_value child dict with disallowed type."""
- from pyjsclear.parser import parse
-
ast = parse('var x = 1;')
t = ProxyFunctionInliner(ast)
# A node containing a child dict with a disallowed type
diff --git a/tests/unit/transforms/reassignment_test.py b/tests/unit/transforms/reassignment_test.py
index 5bcece6..2e7c253 100644
--- a/tests/unit/transforms/reassignment_test.py
+++ b/tests/unit/transforms/reassignment_test.py
@@ -1,7 +1,5 @@
"""Unit tests for ReassignmentRemover transform."""
-import pytest
-
from pyjsclear.transforms.reassignment import ReassignmentRemover
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
@@ -10,37 +8,37 @@
class TestReassignmentRemoverDeclaratorInline:
"""Tests for var x = y; declarator-based inlining."""
- def test_well_known_global_json(self):
+ def test_well_known_global_json(self) -> None:
code = 'var x = JSON; x.parse(y);'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is True
assert 'JSON.parse(y)' in normalize(result)
assert 'var x' not in normalize(result) or 'x.parse' not in normalize(result)
- def test_well_known_global_console(self):
+ def test_well_known_global_console(self) -> None:
code = 'var x = console; x.log("hi");'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is True
assert 'console.log("hi")' in normalize(result)
- def test_constant_binding_inline(self):
+ def test_constant_binding_inline(self) -> None:
code = 'const a = 1; var b = a; c(b);'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is True
assert 'c(a)' in normalize(result)
- def test_unknown_target_unchanged(self):
+ def test_unknown_target_unchanged(self) -> None:
code = 'var x = y; x.foo();'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is False
assert 'x.foo()' in normalize(result)
- def test_self_assignment_skipped(self):
+ def test_self_assignment_skipped(self) -> None:
code = 'var x = x;'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is False
- def test_non_constant_target_unchanged(self):
+ def test_non_constant_target_unchanged(self) -> None:
code = 'var y = 1; y = 2; var x = y; x.foo();'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is False
@@ -50,7 +48,7 @@ def test_non_constant_target_unchanged(self):
class TestReassignmentRemoverAssignmentAlias:
"""Tests for var x; ... x = y; assignment alias inlining."""
- def test_assignment_alias_console(self):
+ def test_assignment_alias_console(self) -> None:
code = 'var x; x = console; x.log("hi");'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is True
@@ -58,19 +56,19 @@ def test_assignment_alias_console(self):
# The assignment statement should be removed
assert 'x = console' not in normalize(result)
- def test_assignment_alias_json(self):
+ def test_assignment_alias_json(self) -> None:
code = 'var x; x = JSON; x.parse(s);'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is True
assert 'JSON.parse(s)' in normalize(result)
- def test_assignment_alias_unknown_target_unchanged(self):
+ def test_assignment_alias_unknown_target_unchanged(self) -> None:
code = 'var x; x = y; x.foo();'
result, changed = roundtrip(code, ReassignmentRemover)
# y is not well-known and not a constant binding, so no change
assert 'x.foo()' in normalize(result)
- def test_assignment_alias_non_constant_target_unchanged(self):
+ def test_assignment_alias_non_constant_target_unchanged(self) -> None:
code = 'var y = 1; y = 2; var x; x = y; x.foo();'
result, changed = roundtrip(code, ReassignmentRemover)
assert 'x.foo()' in normalize(result)
@@ -79,12 +77,12 @@ def test_assignment_alias_non_constant_target_unchanged(self):
class TestReassignmentRemoverNoChange:
"""Tests for cases where nothing changes."""
- def test_no_reassignments_returns_false(self):
+ def test_no_reassignments_returns_false(self) -> None:
code = 'var a = 1; console.log(a);'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is False
- def test_param_not_inlined(self):
+ def test_param_not_inlined(self) -> None:
code = 'function f(x) { var y = x; return y; }'
result, changed = roundtrip(code, ReassignmentRemover)
# x is a param binding; it is constant, so inlining y=x should work
@@ -92,7 +90,7 @@ def test_param_not_inlined(self):
# Let's just verify it doesn't crash
assert isinstance(changed, bool)
- def test_no_variable_declarations(self):
+ def test_no_variable_declarations(self) -> None:
code = 'console.log(1);'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is False
@@ -101,28 +99,28 @@ def test_no_variable_declarations(self):
class TestReassignmentRemoverEdgeCases:
"""Edge case tests."""
- def test_multiple_well_known_globals(self):
+ def test_multiple_well_known_globals(self) -> None:
code = 'var a = Object; var b = Array; a.keys(x); b.isArray(y);'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is True
assert 'Object.keys(x)' in normalize(result)
assert 'Array.isArray(y)' in normalize(result)
- def test_rebuild_scope_flag(self):
+ def test_rebuild_scope_flag(self) -> None:
assert ReassignmentRemover.rebuild_scope is True
class TestReassignmentRemoverSkipConditions:
"""Tests for skip conditions on reference replacement."""
- def test_reference_as_assignment_lhs_skipped(self):
+ def test_reference_as_assignment_lhs_skipped(self) -> None:
"""Line 93: Reference used as assignment left-hand side should be skipped."""
code = 'var x = console; x = something; x.log("hi");'
result, changed = roundtrip(code, ReassignmentRemover)
# x has writes so it is not constant, no inlining should happen
assert isinstance(changed, bool)
- def test_reference_as_declarator_id_skipped(self):
+ def test_reference_as_declarator_id_skipped(self) -> None:
"""Line 95: Reference used as VariableDeclarator id should be skipped."""
# When a variable is reassigned via declarator pattern, the id ref should be skipped
code = 'var x = JSON; x.parse(s);'
@@ -134,7 +132,7 @@ def test_reference_as_declarator_id_skipped(self):
class TestReassignmentRemoverAssignmentAliasEdgeCases:
"""Tests for assignment alias edge cases."""
- def test_assignment_alias_with_init_skipped(self):
+ def test_assignment_alias_with_init_skipped(self) -> None:
"""Line 131: VariableDeclarator with init should be skipped for assignment alias."""
code = 'var _0x1 = undefined; _0x1 = console; _0x1.log("hi");'
result, changed = roundtrip(code, ReassignmentRemover)
@@ -142,34 +140,34 @@ def test_assignment_alias_with_init_skipped(self):
# but it might be handled by the declarator path instead
assert isinstance(changed, bool)
- def test_assignment_alias_multiple_writes_skipped(self):
+ def test_assignment_alias_multiple_writes_skipped(self) -> None:
"""Line 151: Assignment alias with != 1 writes should be skipped."""
code = 'var _0x1; _0x1 = console; _0x1 = JSON; _0x1.log("hi");'
result, changed = roundtrip(code, ReassignmentRemover)
# Two writes means it won't be inlined via assignment alias
assert '_0x1' in result or changed is False
- def test_assignment_alias_rhs_not_identifier(self):
+ def test_assignment_alias_rhs_not_identifier(self) -> None:
"""Line 157: Assignment alias where right side is not an identifier."""
code = 'var _0x1; _0x1 = 123; console.log(_0x1);'
result, changed = roundtrip(code, ReassignmentRemover)
# RHS is a literal, not an identifier — should be skipped
assert '_0x1' in result
- def test_assignment_alias_self_assignment_skipped(self):
+ def test_assignment_alias_self_assignment_skipped(self) -> None:
"""Line 161: target_name equals name should be skipped."""
code = 'var _0x1; _0x1 = _0x1;'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is False
- def test_assignment_alias_replacement_with_index(self):
+ def test_assignment_alias_replacement_with_index(self) -> None:
"""Line 175: Assignment alias replacement in array position (with index)."""
code = 'var _0x1; _0x1 = console; foo([_0x1]);'
result, changed = roundtrip(code, ReassignmentRemover)
assert changed is True
assert 'console' in result
- def test_assignment_alias_pattern(self):
+ def test_assignment_alias_pattern(self) -> None:
"""Assignment alias: var x; x = console; log(x);"""
code = 'var _0x1 = undefined; var _0x2; _0x2 = console; _0x2.log("hi");'
result, changed = roundtrip(code, ReassignmentRemover)
diff --git a/tests/unit/transforms/require_inliner_test.py b/tests/unit/transforms/require_inliner_test.py
index ed17858..cb892b0 100644
--- a/tests/unit/transforms/require_inliner_test.py
+++ b/tests/unit/transforms/require_inliner_test.py
@@ -7,7 +7,7 @@
class TestRequirePolyfillDetection:
"""Tests for detecting and inlining require polyfill wrappers."""
- def test_typeof_require_polyfill(self):
+ def test_typeof_require_polyfill(self) -> None:
code = '''
var _0x544bfe = (function() { return typeof require !== "undefined" ? require : null; })();
_0x544bfe("fs");
@@ -18,17 +18,17 @@ def test_typeof_require_polyfill(self):
assert 'require("fs")' in result
assert 'require("path")' in result
- def test_preserves_non_polyfill_calls(self):
+ def test_preserves_non_polyfill_calls(self) -> None:
"""Regular function calls should not be changed."""
result, changed = roundtrip('myFunc("fs");', RequireInliner)
assert changed is False
assert 'require' not in result
- def test_no_polyfills_returns_false(self):
+ def test_no_polyfills_returns_false(self) -> None:
result, changed = roundtrip('var x = 1;', RequireInliner)
assert changed is False
- def test_multi_arg_call_unchanged(self):
+ def test_multi_arg_call_unchanged(self) -> None:
"""Polyfill calls with != 1 arg should not be replaced."""
code = '''
var _0x544bfe = (function() { return typeof require !== "undefined" ? require : null; })();
diff --git a/tests/unit/transforms/sequence_splitter_test.py b/tests/unit/transforms/sequence_splitter_test.py
index c91aaa1..43a6b50 100644
--- a/tests/unit/transforms/sequence_splitter_test.py
+++ b/tests/unit/transforms/sequence_splitter_test.py
@@ -1,18 +1,18 @@
-import pytest
-
+from pyjsclear.generator import generate
+from pyjsclear.parser import parse
from pyjsclear.transforms.sequence_splitter import SequenceSplitter
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
class TestSequenceSplittingInExpressionStatements:
- def test_splits_sequence_into_separate_statements(self):
+ def test_splits_sequence_into_separate_statements(self) -> None:
code = '(a(), b(), c());'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('a(); b(); c();')
- def test_splits_two_element_sequence(self):
+ def test_splits_two_element_sequence(self) -> None:
code = '(a(), b());'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
@@ -20,25 +20,25 @@ def test_splits_two_element_sequence(self):
class TestMultiVarSplitting:
- def test_splits_multi_declarator_var(self):
+ def test_splits_multi_declarator_var(self) -> None:
code = 'var a = 1, b = 2;'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('var a = 1; var b = 2;')
- def test_splits_three_declarators(self):
+ def test_splits_three_declarators(self) -> None:
code = 'var a = 1, b = 2, c = 3;'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('var a = 1; var b = 2; var c = 3;')
- def test_preserves_kind_let(self):
+ def test_preserves_kind_let(self) -> None:
code = 'let a = 1, b = 2;'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('let a = 1; let b = 2;')
- def test_preserves_kind_const(self):
+ def test_preserves_kind_const(self) -> None:
code = 'const a = 1, b = 2;'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
@@ -46,13 +46,13 @@ def test_preserves_kind_const(self):
class TestSingleDeclaratorSequenceInitSplitting:
- def test_splits_sequence_in_var_init(self):
+ def test_splits_sequence_in_var_init(self) -> None:
code = 'var x = (a, b, expr());'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('a; b; var x = expr();')
- def test_splits_two_element_sequence_in_init(self):
+ def test_splits_two_element_sequence_in_init(self) -> None:
code = 'var x = (a, expr());'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
@@ -60,13 +60,13 @@ def test_splits_two_element_sequence_in_init(self):
class TestIndirectCallPrefixExtraction:
- def test_extracts_zero_prefix(self):
+ def test_extracts_zero_prefix(self) -> None:
code = '(0, fn)(args);'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('0; fn(args);')
- def test_extracts_multiple_prefixes(self):
+ def test_extracts_multiple_prefixes(self) -> None:
code = '(0, 1, fn)(args);'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
@@ -74,19 +74,19 @@ def test_extracts_multiple_prefixes(self):
class TestBodyNormalization:
- def test_if_body_normalized_to_block(self):
+ def test_if_body_normalized_to_block(self) -> None:
code = 'if (x) y();'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('if (x) { y(); }')
- def test_while_body_normalized_to_block(self):
+ def test_while_body_normalized_to_block(self) -> None:
code = 'while (x) y();'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('while (x) { y(); }')
- def test_for_body_normalized_to_block(self):
+ def test_for_body_normalized_to_block(self) -> None:
code = 'for (;;) y();'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
@@ -94,13 +94,13 @@ def test_for_body_normalized_to_block(self):
class TestIfBranchNormalization:
- def test_alternate_normalized(self):
+ def test_alternate_normalized(self) -> None:
code = 'if (x) { y(); } else z();'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('if (x) { y(); } else { z(); }')
- def test_else_if_not_wrapped(self):
+ def test_else_if_not_wrapped(self) -> None:
code = 'if (x) { y(); } else if (z) { w(); }'
result, changed = roundtrip(code, SequenceSplitter)
# else-if should not be wrapped in a block
@@ -108,38 +108,38 @@ def test_else_if_not_wrapped(self):
class TestNoChangeWhenNothingToSplit:
- def test_single_expression_statement(self):
+ def test_single_expression_statement(self) -> None:
code = 'a();'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is False
assert normalize(result) == normalize('a();')
- def test_single_var_declaration(self):
+ def test_single_var_declaration(self) -> None:
code = 'var a = 1;'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is False
assert normalize(result) == normalize('var a = 1;')
- def test_already_block_if(self):
+ def test_already_block_if(self) -> None:
code = 'if (x) { y(); }'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is False
assert normalize(result) == normalize('if (x) { y(); }')
- def test_empty_program(self):
+ def test_empty_program(self) -> None:
code = ''
result, changed = roundtrip(code, SequenceSplitter)
assert changed is False
class TestAwaitWrappedSequenceSplitting:
- def test_splits_await_sequence_in_var_init(self):
+ def test_splits_await_sequence_in_var_init(self) -> None:
code = 'async function f() { var x = await (a, b, expr()); }'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
assert normalize(result) == normalize('async function f() { a; b; var x = await expr(); }')
- def test_splits_await_two_element_sequence(self):
+ def test_splits_await_two_element_sequence(self) -> None:
code = 'async function f() { var x = await (a, expr()); }'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is True
@@ -149,14 +149,14 @@ def test_splits_await_two_element_sequence(self):
class TestSequenceSplitterEdgeCases:
"""Tests for uncovered edge cases."""
- def test_non_dict_in_body_arrays(self):
+ def test_non_dict_in_body_arrays(self) -> None:
"""Line 59: Non-dict in _split_in_body_arrays should be skipped gracefully."""
# This is handled internally; just ensure no crash with normal code
code = 'a(); b();'
result, changed = roundtrip(code, SequenceSplitter)
assert isinstance(changed, bool)
- def test_return_indirect_call(self):
+ def test_return_indirect_call(self) -> None:
"""Line 107: ReturnStatement with indirect call: return (0, fn)(args)."""
code = 'function f() { return (0, g)("hello"); }'
result, changed = roundtrip(code, SequenceSplitter)
@@ -165,7 +165,7 @@ def test_return_indirect_call(self):
assert '0' in norm
assert 'g("hello")' in norm
- def test_assignment_indirect_call(self):
+ def test_assignment_indirect_call(self) -> None:
"""Line 113: Assignment with indirect call: x = (0, fn)(args)."""
code = 'var x; x = (0, g)("hello");'
result, changed = roundtrip(code, SequenceSplitter)
@@ -174,21 +174,21 @@ def test_assignment_indirect_call(self):
assert '0' in norm
assert 'g("hello")' in norm
- def test_non_dict_statement_in_process_stmt_array(self):
+ def test_non_dict_statement_in_process_stmt_array(self) -> None:
"""Lines 123-124: Non-dict statement in _process_stmt_array should be skipped."""
# Internal handling; verify no crash with various patterns
code = 'var a = 1;'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is False
- def test_non_dict_init_in_single_declarator(self):
+ def test_non_dict_init_in_single_declarator(self) -> None:
"""Line 189: _try_split_single_declarator_init with non-dict init."""
# Variable with no init (init is None, not a dict)
code = 'var x;'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is False
- def test_sequence_callee_with_single_expression(self):
+ def test_sequence_callee_with_single_expression(self) -> None:
"""Line 96: SequenceExpression callee with <=1 expression should not be extracted."""
# Construct a case where there's a single-element sequence callee
# This is a degenerate case; normal JS won't produce it, but we can test
@@ -197,28 +197,28 @@ def test_sequence_callee_with_single_expression(self):
result, changed = roundtrip(code, SequenceSplitter)
assert changed is False
- def test_direct_sequence_init_single_expression(self):
+ def test_direct_sequence_init_single_expression(self) -> None:
"""Line 195: Direct SequenceExpression init with <=1 expression should return None."""
# A single-element sequence is degenerate; just test normal single-init var
code = 'var x = 1;'
result, changed = roundtrip(code, SequenceSplitter)
assert changed is False
- def test_await_wrapped_sequence_single_expression(self):
+ def test_await_wrapped_sequence_single_expression(self) -> None:
"""Line 209: Await-wrapped sequence with <=1 expression should return None."""
code = 'async function f() { var x = await expr(); }'
result, changed = roundtrip(code, SequenceSplitter)
# No sequence to split, just a simple await
assert normalize(result) == normalize('async function f() { var x = await expr(); }')
- def test_extract_from_call_non_dict(self):
+ def test_extract_from_call_non_dict(self) -> None:
"""Line 85: Non-dict in extract_from_call should be skipped."""
# When expression is not a dict (e.g., expression missing), no crash
code = ';'
result, changed = roundtrip(code, SequenceSplitter)
assert isinstance(changed, bool)
- def test_var_declaration_indirect_call_in_init(self):
+ def test_var_declaration_indirect_call_in_init(self) -> None:
"""VariableDeclaration path for extracting indirect call prefixes."""
code = 'var x = (0, fn)("hello");'
result, changed = roundtrip(code, SequenceSplitter)
@@ -231,11 +231,8 @@ def test_var_declaration_indirect_call_in_init(self):
class TestSequenceSplitterDirectASTCoverage:
"""Tests using direct AST manipulation to hit remaining uncovered lines."""
- def test_sequence_callee_with_single_expression(self):
+ def test_sequence_callee_with_single_expression(self) -> None:
"""Line 95-96: SequenceExpression callee with <=1 expression."""
- from pyjsclear.generator import generate
- from pyjsclear.parser import parse
-
ast = parse('fn("hello");')
# Manually wrap callee in a SequenceExpression with 1 element
call_expr = ast['body'][0]['expression']
@@ -249,10 +246,8 @@ def test_sequence_callee_with_single_expression(self):
# Single-element sequence callee should not be extracted
# (but body normalization may still trigger changes)
- def test_single_element_await_sequence(self):
+ def test_single_element_await_sequence(self) -> None:
"""Line 208-209: single-element await sequence returns None."""
- from pyjsclear.parser import parse
-
ast = parse('async function f() { var x = await expr(); }')
# Find the var declaration inside the function body
func_body = ast['body'][0]['body']['body']
diff --git a/tests/unit/transforms/single_use_vars_test.py b/tests/unit/transforms/single_use_vars_test.py
index e229d33..cb4bbc5 100644
--- a/tests/unit/transforms/single_use_vars_test.py
+++ b/tests/unit/transforms/single_use_vars_test.py
@@ -7,7 +7,7 @@
class TestRequireInlining:
"""Tests for single-use require() inlining."""
- def test_simple_require_inlined(self):
+ def test_simple_require_inlined(self) -> None:
code = '''
function f() {
const x = require("fs");
@@ -19,7 +19,7 @@ def test_simple_require_inlined(self):
assert 'require("fs").readFileSync' in result
assert 'const x' not in result
- def test_require_member_access(self):
+ def test_require_member_access(self) -> None:
code = '''
function f() {
const proc = require("process");
@@ -30,7 +30,7 @@ def test_require_member_access(self):
assert changed is True
assert 'require("process").env.HOME' in result
- def test_var_require_inlined(self):
+ def test_var_require_inlined(self) -> None:
code = '''
function f() {
var x = require("path");
@@ -46,7 +46,7 @@ def test_var_require_inlined(self):
class TestExpressionInlining:
"""Tests for single-use non-require expression inlining."""
- def test_property_access_inlined(self):
+ def test_property_access_inlined(self) -> None:
code = '''
function f() {
const x = obj.prop;
@@ -58,7 +58,7 @@ def test_property_access_inlined(self):
assert 'obj.prop.foo' in result
assert 'const x' not in result
- def test_method_call_inlined(self):
+ def test_method_call_inlined(self) -> None:
code = '''
function f(arr) {
const x = Buffer.from(arr);
@@ -69,7 +69,7 @@ def test_method_call_inlined(self):
assert changed is True
assert 'Buffer.from(arr).toString()' in result
- def test_new_expression_inlined(self):
+ def test_new_expression_inlined(self) -> None:
code = '''
function f() {
const d = new Date();
@@ -80,7 +80,7 @@ def test_new_expression_inlined(self):
assert changed is True
assert 'new Date().getTime()' in result
- def test_string_literal_inlined(self):
+ def test_string_literal_inlined(self) -> None:
code = '''
function f() {
const url = "https://example.com";
@@ -92,7 +92,7 @@ def test_string_literal_inlined(self):
assert 'fetch("https://example.com")' in result
assert 'const url' not in result
- def test_simple_call_inlined(self):
+ def test_simple_call_inlined(self) -> None:
code = '''
function f(x) {
const n = parseInt(x);
@@ -107,7 +107,7 @@ def test_simple_call_inlined(self):
class TestNoInlining:
"""Tests where inlining should NOT occur."""
- def test_multi_use_not_inlined(self):
+ def test_multi_use_not_inlined(self) -> None:
code = '''
function f() {
const fs = require("fs");
@@ -118,7 +118,7 @@ def test_multi_use_not_inlined(self):
result, changed = roundtrip(code, SingleUseVarInliner)
assert changed is False
- def test_reassigned_var_not_inlined(self):
+ def test_reassigned_var_not_inlined(self) -> None:
code = '''
function f() {
let x = require("fs");
@@ -129,11 +129,11 @@ def test_reassigned_var_not_inlined(self):
result, changed = roundtrip(code, SingleUseVarInliner)
assert changed is False
- def test_no_init_returns_false(self):
+ def test_no_init_returns_false(self) -> None:
result, changed = roundtrip('var x;', SingleUseVarInliner)
assert changed is False
- def test_large_init_not_inlined(self):
+ def test_large_init_not_inlined(self) -> None:
"""Init expressions with too many AST nodes should not be inlined."""
# Build a deeply nested expression that exceeds the node limit
code = '''
@@ -145,7 +145,7 @@ def test_large_init_not_inlined(self):
result, changed = roundtrip(code, SingleUseVarInliner)
assert changed is False
- def test_assignment_target_not_inlined(self):
+ def test_assignment_target_not_inlined(self) -> None:
code = '''
function f() {
const x = obj.prop;
@@ -155,7 +155,7 @@ def test_assignment_target_not_inlined(self):
result, changed = roundtrip(code, SingleUseVarInliner)
assert changed is False
- def test_mutated_member_not_inlined(self):
+ def test_mutated_member_not_inlined(self) -> None:
"""var x = {}; x[key] = val; should NOT inline to {}[key] = val."""
code = '''
function f() {
@@ -167,7 +167,7 @@ def test_mutated_member_not_inlined(self):
assert changed is False
assert 'var x' in result
- def test_mutated_dot_member_not_inlined(self):
+ def test_mutated_dot_member_not_inlined(self) -> None:
"""var x = {}; x.foo = val; should NOT inline."""
code = '''
function f() {
@@ -182,7 +182,7 @@ def test_mutated_dot_member_not_inlined(self):
class TestNestedContexts:
"""Tests for inlining in nested scopes (classes, closures)."""
- def test_class_method_inlined(self):
+ def test_class_method_inlined(self) -> None:
code = '''
var Cls = class {
static method() {
@@ -195,7 +195,7 @@ def test_class_method_inlined(self):
assert changed is True
assert 'require("process").env.HOME' in result
- def test_nested_function_inlined(self):
+ def test_nested_function_inlined(self) -> None:
code = '''
function outer() {
function inner() {
@@ -208,7 +208,7 @@ def test_nested_function_inlined(self):
assert changed is True
assert 'require("fs").existsSync' in result
- def test_multiple_scopes_inlined(self):
+ def test_multiple_scopes_inlined(self) -> None:
code = '''
function a() {
const x = obj.foo;
diff --git a/tests/unit/transforms/string_revealer_test.py b/tests/unit/transforms/string_revealer_test.py
index 7931105..2f1ad00 100644
--- a/tests/unit/transforms/string_revealer_test.py
+++ b/tests/unit/transforms/string_revealer_test.py
@@ -11,6 +11,8 @@
from pyjsclear.transforms.string_revealer import _js_parse_int
from pyjsclear.transforms.string_revealer import _resolve_arg_value
from pyjsclear.transforms.string_revealer import _resolve_string_arg
+from pyjsclear.utils.string_decoders import Base64StringDecoder
+from pyjsclear.utils.string_decoders import BasicStringDecoder
from tests.unit.conftest import normalize
from tests.unit.conftest import parse_expr
from tests.unit.conftest import roundtrip
@@ -24,59 +26,59 @@
class TestEvalNumeric:
"""Tests for the _eval_numeric helper."""
- def test_integer_literal(self):
+ def test_integer_literal(self) -> None:
node = parse_expr('42')
assert _eval_numeric(node) == 42
- def test_float_literal(self):
+ def test_float_literal(self) -> None:
node = parse_expr('3.14')
assert _eval_numeric(node) == pytest.approx(3.14)
- def test_unary_negative(self):
+ def test_unary_negative(self) -> None:
node = parse_expr('-7')
assert _eval_numeric(node) == -7
- def test_unary_positive(self):
+ def test_unary_positive(self) -> None:
node = parse_expr('+5')
assert _eval_numeric(node) == 5
- def test_binary_addition(self):
+ def test_binary_addition(self) -> None:
node = parse_expr('3 + 4')
assert _eval_numeric(node) == 7
- def test_binary_subtraction(self):
+ def test_binary_subtraction(self) -> None:
node = parse_expr('10 - 3')
assert _eval_numeric(node) == 7
- def test_binary_multiplication(self):
+ def test_binary_multiplication(self) -> None:
node = parse_expr('6 * 7')
assert _eval_numeric(node) == 42
- def test_binary_division(self):
+ def test_binary_division(self) -> None:
node = parse_expr('20 / 4')
assert _eval_numeric(node) == 5.0
- def test_division_by_zero_returns_none(self):
+ def test_division_by_zero_returns_none(self) -> None:
node = parse_expr('1 / 0')
assert _eval_numeric(node) is None
- def test_nested_expression(self):
+ def test_nested_expression(self) -> None:
node = parse_expr('(2 + 3) * 4')
assert _eval_numeric(node) == 20
- def test_string_literal_returns_none(self):
+ def test_string_literal_returns_none(self) -> None:
node = parse_expr('"hello"')
assert _eval_numeric(node) is None
- def test_identifier_returns_none(self):
+ def test_identifier_returns_none(self) -> None:
node = parse_expr('x')
assert _eval_numeric(node) is None
- def test_non_dict_returns_none(self):
+ def test_non_dict_returns_none(self) -> None:
assert _eval_numeric(None) is None
assert _eval_numeric(42) is None
- def test_unsupported_operator_returns_none(self):
+ def test_unsupported_operator_returns_none(self) -> None:
node = parse_expr('2 << 3')
assert _eval_numeric(node) is None
@@ -89,34 +91,34 @@ def test_unsupported_operator_returns_none(self):
class TestJsParseInt:
"""Tests for the _js_parse_int helper."""
- def test_pure_integer(self):
+ def test_pure_integer(self) -> None:
assert _js_parse_int('123') == 123
- def test_leading_digits_with_trailing_chars(self):
+ def test_leading_digits_with_trailing_chars(self) -> None:
assert _js_parse_int('12abc') == 12
- def test_no_leading_digits_returns_nan(self):
+ def test_no_leading_digits_returns_nan(self) -> None:
result = _js_parse_int('abc')
assert math.isnan(result)
- def test_negative_number(self):
+ def test_negative_number(self) -> None:
assert _js_parse_int('-42') == -42
- def test_positive_sign(self):
+ def test_positive_sign(self) -> None:
assert _js_parse_int('+99') == 99
- def test_whitespace_stripped(self):
+ def test_whitespace_stripped(self) -> None:
assert _js_parse_int(' 56 ') == 56
- def test_empty_string_returns_nan(self):
+ def test_empty_string_returns_nan(self) -> None:
result = _js_parse_int('')
assert math.isnan(result)
- def test_non_string_returns_nan(self):
+ def test_non_string_returns_nan(self) -> None:
result = _js_parse_int(42)
assert math.isnan(result)
- def test_none_returns_nan(self):
+ def test_none_returns_nan(self) -> None:
result = _js_parse_int(None)
assert math.isnan(result)
@@ -129,39 +131,39 @@ def test_none_returns_nan(self):
class TestWrapperInfo:
"""Tests for WrapperInfo.get_effective_index."""
- def test_basic_offset(self):
+ def test_basic_offset(self) -> None:
info = WrapperInfo('w', param_index=0, wrapper_offset=10, func_node={})
assert info.get_effective_index([5]) == 15
- def test_negative_offset(self):
+ def test_negative_offset(self) -> None:
info = WrapperInfo('w', param_index=0, wrapper_offset=-3, func_node={})
assert info.get_effective_index([10]) == 7
- def test_zero_offset(self):
+ def test_zero_offset(self) -> None:
info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={})
assert info.get_effective_index([7]) == 7
- def test_param_index_selects_correct_arg(self):
+ def test_param_index_selects_correct_arg(self) -> None:
info = WrapperInfo('w', param_index=1, wrapper_offset=100, func_node={})
assert info.get_effective_index(['ignored', 5]) == 105
- def test_param_index_out_of_bounds_returns_none(self):
+ def test_param_index_out_of_bounds_returns_none(self) -> None:
info = WrapperInfo('w', param_index=2, wrapper_offset=0, func_node={})
assert info.get_effective_index([1]) is None
- def test_non_numeric_arg_returns_none(self):
+ def test_non_numeric_arg_returns_none(self) -> None:
info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={})
assert info.get_effective_index(['not_a_number']) is None
- def test_get_key_with_key_param(self):
+ def test_get_key_with_key_param(self) -> None:
info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={}, key_param_index=1)
assert info.get_key([10, 'secret']) == 'secret'
- def test_get_key_without_key_param(self):
+ def test_get_key_without_key_param(self) -> None:
info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={})
assert info.get_key([10]) is None
- def test_get_key_index_out_of_bounds(self):
+ def test_get_key_index_out_of_bounds(self) -> None:
info = WrapperInfo('w', param_index=0, wrapper_offset=0, func_node={}, key_param_index=5)
assert info.get_key([10]) is None
@@ -174,14 +176,14 @@ def test_get_key_index_out_of_bounds(self):
class TestDirectArrays:
"""Tests for direct array declaration and replacement."""
- def test_basic_direct_array(self):
+ def test_basic_direct_array(self) -> None:
js = 'var arr = ["hello", "world"]; x(arr[0]); y(arr[1]);'
code, changed = roundtrip(js, StringRevealer)
assert changed is True
assert 'x("hello")' in code
assert 'y("world")' in code
- def test_direct_array_multiple_accesses(self):
+ def test_direct_array_multiple_accesses(self) -> None:
js = 'var arr = ["a", "b", "c"]; f(arr[0]); g(arr[1]); h(arr[2]);'
code, changed = roundtrip(js, StringRevealer)
assert changed is True
@@ -189,30 +191,30 @@ def test_direct_array_multiple_accesses(self):
assert '"b"' in code
assert '"c"' in code
- def test_direct_array_out_of_bounds_no_replacement(self):
+ def test_direct_array_out_of_bounds_no_replacement(self) -> None:
js = 'var arr = ["hello"]; x(arr[5]);'
code, changed = roundtrip(js, StringRevealer)
assert changed is False
assert 'arr[5]' in code
- def test_direct_array_non_numeric_index_no_replacement(self):
+ def test_direct_array_non_numeric_index_no_replacement(self) -> None:
js = 'var arr = ["hello", "world"]; x(arr[y]);'
code, changed = roundtrip(js, StringRevealer)
assert changed is False
assert 'arr[y]' in code
- def test_direct_array_single_element(self):
+ def test_direct_array_single_element(self) -> None:
js = 'var arr = ["only"]; x(arr[0]);'
code, changed = roundtrip(js, StringRevealer)
assert changed is True
assert '"only"' in code
- def test_direct_array_preserves_non_array_vars(self):
+ def test_direct_array_preserves_non_array_vars(self) -> None:
js = 'var x = 42; console.log(x);'
code, changed = roundtrip(js, StringRevealer)
assert changed is False
- def test_mixed_element_array_not_replaced(self):
+ def test_mixed_element_array_not_replaced(self) -> None:
js = 'var arr = ["hello", 42]; x(arr[0]);'
code, changed = roundtrip(js, StringRevealer)
assert changed is False
@@ -227,12 +229,12 @@ def test_mixed_element_array_not_replaced(self):
class TestNoStringArrays:
"""Tests for code with no string array patterns."""
- def test_empty_program_returns_false(self):
+ def test_empty_program_returns_false(self) -> None:
js = ''
_, changed = roundtrip(js, StringRevealer)
assert changed is False
- def test_non_string_array_returns_false(self):
+ def test_non_string_array_returns_false(self) -> None:
js = 'var arr = [1, 2, 3]; x(arr[0]);'
_, changed = roundtrip(js, StringRevealer)
assert changed is False
@@ -246,7 +248,7 @@ def test_non_string_array_returns_false(self):
class TestObfuscatorIoShortArray:
"""Short string arrays (< 5 elements) should not trigger obfuscator.io strategy."""
- def test_short_array_function_decoded(self):
+ def test_short_array_function_decoded(self) -> None:
# Arrays with >= 2 elements in obfuscator.io function pattern are decoded.
js = """
function _0xArr() {
@@ -274,28 +276,28 @@ def test_short_array_function_decoded(self):
class TestApplyArith:
"""Tests for the _apply_arith helper."""
- def test_addition(self):
+ def test_addition(self) -> None:
assert _apply_arith('+', 3, 4) == 7
- def test_subtraction(self):
+ def test_subtraction(self) -> None:
assert _apply_arith('-', 10, 3) == 7
- def test_multiplication(self):
+ def test_multiplication(self) -> None:
assert _apply_arith('*', 6, 7) == 42
- def test_division(self):
+ def test_division(self) -> None:
assert _apply_arith('/', 20, 4) == 5.0
- def test_division_by_zero(self):
+ def test_division_by_zero(self) -> None:
assert _apply_arith('/', 1, 0) is None
- def test_modulo(self):
+ def test_modulo(self) -> None:
assert _apply_arith('%', 10, 3) == 1
- def test_modulo_by_zero(self):
+ def test_modulo_by_zero(self) -> None:
assert _apply_arith('%', 10, 0) is None
- def test_unsupported_operator_returns_none(self):
+ def test_unsupported_operator_returns_none(self) -> None:
assert _apply_arith('**', 2, 3) is None
assert _apply_arith('<<', 2, 3) is None
assert _apply_arith('>>', 8, 1) is None
@@ -310,41 +312,41 @@ def test_unsupported_operator_returns_none(self):
class TestCollectObjectLiterals:
"""Tests for the _collect_object_literals helper."""
- def test_numeric_properties(self):
+ def test_numeric_properties(self) -> None:
ast = parse('var obj = {a: 0x1b1, b: 42};')
result = _collect_object_literals(ast)
assert ('obj', 'a') in result
assert result[('obj', 'a')] == 0x1B1
assert result[('obj', 'b')] == 42
- def test_string_properties(self):
+ def test_string_properties(self) -> None:
ast = parse('var obj = {a: "hello", b: "world"};')
result = _collect_object_literals(ast)
assert result[('obj', 'a')] == 'hello'
assert result[('obj', 'b')] == 'world'
- def test_mixed_properties(self):
+ def test_mixed_properties(self) -> None:
ast = parse('var obj = {a: 0x1b1, b: "hello"};')
result = _collect_object_literals(ast)
assert result[('obj', 'a')] == 0x1B1
assert result[('obj', 'b')] == 'hello'
- def test_string_key_properties(self):
+ def test_string_key_properties(self) -> None:
ast = parse('var obj = {"myKey": 42};')
result = _collect_object_literals(ast)
assert result[('obj', 'myKey')] == 42
- def test_empty_object(self):
+ def test_empty_object(self) -> None:
ast = parse('var obj = {};')
result = _collect_object_literals(ast)
assert len(result) == 0
- def test_non_object_init_ignored(self):
+ def test_non_object_init_ignored(self) -> None:
ast = parse('var x = 42;')
result = _collect_object_literals(ast)
assert len(result) == 0
- def test_multiple_objects(self):
+ def test_multiple_objects(self) -> None:
ast = parse('var a = {x: 1}; var b = {y: 2};')
result = _collect_object_literals(ast)
assert result[('a', 'x')] == 1
@@ -359,23 +361,23 @@ def test_multiple_objects(self):
class TestResolveArgValue:
"""Tests for the _resolve_arg_value helper."""
- def test_numeric_literal(self):
+ def test_numeric_literal(self) -> None:
arg = parse_expr('42')
assert _resolve_arg_value(arg, {}) == 42
- def test_string_hex_literal(self):
+ def test_string_hex_literal(self) -> None:
arg = parse_expr('"0x1a"')
assert _resolve_arg_value(arg, {}) == 0x1A
- def test_string_decimal_literal(self):
+ def test_string_decimal_literal(self) -> None:
arg = parse_expr('"10"')
assert _resolve_arg_value(arg, {}) == 10
- def test_string_non_numeric_returns_none(self):
+ def test_string_non_numeric_returns_none(self) -> None:
arg = parse_expr('"hello"')
assert _resolve_arg_value(arg, {}) is None
- def test_member_expression_numeric(self):
+ def test_member_expression_numeric(self) -> None:
obj_literals = {('obj', 'x'): 0x42}
arg = {
'type': 'MemberExpression',
@@ -385,7 +387,7 @@ def test_member_expression_numeric(self):
}
assert _resolve_arg_value(arg, obj_literals) == 0x42
- def test_member_expression_string_numeric(self):
+ def test_member_expression_string_numeric(self) -> None:
obj_literals = {('obj', 'x'): '0x10'}
arg = {
'type': 'MemberExpression',
@@ -395,7 +397,7 @@ def test_member_expression_string_numeric(self):
}
assert _resolve_arg_value(arg, obj_literals) == 0x10
- def test_member_expression_string_non_numeric(self):
+ def test_member_expression_string_non_numeric(self) -> None:
obj_literals = {('obj', 'x'): 'hello'}
arg = {
'type': 'MemberExpression',
@@ -405,7 +407,7 @@ def test_member_expression_string_non_numeric(self):
}
assert _resolve_arg_value(arg, obj_literals) is None
- def test_member_expression_unknown_key(self):
+ def test_member_expression_unknown_key(self) -> None:
obj_literals = {('obj', 'x'): 42}
arg = {
'type': 'MemberExpression',
@@ -415,7 +417,7 @@ def test_member_expression_unknown_key(self):
}
assert _resolve_arg_value(arg, obj_literals) is None
- def test_computed_member_expression_not_resolved(self):
+ def test_computed_member_expression_not_resolved(self) -> None:
obj_literals = {('obj', 'x'): 42}
arg = {
'type': 'MemberExpression',
@@ -425,7 +427,7 @@ def test_computed_member_expression_not_resolved(self):
}
assert _resolve_arg_value(arg, obj_literals) is None
- def test_identifier_returns_none(self):
+ def test_identifier_returns_none(self) -> None:
arg = parse_expr('x')
assert _resolve_arg_value(arg, {}) is None
@@ -438,11 +440,11 @@ def test_identifier_returns_none(self):
class TestResolveStringArg:
"""Tests for the _resolve_string_arg helper."""
- def test_string_literal(self):
+ def test_string_literal(self) -> None:
arg = parse_expr('"hello"')
assert _resolve_string_arg(arg, {}) == 'hello'
- def test_member_expression_string(self):
+ def test_member_expression_string(self) -> None:
obj_literals = {('obj', 'key'): 'secret'}
arg = {
'type': 'MemberExpression',
@@ -452,7 +454,7 @@ def test_member_expression_string(self):
}
assert _resolve_string_arg(arg, obj_literals) == 'secret'
- def test_member_expression_numeric_returns_none(self):
+ def test_member_expression_numeric_returns_none(self) -> None:
obj_literals = {('obj', 'key'): 42}
arg = {
'type': 'MemberExpression',
@@ -462,15 +464,15 @@ def test_member_expression_numeric_returns_none(self):
}
assert _resolve_string_arg(arg, obj_literals) is None
- def test_numeric_literal_returns_none(self):
+ def test_numeric_literal_returns_none(self) -> None:
arg = parse_expr('42')
assert _resolve_string_arg(arg, {}) is None
- def test_identifier_returns_none(self):
+ def test_identifier_returns_none(self) -> None:
arg = parse_expr('x')
assert _resolve_string_arg(arg, {}) is None
- def test_member_expression_unknown_returns_none(self):
+ def test_member_expression_unknown_returns_none(self) -> None:
arg = {
'type': 'MemberExpression',
'computed': False,
@@ -488,7 +490,7 @@ def test_member_expression_unknown_returns_none(self):
class TestVarArrayPattern:
"""Tests for var-based string array with rotation and decoder (Strategy 2b)."""
- def test_var_array_with_rotation_and_decoder(self):
+ def test_var_array_with_rotation_and_decoder(self) -> None:
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar', 'baz', 'qux'];
(function(arr, count) {
@@ -514,7 +516,7 @@ def test_var_array_with_rotation_and_decoder(self):
# After rotation 2: ['foo', 'bar', 'baz', 'qux', 'hello', 'world']
assert '"foo"' in code
- def test_var_array_without_rotation(self):
+ def test_var_array_without_rotation(self) -> None:
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
var _0xdec = function(a) {
@@ -530,7 +532,7 @@ def test_var_array_without_rotation(self):
assert '"hello"' in code
assert '"world"' in code
- def test_var_array_with_offset(self):
+ def test_var_array_with_offset(self) -> None:
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
var _0xdec = function(a) {
@@ -546,7 +548,7 @@ def test_var_array_with_offset(self):
assert '"hello"' in code
assert '"world"' in code
- def test_var_array_too_short_ignored(self):
+ def test_var_array_too_short_ignored(self) -> None:
# Arrays with < 3 elements should not match _find_var_string_array
js = """
var _0xarr = ['hello', 'world'];
@@ -571,7 +573,7 @@ def test_var_array_too_short_ignored(self):
class TestObfuscatorIoFullPattern:
"""Tests for the full obfuscator.io string array pattern."""
- def test_basic_obfuscator_io_pattern(self):
+ def test_basic_obfuscator_io_pattern(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -591,7 +593,7 @@ def test_basic_obfuscator_io_pattern(self):
assert '"hello"' in code
assert '"world"' in code
- def test_obfuscator_io_with_offset(self):
+ def test_obfuscator_io_with_offset(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -611,7 +613,7 @@ def test_obfuscator_io_with_offset(self):
assert '"hello"' in code
assert '"world"' in code
- def test_obfuscator_io_multiple_calls(self):
+ def test_obfuscator_io_multiple_calls(self) -> None:
js = """
function _0xArr() {
var a = ['alpha', 'beta', 'gamma', 'delta', 'epsilon'];
@@ -631,7 +633,7 @@ def test_obfuscator_io_multiple_calls(self):
assert '"alpha"' in code
assert '"epsilon"' in code
- def test_obfuscator_io_removes_infrastructure(self):
+ def test_obfuscator_io_removes_infrastructure(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -650,7 +652,7 @@ def test_obfuscator_io_removes_infrastructure(self):
assert '_0xArr' not in code
assert '_0xDec' not in code
- def test_obfuscator_io_with_wrapper_function(self):
+ def test_obfuscator_io_with_wrapper_function(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -671,7 +673,7 @@ def test_obfuscator_io_with_wrapper_function(self):
assert changed is True
assert '"world"' in code
- def test_obfuscator_io_wrapper_with_key_param(self):
+ def test_obfuscator_io_wrapper_with_key_param(self) -> None:
# Wrapper that passes two args to decoder (index + key)
js = """
function _0xArr() {
@@ -695,7 +697,7 @@ def test_obfuscator_io_wrapper_with_key_param(self):
assert '"hello"' in code
assert '"world"' in code
- def test_obfuscator_io_with_decoder_alias(self):
+ def test_obfuscator_io_with_decoder_alias(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -716,7 +718,7 @@ def test_obfuscator_io_with_decoder_alias(self):
assert '"hello"' in code
assert '"world"' in code
- def test_obfuscator_io_with_transitive_alias(self):
+ def test_obfuscator_io_with_transitive_alias(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -745,7 +747,7 @@ def test_obfuscator_io_with_transitive_alias(self):
class TestObfuscatorIoRotation:
"""Tests for the obfuscator.io rotation IIFE pattern."""
- def test_rotation_with_while_loop(self):
+ def test_rotation_with_while_loop(self) -> None:
js = """
function _0xArr() {
var a = ['100', 'hello', 'world', 'foo', 'bar', 'baz'];
@@ -789,7 +791,7 @@ def test_rotation_with_while_loop(self):
class TestWrapperAnalysis:
"""Tests for wrapper function analysis (_analyze_wrapper_expr)."""
- def test_var_function_expression_wrapper(self):
+ def test_var_function_expression_wrapper(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -810,7 +812,7 @@ def test_var_function_expression_wrapper(self):
assert changed is True
assert '"foo"' in code
- def test_arrow_function_wrapper(self):
+ def test_arrow_function_wrapper(self) -> None:
# ArrowFunctionExpression with block body as wrapper
js = """
function _0xArr() {
@@ -841,7 +843,7 @@ def test_arrow_function_wrapper(self):
class TestExtractWrapperOffset:
"""Tests for wrapper offset extraction patterns."""
- def test_wrapper_with_subtraction_offset(self):
+ def test_wrapper_with_subtraction_offset(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -862,7 +864,7 @@ def test_wrapper_with_subtraction_offset(self):
assert changed is True
assert '"hello"' in code
- def test_wrapper_with_second_param_index(self):
+ def test_wrapper_with_second_param_index(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -892,7 +894,7 @@ def test_wrapper_with_second_param_index(self):
class TestObjectLiteralResolution:
"""Tests for resolving member expressions via object literals."""
- def test_decoder_call_with_object_member_arg(self):
+ def test_decoder_call_with_object_member_arg(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -911,7 +913,7 @@ def test_decoder_call_with_object_member_arg(self):
assert changed is True
assert '"hello"' in code
- def test_wrapper_call_with_object_member_arg(self):
+ def test_wrapper_call_with_object_member_arg(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -942,20 +944,20 @@ def test_wrapper_call_with_object_member_arg(self):
class TestEvalNumericModulo:
"""Additional _eval_numeric tests for modulo operator."""
- def test_modulo(self):
+ def test_modulo(self) -> None:
node = parse_expr('10 % 3')
assert _eval_numeric(node) == 1
- def test_modulo_by_zero(self):
+ def test_modulo_by_zero(self) -> None:
node = parse_expr('10 % 0')
assert _eval_numeric(node) is None
- def test_unary_unsupported_operator(self):
+ def test_unary_unsupported_operator(self) -> None:
# The ~ operator is unsupported
node = parse_expr('~5')
assert _eval_numeric(node) is None
- def test_deeply_nested_expression(self):
+ def test_deeply_nested_expression(self) -> None:
node = parse_expr('(1 + 2) * (3 - 1) + 4 / 2')
assert _eval_numeric(node) == 8.0
@@ -968,7 +970,7 @@ def test_deeply_nested_expression(self):
class TestCollectRotationLocals:
"""Tests for _collect_rotation_locals static method."""
- def test_collects_object_from_iife(self):
+ def test_collects_object_from_iife(self) -> None:
ast = parse(
"""
(function(arr, stop) {
@@ -991,7 +993,7 @@ def test_collects_object_from_iife(self):
assert result['J']['S'] == 0xA7
assert result['J']['D'] == 'M8Y&'
- def test_empty_iife_returns_empty(self):
+ def test_empty_iife_returns_empty(self) -> None:
ast = parse(
"""
(function() {
@@ -1012,7 +1014,7 @@ def test_empty_iife_returns_empty(self):
class TestExpressionFromTryBlock:
"""Tests for _expression_from_try_block static method."""
- def test_variable_declaration(self):
+ def test_variable_declaration(self) -> None:
ast = parse('var x = 42;')
stmt = ast['body'][0]
result = StringRevealer._expression_from_try_block(stmt)
@@ -1020,7 +1022,7 @@ def test_variable_declaration(self):
assert result.get('type') == 'Literal'
assert result.get('value') == 42
- def test_assignment_expression(self):
+ def test_assignment_expression(self) -> None:
ast = parse('x = 42;')
stmt = ast['body'][0]
result = StringRevealer._expression_from_try_block(stmt)
@@ -1028,13 +1030,13 @@ def test_assignment_expression(self):
assert result.get('type') == 'Literal'
assert result.get('value') == 42
- def test_non_matching_returns_none(self):
+ def test_non_matching_returns_none(self) -> None:
ast = parse('if (true) {}')
stmt = ast['body'][0]
result = StringRevealer._expression_from_try_block(stmt)
assert result is None
- def test_expression_statement_non_assignment(self):
+ def test_expression_statement_non_assignment(self) -> None:
ast = parse('foo();')
stmt = ast['body'][0]
result = StringRevealer._expression_from_try_block(stmt)
@@ -1049,7 +1051,7 @@ def test_expression_statement_non_assignment(self):
class TestDirectArrayEdgeCases:
"""Additional edge case tests for direct array access replacement."""
- def test_direct_array_in_function_scope(self):
+ def test_direct_array_in_function_scope(self) -> None:
# Direct array strategy only processes the root scope bindings,
# so arrays inside function scopes are not replaced.
js = """
@@ -1061,7 +1063,7 @@ def test_direct_array_in_function_scope(self):
code, changed = roundtrip(js, StringRevealer)
assert changed is False
- def test_direct_array_not_used_as_member(self):
+ def test_direct_array_not_used_as_member(self) -> None:
# Using arr as a standalone identifier (not arr[N]) should not trigger
js = 'var arr = ["hello"]; f(arr);'
code, changed = roundtrip(js, StringRevealer)
@@ -1076,7 +1078,7 @@ def test_direct_array_not_used_as_member(self):
class TestVarPatternWithWrappersAndAliases:
"""Tests for var-based array with wrappers and alias resolution."""
- def test_var_array_with_decoder_alias(self):
+ def test_var_array_with_decoder_alias(self) -> None:
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
var _0xdec = function(a) {
@@ -1091,7 +1093,7 @@ def test_var_array_with_decoder_alias(self):
assert changed is True
assert '"hello"' in code
- def test_var_array_with_wrapper(self):
+ def test_var_array_with_wrapper(self) -> None:
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
var _0xdec = function(a) {
@@ -1117,7 +1119,7 @@ def test_var_array_with_wrapper(self):
class TestHexStringResolution:
"""Tests for hex string resolution in decoder calls."""
- def test_hex_string_arg_to_decoder(self):
+ def test_hex_string_arg_to_decoder(self) -> None:
js = """
function _0xArr() {
var a = ['hello', 'world', 'foo', 'bar', 'baz'];
@@ -1148,7 +1150,7 @@ def test_hex_string_arg_to_decoder(self):
class TestRotationLogicFull:
"""Tests for the full rotation pipeline with parseInt-based expressions."""
- def test_rotation_with_direct_decoder_call_in_parseint(self):
+ def test_rotation_with_direct_decoder_call_in_parseint(self) -> None:
"""Rotation where parseInt calls the decoder directly (via alias in decoder_aliases)."""
js = """
function _0xArr() {
@@ -1180,7 +1182,7 @@ def test_rotation_with_direct_decoder_call_in_parseint(self):
# '200' is already at position 0, so parseInt('200') == 200 == stop_value
assert '"200"' in code
- def test_rotation_with_binary_expression(self):
+ def test_rotation_with_binary_expression(self) -> None:
"""Rotation with binary expression: parseInt(dec(0)) + parseInt(dec(1))."""
js = """
function _0xArr() {
@@ -1212,7 +1214,7 @@ def test_rotation_with_binary_expression(self):
# parseInt('100') + parseInt('50') = 150 = stop, no rotation needed
assert '"100"' in code
- def test_rotation_with_subtraction_expression(self):
+ def test_rotation_with_subtraction_expression(self) -> None:
"""Rotation with subtraction: parseInt(dec(0)) - parseInt(dec(1))."""
js = """
function _0xArr() {
@@ -1243,7 +1245,7 @@ def test_rotation_with_subtraction_expression(self):
assert changed is True
assert '"300"' in code
- def test_rotation_with_multiplication_expression(self):
+ def test_rotation_with_multiplication_expression(self) -> None:
"""Rotation with multiply: parseInt(dec(0)) * parseInt(dec(1))."""
js = """
function _0xArr() {
@@ -1273,7 +1275,7 @@ def test_rotation_with_multiplication_expression(self):
code, changed = roundtrip(js, StringRevealer)
assert changed is True
- def test_rotation_with_wrapper_in_parseint(self):
+ def test_rotation_with_wrapper_in_parseint(self) -> None:
"""Rotation where parseInt calls a wrapper function."""
js = """
function _0xArr() {
@@ -1307,7 +1309,7 @@ def test_rotation_with_wrapper_in_parseint(self):
assert changed is True
assert '"500"' in code
- def test_rotation_needs_one_shift(self):
+ def test_rotation_needs_one_shift(self) -> None:
"""Rotation that needs exactly one shift before parseInt matches.
Uses a wrapper in the rotation expression so _parse_parseInt_call can match.
@@ -1346,7 +1348,7 @@ def test_rotation_needs_one_shift(self):
# dec(0) returns '42'
assert '"42"' in code
- def test_rotation_with_negate_expression(self):
+ def test_rotation_with_negate_expression(self) -> None:
"""Rotation with negation: -parseInt(dec(0)) + literal."""
js = """
function _0xArr() {
@@ -1378,7 +1380,7 @@ def test_rotation_with_negate_expression(self):
# -parseInt('-100') + 200 = -(-100) + 200 = 100 + 200 = 300 = stop_value
assert '"-100"' in code
- def test_rotation_with_literal_node_in_try(self):
+ def test_rotation_with_literal_node_in_try(self) -> None:
"""Rotation expression that is just a literal value (no parseInt call)."""
js = """
function _0xArr() {
@@ -1415,7 +1417,7 @@ def test_rotation_with_literal_node_in_try(self):
class TestRotationArgResolution:
"""Tests for _resolve_rotation_arg with various argument types."""
- def test_rotation_with_member_expression_arg(self):
+ def test_rotation_with_member_expression_arg(self) -> None:
"""Rotation IIFE that has local objects referenced in parseInt args."""
js = """
function _0xArr() {
@@ -1447,7 +1449,7 @@ def test_rotation_with_member_expression_arg(self):
assert changed is True
assert '"300"' in code
- def test_rotation_with_string_hex_arg(self):
+ def test_rotation_with_string_hex_arg(self) -> None:
"""Rotation with hex string literal as argument to decoder."""
js = """
function _0xArr() {
@@ -1487,7 +1489,7 @@ def test_rotation_with_string_hex_arg(self):
class TestSequenceExpressionRotation:
"""Tests for rotation inside a SequenceExpression."""
- def test_rotation_in_sequence_expression_obfuscatorio(self):
+ def test_rotation_in_sequence_expression_obfuscatorio(self) -> None:
"""Rotation IIFE as part of a SequenceExpression (obfuscator.io pattern)."""
js = """
function _0xArr() {
@@ -1518,7 +1520,7 @@ def test_rotation_in_sequence_expression_obfuscatorio(self):
assert changed is True
assert '"777"' in code
- def test_var_rotation_in_sequence_expression(self):
+ def test_var_rotation_in_sequence_expression(self) -> None:
"""Var-based rotation IIFE inside a SequenceExpression."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar', 'baz', 'qux'];
@@ -1541,7 +1543,7 @@ def test_var_rotation_in_sequence_expression(self):
class TestWrapperAnalysisEdgeCases:
"""Tests for _analyze_wrapper_expr edge cases."""
- def test_wrapper_with_non_block_body_ignored(self):
+ def test_wrapper_with_non_block_body_ignored(self) -> None:
"""Function expression with expression body (not BlockStatement) is not a wrapper."""
js = """
function _0xArr() {
@@ -1560,7 +1562,7 @@ def test_wrapper_with_non_block_body_ignored(self):
assert changed is True
assert '"hello"' in code
- def test_wrapper_with_multiple_statements_not_wrapper(self):
+ def test_wrapper_with_multiple_statements_not_wrapper(self) -> None:
"""Function with more than one statement is not recognized as a wrapper."""
js = """
function _0xArr() {
@@ -1586,7 +1588,7 @@ def test_wrapper_with_multiple_statements_not_wrapper(self):
# _0xNotWrap(1) should NOT be replaced since it's not a valid wrapper
assert '_0xNotWrap' in code
- def test_wrapper_non_return_statement_not_wrapper(self):
+ def test_wrapper_non_return_statement_not_wrapper(self) -> None:
"""Function with single non-return statement is not a wrapper."""
js = """
function _0xArr() {
@@ -1610,7 +1612,7 @@ def test_wrapper_non_return_statement_not_wrapper(self):
assert '"hello"' in code
assert '_0xNotWrap' in code
- def test_wrapper_return_not_call_not_wrapper(self):
+ def test_wrapper_return_not_call_not_wrapper(self) -> None:
"""Wrapper that returns a non-call expression is not a wrapper."""
js = """
function _0xArr() {
@@ -1634,7 +1636,7 @@ def test_wrapper_return_not_call_not_wrapper(self):
assert '"hello"' in code
assert '_0xNotWrap' in code
- def test_wrapper_calls_wrong_decoder_not_wrapper(self):
+ def test_wrapper_calls_wrong_decoder_not_wrapper(self) -> None:
"""Wrapper that calls a different function is not recognized as decoder wrapper."""
js = """
function _0xArr() {
@@ -1659,7 +1661,7 @@ def test_wrapper_calls_wrong_decoder_not_wrapper(self):
assert '"hello"' in code
assert '_0xNotWrap' in code
- def test_wrapper_no_call_args_not_wrapper(self):
+ def test_wrapper_no_call_args_not_wrapper(self) -> None:
"""Wrapper with no arguments to decoder call is not a wrapper."""
js = """
function _0xArr() {
@@ -1682,7 +1684,7 @@ def test_wrapper_no_call_args_not_wrapper(self):
assert changed is True
assert '"hello"' in code
- def test_wrapper_with_non_identifier_first_arg(self):
+ def test_wrapper_with_non_identifier_first_arg(self) -> None:
"""Wrapper where first arg to decoder is not a param reference."""
js = """
function _0xArr() {
@@ -1706,7 +1708,7 @@ def test_wrapper_with_non_identifier_first_arg(self):
assert '"hello"' in code
assert '_0xNotWrap' in code
- def test_extract_wrapper_offset_non_plus_minus_operator(self):
+ def test_extract_wrapper_offset_non_plus_minus_operator(self) -> None:
"""Wrapper arg expression with unsupported operator (e.g. *)."""
js = """
function _0xArr() {
@@ -1730,7 +1732,7 @@ def test_extract_wrapper_offset_non_plus_minus_operator(self):
assert '"hello"' in code
assert '_0xNotWrap' in code
- def test_extract_wrapper_offset_non_numeric_right(self):
+ def test_extract_wrapper_offset_non_numeric_right(self) -> None:
"""Wrapper arg expression p + x where x is not numeric."""
js = """
function _0xArr() {
@@ -1755,7 +1757,7 @@ def test_extract_wrapper_offset_non_numeric_right(self):
# _0xNotWrap should remain since p + q doesn't have a numeric right side
assert '_0xNotWrap' in code
- def test_extract_wrapper_offset_left_not_param(self):
+ def test_extract_wrapper_offset_left_not_param(self) -> None:
"""Wrapper arg expression where left side of binary is not a param."""
js = """
function _0xArr() {
@@ -1789,7 +1791,7 @@ def test_extract_wrapper_offset_left_not_param(self):
class TestReplacementEdgeCases:
"""Edge cases in _replace_all_wrapper_calls and _replace_direct_decoder_calls."""
- def test_wrapper_call_with_insufficient_args(self):
+ def test_wrapper_call_with_insufficient_args(self) -> None:
"""Wrapper call with fewer args than expected param_index."""
js = """
function _0xArr() {
@@ -1814,7 +1816,7 @@ def test_wrapper_call_with_insufficient_args(self):
# _0xWrap() called with no args should remain unreplaced
assert '_0xWrap()' in code
- def test_decoder_call_with_no_args(self):
+ def test_decoder_call_with_no_args(self) -> None:
"""Direct decoder call with no arguments is not replaced."""
js = """
function _0xArr() {
@@ -1834,7 +1836,7 @@ def test_decoder_call_with_no_args(self):
assert changed is True
assert '"hello"' in code
- def test_decoder_call_with_unresolvable_arg(self):
+ def test_decoder_call_with_unresolvable_arg(self) -> None:
"""Decoder call with variable (not literal) argument is not replaced."""
js = """
function _0xArr() {
@@ -1855,7 +1857,7 @@ def test_decoder_call_with_unresolvable_arg(self):
assert changed is True
assert '"hello"' in code
- def test_decoder_call_with_out_of_bounds_index(self):
+ def test_decoder_call_with_out_of_bounds_index(self) -> None:
"""Decoder call with an index beyond the array doesn't crash."""
js = """
function _0xArr() {
@@ -1877,7 +1879,7 @@ def test_decoder_call_with_out_of_bounds_index(self):
# Out of bounds call should remain
assert '999' in code
- def test_decoder_call_with_string_key_second_arg(self):
+ def test_decoder_call_with_string_key_second_arg(self) -> None:
"""Decoder direct call with a second string argument (key)."""
js = """
function _0xArr() {
@@ -1896,7 +1898,7 @@ def test_decoder_call_with_string_key_second_arg(self):
assert changed is True
assert '"hello"' in code
- def test_wrapper_call_with_object_member_key_arg(self):
+ def test_wrapper_call_with_object_member_key_arg(self) -> None:
"""Wrapper call where the key param is resolved via object literal."""
js = """
function _0xArr() {
@@ -1928,7 +1930,7 @@ def test_wrapper_call_with_object_member_key_arg(self):
class TestVarPatternEdgeCases:
"""Edge cases for _find_var_string_array, _find_simple_rotation, _find_var_decoder."""
- def test_var_array_not_in_first_three_statements(self):
+ def test_var_array_not_in_first_three_statements(self) -> None:
"""Var string array declared after the first 3 statements is not found."""
js = """
var a = 1;
@@ -1942,7 +1944,7 @@ def test_var_array_not_in_first_three_statements(self):
# Array is at index 3, beyond the first 3 statements checked
assert changed is False
- def test_var_array_with_non_string_elements(self):
+ def test_var_array_with_non_string_elements(self) -> None:
"""Var array with mixed types is not recognized as string array."""
js = """
var _0xarr = ['hello', 42, 'foo', 'bar'];
@@ -1952,7 +1954,7 @@ def test_var_array_with_non_string_elements(self):
code, changed = roundtrip(js, StringRevealer)
assert changed is False
- def test_var_array_with_non_identifier_declaration(self):
+ def test_var_array_with_non_identifier_declaration(self) -> None:
"""Var declaration with destructuring pattern is not matched."""
js = """
var [a, b] = ['hello', 'world', 'foo', 'bar'];
@@ -1962,7 +1964,7 @@ def test_var_array_with_non_identifier_declaration(self):
code, changed = roundtrip(js, StringRevealer)
assert changed is False
- def test_find_var_decoder_function_expression(self):
+ def test_find_var_decoder_function_expression(self) -> None:
"""Var decoder as function expression referencing array name."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -1977,7 +1979,7 @@ def test_find_var_decoder_function_expression(self):
assert changed is True
assert '"hello"' in code
- def test_var_decoder_with_no_matching_array_ref(self):
+ def test_var_decoder_with_no_matching_array_ref(self) -> None:
"""Decoder function that doesn't reference the array name is not found."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -1991,7 +1993,7 @@ def test_var_decoder_with_no_matching_array_ref(self):
code, changed = roundtrip(js, StringRevealer)
assert changed is False
- def test_var_decoder_not_function_expression(self):
+ def test_var_decoder_not_function_expression(self) -> None:
"""Var declaration that is not a function expression is not decoder."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -2002,7 +2004,7 @@ def test_var_decoder_not_function_expression(self):
# Direct array strategy should handle this
assert isinstance(code, str)
- def test_simple_rotation_with_for_statement(self):
+ def test_simple_rotation_with_for_statement(self) -> None:
"""Simple rotation pattern matching checks for push/shift in source."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar', 'baz', 'qux'];
@@ -2013,7 +2015,7 @@ def test_simple_rotation_with_for_statement(self):
code, changed = roundtrip(js, StringRevealer)
assert changed is True
- def test_simple_rotation_no_push_shift_not_rotation(self):
+ def test_simple_rotation_no_push_shift_not_rotation(self) -> None:
"""IIFE without push/shift in source is not recognized as rotation."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -2025,7 +2027,7 @@ def test_simple_rotation_no_push_shift_not_rotation(self):
assert changed is True
assert '"hello"' in code # No rotation applied
- def test_simple_rotation_wrong_array_name(self):
+ def test_simple_rotation_wrong_array_name(self) -> None:
"""Rotation IIFE that references different array name is not matched."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -2037,7 +2039,7 @@ def test_simple_rotation_wrong_array_name(self):
assert changed is True
assert '"hello"' in code # No rotation since IIFE references _0xother
- def test_simple_rotation_non_numeric_count(self):
+ def test_simple_rotation_non_numeric_count(self) -> None:
"""Rotation IIFE with non-numeric count is not matched."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -2058,13 +2060,13 @@ def test_simple_rotation_non_numeric_count(self):
class TestDirectArrayAccessEdgeCases:
"""Edge cases for _try_replace_array_access and _process_direct_arrays_in_scope."""
- def test_direct_array_non_computed_member(self):
+ def test_direct_array_non_computed_member(self) -> None:
"""arr.length style access (non-computed) is not replaced."""
js = 'var arr = ["hello", "world"]; f(arr.length);'
code, changed = roundtrip(js, StringRevealer)
assert changed is False
- def test_direct_array_used_in_child_scope(self):
+ def test_direct_array_used_in_child_scope(self) -> None:
"""Direct array used inside a function (child scope)."""
js = """
var arr = ["hello", "world"];
@@ -2077,7 +2079,7 @@ def test_direct_array_used_in_child_scope(self):
assert changed is True
assert '"hello"' in code
- def test_direct_array_in_child_scope_no_binding(self):
+ def test_direct_array_in_child_scope_no_binding(self) -> None:
"""Child scope that doesn't reference the array leaves it alone."""
js = """
var arr = ["hello", "world"];
@@ -2089,7 +2091,7 @@ def test_direct_array_in_child_scope_no_binding(self):
code, changed = roundtrip(js, StringRevealer)
assert changed is False
- def test_replace_node_in_ast_index_path(self):
+ def test_replace_node_in_ast_index_path(self) -> None:
"""Verify replacement works when target is in an array (index != None)."""
js = 'var arr = ["a", "b"]; f(arr[0], arr[1]);'
code, changed = roundtrip(js, StringRevealer)
@@ -2106,14 +2108,14 @@ def test_replace_node_in_ast_index_path(self):
class TestFindArrayExpressionInStatement:
"""Tests for _find_array_expression_in_statement."""
- def test_array_in_variable_declaration(self):
+ def test_array_in_variable_declaration(self) -> None:
ast = parse("var x = [1, 2, 3];")
stmt = ast['body'][0]
result = StringRevealer._find_array_expression_in_statement(stmt)
assert result is not None
assert result['type'] == 'ArrayExpression'
- def test_array_in_assignment_expression(self):
+ def test_array_in_assignment_expression(self) -> None:
"""Array in ExpressionStatement with AssignmentExpression."""
ast = parse("x = [1, 2, 3];")
stmt = ast['body'][0]
@@ -2121,28 +2123,28 @@ def test_array_in_assignment_expression(self):
assert result is not None
assert result['type'] == 'ArrayExpression'
- def test_assignment_non_array_rhs(self):
+ def test_assignment_non_array_rhs(self) -> None:
"""Assignment with non-array right side returns None."""
ast = parse("x = 42;")
stmt = ast['body'][0]
result = StringRevealer._find_array_expression_in_statement(stmt)
assert result is None
- def test_non_declaration_non_assignment(self):
+ def test_non_declaration_non_assignment(self) -> None:
"""Statement that is neither declaration nor assignment returns None."""
ast = parse("if (true) {}")
stmt = ast['body'][0]
result = StringRevealer._find_array_expression_in_statement(stmt)
assert result is None
- def test_variable_declaration_non_array_init(self):
+ def test_variable_declaration_non_array_init(self) -> None:
"""Variable declaration with non-array init returns None."""
ast = parse("var x = 42;")
stmt = ast['body'][0]
result = StringRevealer._find_array_expression_in_statement(stmt)
assert result is None
- def test_expression_statement_non_assignment(self):
+ def test_expression_statement_non_assignment(self) -> None:
"""ExpressionStatement that is not an assignment returns None."""
ast = parse("foo();")
stmt = ast['body'][0]
@@ -2158,7 +2160,7 @@ def test_expression_statement_non_assignment(self):
class TestExtractArrayFromStatement:
"""Tests for _extract_array_from_statement ExpressionStatement path."""
- def test_array_from_assignment_expression_statement(self):
+ def test_array_from_assignment_expression_statement(self) -> None:
"""String array in an assignment expression (not var declaration)."""
js = """
function _0xArr() {
@@ -2190,11 +2192,11 @@ def test_array_from_assignment_expression_statement(self):
class TestEvalNumericBinaryEdge:
"""Test _eval_numeric with binary expressions producing None children."""
- def test_binary_with_non_numeric_left(self):
+ def test_binary_with_non_numeric_left(self) -> None:
node = parse_expr('"abc" + 1')
assert _eval_numeric(node) is None
- def test_binary_with_non_numeric_right(self):
+ def test_binary_with_non_numeric_right(self) -> None:
node = parse_expr('1 + "abc"')
assert _eval_numeric(node) is None
@@ -2207,7 +2209,7 @@ def test_binary_with_non_numeric_right(self):
class TestCollectObjectLiteralsEdgeCases:
"""Edge cases for _collect_object_literals."""
- def test_property_with_non_literal_key(self):
+ def test_property_with_non_literal_key(self) -> None:
"""Object with computed key -- esprima parses [x] as Identifier key with computed flag."""
# In esprima, {[x]: 42} still produces a Property with key as Identifier
# but the property is marked computed. _collect_object_literals checks
@@ -2218,14 +2220,14 @@ def test_property_with_non_literal_key(self):
# Numeric key is not an identifier or string literal, so it should be skipped
assert ('obj', 0) not in result
- def test_property_with_no_key_or_value(self):
+ def test_property_with_no_key_or_value(self) -> None:
"""Shorthand property patterns."""
ast = parse('var obj = {a: 42, b: "hi"};')
result = _collect_object_literals(ast)
assert result[('obj', 'a')] == 42
assert result[('obj', 'b')] == 'hi'
- def test_property_non_numeric_non_string_value(self):
+ def test_property_non_numeric_non_string_value(self) -> None:
"""Object property with non-literal value is ignored."""
ast = parse('var obj = {a: x};')
result = _collect_object_literals(ast)
@@ -2240,7 +2242,7 @@ def test_property_non_numeric_non_string_value(self):
class TestDecoderTypeDetection:
"""Tests for base64/RC4 decoder type detection."""
- def test_base64_decoder_detected(self):
+ def test_base64_decoder_detected(self) -> None:
"""Decoder function containing base64 alphabet is detected as Base64."""
js = """
function _0xArr() {
@@ -2268,7 +2270,7 @@ def test_base64_decoder_detected(self):
class TestUpdateAstArray:
"""Tests for _update_ast_array via rotation that modifies the AST."""
- def test_rotation_updates_ast_array(self):
+ def test_rotation_updates_ast_array(self) -> None:
"""Rotation execution should update the AST array elements.
Must use a wrapper in the rotation expression so _parse_parseInt_call matches.
@@ -2316,7 +2318,7 @@ def test_rotation_updates_ast_array(self):
class TestExtractRotationExpression:
"""Tests for _extract_rotation_expression with various loop types."""
- def test_rotation_with_for_loop(self):
+ def test_rotation_with_for_loop(self) -> None:
"""Rotation IIFE with for loop instead of while."""
js = """
function _0xArr() {
@@ -2347,7 +2349,7 @@ def test_rotation_with_for_loop(self):
assert changed is True
assert '"500"' in code
- def test_rotation_with_empty_func_body(self):
+ def test_rotation_with_empty_func_body(self) -> None:
"""Rotation IIFE with empty body produces no rotation expression."""
js = """
function _0xArr() {
@@ -2368,7 +2370,7 @@ def test_rotation_with_empty_func_body(self):
assert changed is True
assert '"hello"' in code
- def test_rotation_with_assignment_in_try(self):
+ def test_rotation_with_assignment_in_try(self) -> None:
"""Rotation where try block uses assignment expression instead of var."""
js = """
function _0xArr() {
@@ -2409,7 +2411,7 @@ def test_rotation_with_assignment_in_try(self):
class TestParseRotationOp:
"""Tests for _parse_rotation_op with various expression types."""
- def test_rotation_op_with_modulo(self):
+ def test_rotation_op_with_modulo(self) -> None:
"""Rotation expression with modulo operator."""
js = """
function _0xArr() {
@@ -2440,7 +2442,7 @@ def test_rotation_op_with_modulo(self):
assert changed is True
# 10 % 3 = 1 = stop_value
- def test_rotation_op_with_division(self):
+ def test_rotation_op_with_division(self) -> None:
"""Rotation expression with division operator."""
js = """
function _0xArr() {
@@ -2470,7 +2472,7 @@ def test_rotation_op_with_division(self):
code, changed = roundtrip(js, StringRevealer)
assert changed is True
- def test_rotation_non_parseint_call_ignored(self):
+ def test_rotation_non_parseint_call_ignored(self) -> None:
"""Rotation expression with non-parseInt call returns None from _parse_rotation_op."""
js = """
function _0xArr() {
@@ -2511,7 +2513,7 @@ def test_rotation_non_parseint_call_ignored(self):
class TestTryExecuteRotationCall:
"""Edge cases for _try_execute_rotation_call."""
- def test_rotation_callee_not_function_expression(self):
+ def test_rotation_callee_not_function_expression(self) -> None:
"""Rotation call whose callee is not a FunctionExpression is skipped."""
js = """
function _0xArr() {
@@ -2531,7 +2533,7 @@ def test_rotation_callee_not_function_expression(self):
assert changed is True
assert '"hello"' in code
- def test_rotation_wrong_arg_count(self):
+ def test_rotation_wrong_arg_count(self) -> None:
"""Rotation IIFE with != 2 arguments is skipped."""
js = """
function _0xArr() {
@@ -2552,7 +2554,7 @@ def test_rotation_wrong_arg_count(self):
assert changed is True
assert '"hello"' in code
- def test_rotation_first_arg_not_array_func(self):
+ def test_rotation_first_arg_not_array_func(self) -> None:
"""Rotation IIFE where first arg is not the array function name."""
js = """
function _0xArr() {
@@ -2573,7 +2575,7 @@ def test_rotation_first_arg_not_array_func(self):
assert changed is True
assert '"hello"' in code
- def test_rotation_non_numeric_stop_value(self):
+ def test_rotation_non_numeric_stop_value(self) -> None:
"""Rotation IIFE with non-numeric stop value is skipped."""
js = """
function _0xArr() {
@@ -2610,7 +2612,7 @@ def test_rotation_non_numeric_stop_value(self):
class TestExtractDecoderOffset:
"""Tests for _extract_decoder_offset edge cases."""
- def test_decoder_with_addition_offset(self):
+ def test_decoder_with_addition_offset(self) -> None:
"""Decoder with idx = idx + OFFSET."""
js = """
function _0xArr() {
@@ -2630,7 +2632,7 @@ def test_decoder_with_addition_offset(self):
# offset is +2, so _0xDec(0) -> arr[0+2] = 'foo'
assert '"foo"' in code
- def test_decoder_with_no_offset(self):
+ def test_decoder_with_no_offset(self) -> None:
"""Decoder without any param reassignment has offset 0."""
js = """
function _0xArr() {
@@ -2657,7 +2659,7 @@ def test_decoder_with_no_offset(self):
class TestResolveRotationArgEdgeCases:
"""Edge cases for _resolve_rotation_arg returning string values."""
- def test_rotation_arg_non_hex_string_literal(self):
+ def test_rotation_arg_non_hex_string_literal(self) -> None:
"""Non-hex, non-numeric string literal is returned as-is (for RC4 key)."""
# We test this indirectly through the rotation with wrapper + key param
js = """
@@ -2692,7 +2694,7 @@ def test_rotation_arg_non_hex_string_literal(self):
assert changed is True
assert '"500"' in code
- def test_rotation_arg_member_expression_computed(self):
+ def test_rotation_arg_member_expression_computed(self) -> None:
"""MemberExpression with computed string key in rotation locals."""
js = """
function _0xArr() {
@@ -2733,7 +2735,7 @@ def test_rotation_arg_member_expression_computed(self):
class TestCollectRotationLocalsEdgeCases:
"""Edge cases for _collect_rotation_locals."""
- def test_rotation_locals_with_string_key(self):
+ def test_rotation_locals_with_string_key(self) -> None:
"""Object literal with string keys in rotation IIFE."""
ast = parse(
"""
@@ -2755,7 +2757,7 @@ def test_rotation_locals_with_string_key(self):
assert result['J']['A'] == 0xB9
assert result['J']['B'] == 'key'
- def test_rotation_locals_non_object_var_ignored(self):
+ def test_rotation_locals_non_object_var_ignored(self) -> None:
"""Non-object variable declarations in IIFE are ignored."""
ast = parse(
"""
@@ -2774,7 +2776,7 @@ def test_rotation_locals_non_object_var_ignored(self):
result = StringRevealer._collect_rotation_locals(iife_func)
assert result == {}
- def test_rotation_locals_non_identifier_name_ignored(self):
+ def test_rotation_locals_non_identifier_name_ignored(self) -> None:
"""Var declaration with non-identifier pattern is ignored."""
ast = parse(
"""
@@ -2789,7 +2791,7 @@ def test_rotation_locals_non_identifier_name_ignored(self):
assert 'J' in result
assert result['J']['A'] == 1
- def test_rotation_locals_empty_object(self):
+ def test_rotation_locals_empty_object(self) -> None:
"""Empty object literal in IIFE is not added (no properties)."""
ast = parse(
"""
@@ -2812,7 +2814,7 @@ def test_rotation_locals_empty_object(self):
class TestRc4DecoderCreation:
"""Tests for RC4 decoder type being created via _create_base_decoder."""
- def test_rc4_decoder_detected_and_used(self):
+ def test_rc4_decoder_detected_and_used(self) -> None:
"""Decoder with base64 alphabet AND fromCharCode...^ pattern is detected as RC4."""
js = """
function _0xArr() {
@@ -2847,14 +2849,14 @@ def _make_revealer(self, js):
ast = parse(js)
return StringRevealer(ast), ast
- def test_parse_rotation_op_literal(self):
+ def test_parse_rotation_op_literal(self) -> None:
"""_parse_rotation_op handles a bare numeric literal."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('42')
result = t._parse_rotation_op(node, {}, set())
assert result == {'op': 'literal', 'value': 42}
- def test_parse_rotation_op_negate(self):
+ def test_parse_rotation_op_negate(self) -> None:
"""_parse_rotation_op handles unary negation."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('-42')
@@ -2864,7 +2866,7 @@ def test_parse_rotation_op_negate(self):
assert result['child']['op'] == 'literal'
assert result['child']['value'] == 42
- def test_parse_rotation_op_binary(self):
+ def test_parse_rotation_op_binary(self) -> None:
"""_parse_rotation_op handles binary addition of literals."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('10 + 20')
@@ -2873,7 +2875,7 @@ def test_parse_rotation_op_binary(self):
assert result['op'] == 'binary'
assert result['operator'] == '+'
- def test_parse_rotation_op_call_with_wrapper(self):
+ def test_parse_rotation_op_call_with_wrapper(self) -> None:
"""_parse_rotation_op handles parseInt(wrapper(0))."""
t, ast = self._make_revealer('var x = 1;')
wrapper = WrapperInfo('_0xWrap', param_index=0, wrapper_offset=0, func_node={})
@@ -2884,7 +2886,7 @@ def test_parse_rotation_op_call_with_wrapper(self):
assert result['wrapper_name'] == '_0xWrap'
assert result['args'] == [0]
- def test_parse_rotation_op_call_with_decoder_alias(self):
+ def test_parse_rotation_op_call_with_decoder_alias(self) -> None:
"""_parse_rotation_op handles parseInt(alias(0)) where alias is in decoder_aliases."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('parseInt(_0xAlias(0))')
@@ -2894,20 +2896,20 @@ def test_parse_rotation_op_call_with_decoder_alias(self):
assert result['alias_name'] == '_0xAlias'
assert result['args'] == [0]
- def test_parse_rotation_op_non_dict_returns_none(self):
+ def test_parse_rotation_op_non_dict_returns_none(self) -> None:
"""_parse_rotation_op returns None for non-dict input."""
t, ast = self._make_revealer('var x = 1;')
assert t._parse_rotation_op(None, {}, set()) is None
assert t._parse_rotation_op('string', {}, set()) is None
- def test_parse_rotation_op_unsupported_type_returns_none(self):
+ def test_parse_rotation_op_unsupported_type_returns_none(self) -> None:
"""_parse_rotation_op returns None for unsupported node types."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('x') # Identifier
result = t._parse_rotation_op(node, {}, set())
assert result is None
- def test_parse_rotation_op_negate_with_non_numeric_child(self):
+ def test_parse_rotation_op_negate_with_non_numeric_child(self) -> None:
"""_parse_rotation_op negate with non-parseable child returns None."""
t, ast = self._make_revealer('var x = 1;')
# -x where x is an identifier, not resolvable
@@ -2915,7 +2917,7 @@ def test_parse_rotation_op_negate_with_non_numeric_child(self):
result = t._parse_rotation_op(node, {}, set())
assert result is None
- def test_parse_rotation_op_binary_with_unparseable_child(self):
+ def test_parse_rotation_op_binary_with_unparseable_child(self) -> None:
"""_parse_rotation_op binary with unparseable left or right returns None."""
t, ast = self._make_revealer('var x = 1;')
# x + 1 where x is identifier
@@ -2923,42 +2925,42 @@ def test_parse_rotation_op_binary_with_unparseable_child(self):
result = t._parse_rotation_op(node, {}, set())
assert result is None
- def test_parse_parseInt_call_not_parseint(self):
+ def test_parse_parseInt_call_not_parseint(self) -> None:
"""_parse_parseInt_call returns None when callee is not parseInt."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('Math.floor(1)')
result = t._parse_parseInt_call(node, {}, set())
assert result is None
- def test_parse_parseInt_call_wrong_arg_count(self):
+ def test_parse_parseInt_call_wrong_arg_count(self) -> None:
"""_parse_parseInt_call returns None when parseInt has != 1 arg."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('parseInt(1, 10)')
result = t._parse_parseInt_call(node, {}, set())
assert result is None
- def test_parse_parseInt_call_inner_not_call(self):
+ def test_parse_parseInt_call_inner_not_call(self) -> None:
"""_parse_parseInt_call returns None when inner arg is not a call."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('parseInt(42)')
result = t._parse_parseInt_call(node, {}, set())
assert result is None
- def test_parse_parseInt_call_inner_callee_not_identifier(self):
+ def test_parse_parseInt_call_inner_callee_not_identifier(self) -> None:
"""_parse_parseInt_call returns None when inner callee is not identifier."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('parseInt(a.b(0))')
result = t._parse_parseInt_call(node, {}, set())
assert result is None
- def test_parse_parseInt_call_inner_unknown_function(self):
+ def test_parse_parseInt_call_inner_unknown_function(self) -> None:
"""_parse_parseInt_call returns None when inner function is not in wrappers/aliases."""
t, ast = self._make_revealer('var x = 1;')
node = parse_expr('parseInt(unknownFunc(0))')
result = t._parse_parseInt_call(node, {}, set())
assert result is None
- def test_parse_parseInt_call_unresolvable_arg(self):
+ def test_parse_parseInt_call_unresolvable_arg(self) -> None:
"""_parse_parseInt_call returns None when inner arg can't be resolved."""
t, ast = self._make_revealer('var x = 1;')
t._rotation_locals = {}
@@ -2967,35 +2969,35 @@ def test_parse_parseInt_call_unresolvable_arg(self):
result = t._parse_parseInt_call(node, {'_0xW': wrapper}, set())
assert result is None
- def test_resolve_rotation_arg_numeric(self):
+ def test_resolve_rotation_arg_numeric(self) -> None:
"""_resolve_rotation_arg resolves numeric literal."""
t, ast = self._make_revealer('var x = 1;')
t._rotation_locals = {}
node = parse_expr('42')
assert t._resolve_rotation_arg(node) == 42
- def test_resolve_rotation_arg_string_hex(self):
+ def test_resolve_rotation_arg_string_hex(self) -> None:
"""_resolve_rotation_arg resolves hex string literal."""
t, ast = self._make_revealer('var x = 1;')
t._rotation_locals = {}
node = parse_expr('"0x1b"')
assert t._resolve_rotation_arg(node) == 0x1B
- def test_resolve_rotation_arg_string_decimal(self):
+ def test_resolve_rotation_arg_string_decimal(self) -> None:
"""_resolve_rotation_arg resolves decimal string literal."""
t, ast = self._make_revealer('var x = 1;')
t._rotation_locals = {}
node = parse_expr('"42"')
assert t._resolve_rotation_arg(node) == 42
- def test_resolve_rotation_arg_string_non_numeric(self):
+ def test_resolve_rotation_arg_string_non_numeric(self) -> None:
"""_resolve_rotation_arg resolves non-numeric string as-is."""
t, ast = self._make_revealer('var x = 1;')
t._rotation_locals = {}
node = parse_expr('"myKey"')
assert t._resolve_rotation_arg(node) == 'myKey'
- def test_resolve_rotation_arg_member_identifier_prop(self):
+ def test_resolve_rotation_arg_member_identifier_prop(self) -> None:
"""_resolve_rotation_arg resolves J.A from rotation locals."""
t, ast = self._make_revealer('var x = 1;')
t._rotation_locals = {'J': {'A': 42}}
@@ -3007,7 +3009,7 @@ def test_resolve_rotation_arg_member_identifier_prop(self):
}
assert t._resolve_rotation_arg(node) == 42
- def test_resolve_rotation_arg_member_string_prop(self):
+ def test_resolve_rotation_arg_member_string_prop(self) -> None:
"""_resolve_rotation_arg resolves J['A'] from rotation locals."""
t, ast = self._make_revealer('var x = 1;')
t._rotation_locals = {'J': {'A': 99}}
@@ -3019,7 +3021,7 @@ def test_resolve_rotation_arg_member_string_prop(self):
}
assert t._resolve_rotation_arg(node) == 99
- def test_resolve_rotation_arg_member_unknown_object(self):
+ def test_resolve_rotation_arg_member_unknown_object(self) -> None:
"""_resolve_rotation_arg returns None for unknown object in member."""
t, ast = self._make_revealer('var x = 1;')
t._rotation_locals = {}
@@ -3031,27 +3033,27 @@ def test_resolve_rotation_arg_member_unknown_object(self):
}
assert t._resolve_rotation_arg(node) is None
- def test_resolve_rotation_arg_identifier_returns_none(self):
+ def test_resolve_rotation_arg_identifier_returns_none(self) -> None:
"""_resolve_rotation_arg returns None for bare identifier."""
t, ast = self._make_revealer('var x = 1;')
t._rotation_locals = {}
node = parse_expr('x')
assert t._resolve_rotation_arg(node) is None
- def test_apply_rotation_op_literal(self):
+ def test_apply_rotation_op_literal(self) -> None:
"""_apply_rotation_op evaluates a literal node."""
t, ast = self._make_revealer('var x = 1;')
result = t._apply_rotation_op({'op': 'literal', 'value': 42}, {}, None)
assert result == 42
- def test_apply_rotation_op_negate(self):
+ def test_apply_rotation_op_negate(self) -> None:
"""_apply_rotation_op evaluates a negate node."""
t, ast = self._make_revealer('var x = 1;')
op = {'op': 'negate', 'child': {'op': 'literal', 'value': 10}}
result = t._apply_rotation_op(op, {}, None)
assert result == -10
- def test_apply_rotation_op_binary(self):
+ def test_apply_rotation_op_binary(self) -> None:
"""_apply_rotation_op evaluates a binary node."""
t, ast = self._make_revealer('var x = 1;')
op = {
@@ -3063,9 +3065,8 @@ def test_apply_rotation_op_binary(self):
result = t._apply_rotation_op(op, {}, None)
assert result == 30
- def test_apply_rotation_op_call_wrapper(self):
+ def test_apply_rotation_op_call_wrapper(self) -> None:
"""_apply_rotation_op evaluates a wrapper call op."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
t, ast = self._make_revealer('var x = 1;')
decoder = BasicStringDecoder(['100', 'hello', 'world'], 0)
@@ -3074,9 +3075,8 @@ def test_apply_rotation_op_call_wrapper(self):
result = t._apply_rotation_op(op, {'w': wrapper}, decoder)
assert result == 100
- def test_apply_rotation_op_direct_decoder_call(self):
+ def test_apply_rotation_op_direct_decoder_call(self) -> None:
"""_apply_rotation_op evaluates a direct_decoder_call op."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
t, ast = self._make_revealer('var x = 1;')
decoder = BasicStringDecoder(['200', 'hello'], 0)
@@ -3085,9 +3085,8 @@ def test_apply_rotation_op_direct_decoder_call(self):
result = t._apply_rotation_op(op, {}, decoder, alias_decoder_map=alias_map)
assert result == 200
- def test_apply_rotation_op_direct_decoder_call_with_key(self):
+ def test_apply_rotation_op_direct_decoder_call_with_key(self) -> None:
"""_apply_rotation_op direct_decoder_call with key arg."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
t, ast = self._make_revealer('var x = 1;')
decoder = BasicStringDecoder(['300', 'hello'], 0)
@@ -3096,13 +3095,13 @@ def test_apply_rotation_op_direct_decoder_call_with_key(self):
# BasicStringDecoder ignores the key, just returns by index
assert result == 300
- def test_apply_rotation_op_unknown_op_raises(self):
+ def test_apply_rotation_op_unknown_op_raises(self) -> None:
"""_apply_rotation_op raises for unknown op."""
t, ast = self._make_revealer('var x = 1;')
with pytest.raises(ValueError, match='Unknown op'):
t._apply_rotation_op({'op': 'unknown_op'}, {}, None)
- def test_apply_rotation_op_call_invalid_wrapper_args_raises(self):
+ def test_apply_rotation_op_call_invalid_wrapper_args_raises(self) -> None:
"""_apply_rotation_op raises when wrapper args are invalid."""
t, ast = self._make_revealer('var x = 1;')
wrapper = WrapperInfo('w', param_index=5, wrapper_offset=0, func_node={})
@@ -3110,16 +3109,15 @@ def test_apply_rotation_op_call_invalid_wrapper_args_raises(self):
with pytest.raises(ValueError, match='Invalid wrapper args'):
t._apply_rotation_op(op, {'w': wrapper}, None)
- def test_apply_rotation_op_direct_decoder_no_args_raises(self):
+ def test_apply_rotation_op_direct_decoder_no_args_raises(self) -> None:
"""_apply_rotation_op raises when direct_decoder_call has no args."""
t, ast = self._make_revealer('var x = 1;')
op = {'op': 'direct_decoder_call', 'alias_name': 'x', 'args': []}
with pytest.raises(ValueError, match='No args'):
t._apply_rotation_op(op, {}, None)
- def test_execute_rotation_basic(self):
+ def test_execute_rotation_basic(self) -> None:
"""_execute_rotation rotates array until expression matches stop."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
t, ast = self._make_revealer('var x = 1;')
string_array = ['hello', '42', 'world', 'foo', 'bar']
@@ -3131,9 +3129,8 @@ def test_execute_rotation_basic(self):
assert result is True
assert string_array[0] == '42'
- def test_execute_rotation_clears_decoder_cache(self):
+ def test_execute_rotation_clears_decoder_cache(self) -> None:
"""_execute_rotation clears decoder caches on each rotation."""
- from pyjsclear.utils.string_decoders import Base64StringDecoder
t, ast = self._make_revealer('var x = 1;')
# Use Base64StringDecoder which has a _cache attribute
@@ -3147,7 +3144,6 @@ def test_execute_rotation_clears_decoder_cache(self):
# BasicStringDecoder returns string directly, Base64StringDecoder does base64_transform
# But since Base64 won't give us clean ints, use BasicStringDecoder for the actual test
# and just verify that _cache.clear() is called on Base64StringDecoder
- from pyjsclear.utils.string_decoders import BasicStringDecoder
primary = BasicStringDecoder(string_array, 0)
op = {'op': 'call', 'wrapper_name': 'w', 'args': [0]}
@@ -3164,11 +3160,8 @@ def test_execute_rotation_clears_decoder_cache(self):
# Verify the Base64 decoder's cache was cleared during rotation
assert len(decoder._cache) == 0
- def test_execute_rotation_with_alias_decoder_map(self):
+ def test_execute_rotation_with_alias_decoder_map(self) -> None:
"""_execute_rotation uses alias_decoder_map for clearing caches."""
- from pyjsclear.utils.string_decoders import Base64StringDecoder
- from pyjsclear.utils.string_decoders import BasicStringDecoder
-
t, ast = self._make_revealer('var x = 1;')
string_array = ['hello', '42', 'world', 'foo', 'bar']
primary = BasicStringDecoder(string_array, 0)
@@ -3192,7 +3185,7 @@ def test_execute_rotation_with_alias_decoder_map(self):
class TestSequenceExpressionRotationRemoval:
"""Tests for rotation inside SequenceExpression being properly removed."""
- def test_sequence_rotation_removal_with_wrapper(self):
+ def test_sequence_rotation_removal_with_wrapper(self) -> None:
"""Rotation in SequenceExpression is removed while keeping other expressions.
This triggers lines 287-300 by having the rotation succeed inside a sequence.
@@ -3241,7 +3234,7 @@ def test_sequence_rotation_removal_with_wrapper(self):
class TestRotationSequenceExpression:
"""Test that rotation can be found inside a SequenceExpression."""
- def test_rotation_in_sequence_with_decoder_alias(self):
+ def test_rotation_in_sequence_with_decoder_alias(self) -> None:
"""Rotation IIFE inside SequenceExpression using decoder alias."""
js = """
function _0xArr() {
@@ -3285,7 +3278,7 @@ def test_rotation_in_sequence_with_decoder_alias(self):
class TestExtractDecoderOffsetDirect:
"""Direct tests for _extract_decoder_offset edge cases."""
- def test_offset_with_non_identifier_left(self):
+ def test_offset_with_non_identifier_left(self) -> None:
"""Assignment where left is not an identifier is ignored."""
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
# Parse a function with arr[0] = arr[0] - 5 (member expression on left)
@@ -3301,7 +3294,7 @@ def test_offset_with_non_identifier_left(self):
offset = t._extract_decoder_offset(func_node)
assert offset == 0 # Default when no matching pattern found
- def test_offset_non_binary_right_side(self):
+ def test_offset_non_binary_right_side(self) -> None:
"""Assignment where right side is not BinaryExpression is ignored."""
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
func_ast = parse(
@@ -3316,7 +3309,7 @@ def test_offset_non_binary_right_side(self):
offset = t._extract_decoder_offset(func_node)
assert offset == 0
- def test_offset_binary_left_not_matching_param(self):
+ def test_offset_binary_left_not_matching_param(self) -> None:
"""Assignment where binary left doesn't match the assigned variable."""
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
func_ast = parse(
@@ -3331,7 +3324,7 @@ def test_offset_binary_left_not_matching_param(self):
offset = t._extract_decoder_offset(func_node)
assert offset == 0
- def test_offset_unsupported_operator(self):
+ def test_offset_unsupported_operator(self) -> None:
"""Assignment with unsupported operator (*) is ignored."""
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
func_ast = parse(
@@ -3355,28 +3348,28 @@ def test_offset_unsupported_operator(self):
class TestStringArrayFromExpression:
"""Edge cases for _string_array_from_expression."""
- def test_array_with_non_string_element(self):
+ def test_array_with_non_string_element(self) -> None:
"""Array with a numeric element returns None."""
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
node = parse_expr('[1, 2, 3]')
result = t._string_array_from_expression(node)
assert result is None
- def test_empty_array_returns_none(self):
+ def test_empty_array_returns_none(self) -> None:
"""Empty array returns None."""
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
node = parse_expr('[]')
result = t._string_array_from_expression(node)
assert result is None
- def test_non_array_returns_none(self):
+ def test_non_array_returns_none(self) -> None:
"""Non-array node returns None."""
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
node = parse_expr('42')
result = t._string_array_from_expression(node)
assert result is None
- def test_none_returns_none(self):
+ def test_none_returns_none(self) -> None:
"""None returns None."""
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
assert t._string_array_from_expression(None) is None
@@ -3390,7 +3383,7 @@ def test_none_returns_none(self):
class TestFindStringArrayFunction:
"""Edge cases for _find_string_array_function."""
- def test_function_with_short_body(self):
+ def test_function_with_short_body(self) -> None:
"""Function with only 1 statement in body is skipped."""
t, ast = TestRotationInternalsDirect._make_revealer(
None,
@@ -3404,7 +3397,7 @@ def test_function_with_short_body(self):
name, arr, idx = t._find_string_array_function(body)
assert name is None
- def test_function_without_name(self):
+ def test_function_without_name(self) -> None:
"""Function without a name is skipped."""
# FunctionDeclarations always have names in valid JS, but we test the guard
js = """
@@ -3430,12 +3423,12 @@ def test_function_without_name(self):
class TestEvalNumericUnaryArgNone:
"""Test _eval_numeric when unary argument evaluates to None."""
- def test_negate_identifier_returns_none(self):
+ def test_negate_identifier_returns_none(self) -> None:
"""Negating an identifier that can't be evaluated returns None."""
node = parse_expr('-x')
assert _eval_numeric(node) is None
- def test_positive_identifier_returns_none(self):
+ def test_positive_identifier_returns_none(self) -> None:
"""Positive sign on identifier returns None."""
node = parse_expr('+x')
assert _eval_numeric(node) is None
@@ -3449,7 +3442,7 @@ def test_positive_identifier_returns_none(self):
class TestCollectObjectLiteralsPropertyType:
"""Test _collect_object_literals with non-Property type entries."""
- def test_spread_element_ignored(self):
+ def test_spread_element_ignored(self) -> None:
"""SpreadElement in object properties is skipped (type != 'Property')."""
# We can't easily create this in valid JS that esprima parses,
# but we can test with a normal object to verify Property type check works
@@ -3458,7 +3451,7 @@ def test_spread_element_ignored(self):
assert result[('obj', 'a')] == 1
assert result[('obj', 'b')] == 'two'
- def test_property_with_no_value(self):
+ def test_property_with_no_value(self) -> None:
"""Property with missing key or value is skipped."""
# Manually construct an AST with a malformed property
ast = {
@@ -3494,7 +3487,7 @@ def test_property_with_no_value(self):
class TestProcessDirectArraysInScope:
"""Test _process_direct_arrays_in_scope when binding is not found."""
- def test_array_in_nested_function_scopes(self):
+ def test_array_in_nested_function_scopes(self) -> None:
"""Array accessed in deeply nested function scopes."""
js = """
var arr = ["hello", "world"];
@@ -3509,7 +3502,7 @@ def test_array_in_nested_function_scopes(self):
assert changed is True
assert '"hello"' in code
- def test_array_not_referenced_in_child_scope(self):
+ def test_array_not_referenced_in_child_scope(self) -> None:
"""Child scope that doesn't reference the array binding."""
js = """
var arr = ["hello", "world"];
@@ -3531,7 +3524,7 @@ def test_array_not_referenced_in_child_scope(self):
class TestFindVarStringArrayEdge:
"""Edge cases for _find_var_string_array."""
- def test_non_array_init_skipped(self):
+ def test_non_array_init_skipped(self) -> None:
"""Var declaration with non-array init is skipped."""
js = """
var _0x = 42;
@@ -3543,7 +3536,7 @@ def test_non_array_init_skipped(self):
assert changed is True
assert '"hello"' in code
- def test_short_array_skipped(self):
+ def test_short_array_skipped(self) -> None:
"""Array with < 3 elements is skipped by _find_var_string_array."""
js = """
var _0xarr = ['a', 'b'];
@@ -3562,7 +3555,7 @@ def test_short_array_skipped(self):
class TestFindSimpleRotationEdge:
"""Edge cases for _find_simple_rotation."""
- def test_non_expression_statement_skipped(self):
+ def test_non_expression_statement_skipped(self) -> None:
"""Non-ExpressionStatement in body is skipped when looking for rotation."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -3574,7 +3567,7 @@ def test_non_expression_statement_skipped(self):
assert changed is True
assert '"hello"' in code # No rotation applied
- def test_expression_statement_without_expression(self):
+ def test_expression_statement_without_expression(self) -> None:
"""Expression statement edge case."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -3595,7 +3588,7 @@ def test_expression_statement_without_expression(self):
class TestFindVarDecoderEdge:
"""Edge cases for _find_var_decoder."""
- def test_var_declaration_non_function_init(self):
+ def test_var_declaration_non_function_init(self) -> None:
"""Var declaration where init is not FunctionExpression is skipped."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -3606,7 +3599,7 @@ def test_var_declaration_non_function_init(self):
# Direct array strategy handles arr[0], but decoder pattern is not found
assert isinstance(code, str)
- def test_var_declaration_non_variable(self):
+ def test_var_declaration_non_variable(self) -> None:
"""Non-variable declaration statement is skipped when looking for decoder."""
js = """
var _0xarr = ['hello', 'world', 'foo', 'bar'];
@@ -3627,20 +3620,20 @@ def test_var_declaration_non_variable(self):
class TestTryReplaceArrayAccessEdge:
"""Edge cases for _try_replace_array_access."""
- def test_non_computed_member_not_replaced(self):
+ def test_non_computed_member_not_replaced(self) -> None:
"""arr.foo style access is not replaced."""
js = 'var arr = ["hello", "world"]; console.log(arr.length);'
code, changed = roundtrip(js, StringRevealer)
assert changed is False
assert 'arr.length' in code
- def test_member_property_not_numeric(self):
+ def test_member_property_not_numeric(self) -> None:
"""arr[x] where x is an identifier is not replaced."""
js = 'var arr = ["hello", "world"]; console.log(arr[x]);'
code, changed = roundtrip(js, StringRevealer)
assert changed is False
- def test_ref_key_not_object(self):
+ def test_ref_key_not_object(self) -> None:
"""Reference where key is not 'object' is not replaced."""
# This happens when arr is used as a property value, not as the object of a member
js = 'var arr = ["hello", "world"]; var x = {prop: arr};'
@@ -3656,7 +3649,7 @@ def test_ref_key_not_object(self):
class TestUpdateAstArrayEdge:
"""Edge case for _update_ast_array."""
- def test_update_ast_array_with_assignment_init(self):
+ def test_update_ast_array_with_assignment_init(self) -> None:
"""_update_ast_array when first statement is assignment (not var decl)."""
t, ast = TestRotationInternalsDirect._make_revealer(
None,
@@ -3679,7 +3672,7 @@ def test_update_ast_array_with_assignment_init(self):
assert elements[0]['value'] == 'world'
assert elements[1]['value'] == 'hello'
- def test_update_ast_array_empty_body(self):
+ def test_update_ast_array_empty_body(self) -> None:
"""_update_ast_array with empty function body does nothing."""
t, ast = TestRotationInternalsDirect._make_revealer(
None,
@@ -3701,27 +3694,24 @@ def test_update_ast_array_empty_body(self):
class TestDecodeAndParseInt:
"""Tests for _decode_and_parse_int error paths."""
- def test_decode_and_parse_int_returns_nan(self):
+ def test_decode_and_parse_int_returns_nan(self) -> None:
"""_decode_and_parse_int raises when decoded string is not parseable as int."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
decoder = BasicStringDecoder(['hello'], 0)
with pytest.raises(ValueError, match='NaN'):
t._decode_and_parse_int(decoder, 0)
- def test_decode_and_parse_int_none_returned(self):
+ def test_decode_and_parse_int_none_returned(self) -> None:
"""_decode_and_parse_int raises when decoder returns None (out of bounds)."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
decoder = BasicStringDecoder(['hello'], 0)
with pytest.raises(ValueError, match='Decoder returned None'):
t._decode_and_parse_int(decoder, 999)
- def test_decode_and_parse_int_with_key(self):
+ def test_decode_and_parse_int_with_key(self) -> None:
"""_decode_and_parse_int passes key to decoder.get_string."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
decoder = BasicStringDecoder(['42'], 0)
@@ -3733,9 +3723,8 @@ def test_decode_and_parse_int_with_key(self):
class TestExecuteRotationEdge:
"""Edge cases for _execute_rotation."""
- def test_execute_rotation_returns_false_when_no_match(self):
+ def test_execute_rotation_returns_false_when_no_match(self) -> None:
"""_execute_rotation returns False when no match in 100001 iterations."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
t, ast = TestRotationInternalsDirect._make_revealer(None, 'var x = 1;')
# Array where no rotation can produce parseInt matching stop_value 999999
@@ -3752,7 +3741,7 @@ def test_execute_rotation_returns_false_when_no_match(self):
class TestWrapperReplacementEdge:
"""Edge cases for _replace_all_wrapper_calls."""
- def test_wrapper_call_unresolvable_index_value(self):
+ def test_wrapper_call_unresolvable_index_value(self) -> None:
"""Wrapper call where the index param can't be resolved."""
js = """
function _0xArr() {
@@ -3777,7 +3766,7 @@ def test_wrapper_call_unresolvable_index_value(self):
# _0xWrap(someVar) should not be replaced
assert '_0xWrap(someVar)' in code
- def test_wrapper_key_param_resolution(self):
+ def test_wrapper_key_param_resolution(self) -> None:
"""Wrapper with key_param_index resolves key from call args."""
js = """
function _0xArr() {
@@ -3801,7 +3790,7 @@ def test_wrapper_key_param_resolution(self):
assert '"hello"' in code
assert '"world"' in code
- def test_decoder_returns_non_string(self):
+ def test_decoder_returns_non_string(self) -> None:
"""When decoder returns None (out of bounds), the call is not replaced."""
js = """
function _0xArr() {
@@ -3827,7 +3816,7 @@ def test_decoder_returns_non_string(self):
class TestAnalyzeWrapperExprEdge:
"""Edge cases for _analyze_wrapper_expr."""
- def test_wrapper_with_key_param_from_second_arg(self):
+ def test_wrapper_with_key_param_from_second_arg(self) -> None:
"""Wrapper passing second param as key to decoder."""
js = """
function _0xArr() {
@@ -3853,9 +3842,8 @@ def test_wrapper_with_key_param_from_second_arg(self):
class TestFindAndExecuteRotationEdge:
"""Edge cases for _find_and_execute_rotation."""
- def test_rotation_not_found_returns_none(self):
+ def test_rotation_not_found_returns_none(self) -> None:
"""When no rotation IIFE exists, returns None."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
t, ast = TestRotationInternalsDirect._make_revealer(
None,
@@ -3878,9 +3866,8 @@ def test_rotation_not_found_returns_none(self):
result = t._find_and_execute_rotation(body, '_0xArr', ['hello'], decoder, {}, set())
assert result is None
- def test_rotation_found_in_sequence_expression(self):
+ def test_rotation_found_in_sequence_expression(self) -> None:
"""Rotation IIFE inside a SequenceExpression is found and executed."""
- from pyjsclear.utils.string_decoders import BasicStringDecoder
js = """
function _0xArr() {
@@ -3924,7 +3911,7 @@ def test_rotation_found_in_sequence_expression(self):
class TestExtractRotationExpressionEdge:
"""Edge cases for _extract_rotation_expression."""
- def test_no_loop_in_iife(self):
+ def test_no_loop_in_iife(self) -> None:
"""IIFE without a while/for loop returns None."""
t, ast = TestRotationInternalsDirect._make_revealer(
None,
@@ -3940,7 +3927,7 @@ def test_no_loop_in_iife(self):
result = t._extract_rotation_expression(iife_func)
assert result is None
- def test_loop_without_try_statement(self):
+ def test_loop_without_try_statement(self) -> None:
"""Loop without a TryStatement returns None."""
t, ast = TestRotationInternalsDirect._make_revealer(
None,
@@ -3958,7 +3945,7 @@ def test_loop_without_try_statement(self):
result = t._extract_rotation_expression(iife_func)
assert result is None
- def test_try_with_empty_block(self):
+ def test_try_with_empty_block(self) -> None:
"""Try block with empty body returns None."""
t, ast = TestRotationInternalsDirect._make_revealer(
None,
diff --git a/tests/unit/transforms/unreachable_code_test.py b/tests/unit/transforms/unreachable_code_test.py
index 0d99459..042248a 100644
--- a/tests/unit/transforms/unreachable_code_test.py
+++ b/tests/unit/transforms/unreachable_code_test.py
@@ -7,41 +7,41 @@
class TestUnreachableCodeRemover:
"""Tests for removing unreachable statements after terminators."""
- def test_removes_after_return(self):
+ def test_removes_after_return(self) -> None:
code = 'function f() { return 1; console.log("dead"); }'
result, changed = roundtrip(code, UnreachableCodeRemover)
assert changed is True
assert 'dead' not in result
assert 'return 1' in result
- def test_removes_after_throw(self):
+ def test_removes_after_throw(self) -> None:
code = 'function f() { throw new Error(); var x = 1; }'
result, changed = roundtrip(code, UnreachableCodeRemover)
assert changed is True
assert 'var x' not in result
- def test_removes_after_break(self):
+ def test_removes_after_break(self) -> None:
code = 'for(;;) { break; console.log("dead"); }'
result, changed = roundtrip(code, UnreachableCodeRemover)
assert changed is True
assert 'dead' not in result
- def test_removes_after_continue(self):
+ def test_removes_after_continue(self) -> None:
code = 'for(;;) { continue; console.log("dead"); }'
result, changed = roundtrip(code, UnreachableCodeRemover)
assert changed is True
assert 'dead' not in result
- def test_preserves_reachable_code(self):
+ def test_preserves_reachable_code(self) -> None:
code = 'function f() { var x = 1; return x; }'
result, changed = roundtrip(code, UnreachableCodeRemover)
assert changed is False
- def test_no_terminator_returns_false(self):
+ def test_no_terminator_returns_false(self) -> None:
result, changed = roundtrip('var x = 1; var y = 2;', UnreachableCodeRemover)
assert changed is False
- def test_removes_multiple_after_return(self):
+ def test_removes_multiple_after_return(self) -> None:
code = 'function f() { return; var a = 1; var b = 2; var c = 3; }'
result, changed = roundtrip(code, UnreachableCodeRemover)
assert changed is True
@@ -49,14 +49,14 @@ def test_removes_multiple_after_return(self):
assert 'var b' not in result
assert 'var c' not in result
- def test_handles_nested_blocks(self):
+ def test_handles_nested_blocks(self) -> None:
code = 'function f() { if (x) { return; var dead = 1; } var live = 2; }'
result, changed = roundtrip(code, UnreachableCodeRemover)
assert changed is True
assert 'dead' not in result
assert 'live' in result
- def test_return_at_end_no_change(self):
+ def test_return_at_end_no_change(self) -> None:
code = 'function f() { var x = 1; return x; }'
result, changed = roundtrip(code, UnreachableCodeRemover)
assert changed is False
diff --git a/tests/unit/transforms/unused_vars_test.py b/tests/unit/transforms/unused_vars_test.py
index 89dcce5..f214cde 100644
--- a/tests/unit/transforms/unused_vars_test.py
+++ b/tests/unit/transforms/unused_vars_test.py
@@ -1,7 +1,6 @@
"""Unit tests for UnusedVariableRemover transform."""
-import pytest
-
+from pyjsclear.parser import parse
from pyjsclear.transforms.unused_vars import UnusedVariableRemover
from tests.unit.conftest import normalize
from tests.unit.conftest import roundtrip
@@ -10,121 +9,121 @@
class TestUnusedVariableRemover:
"""Tests for removing unreferenced variables and functions."""
- def test_unreferenced_var_in_function_removed(self):
+ def test_unreferenced_var_in_function_removed(self) -> None:
code = 'function f() { var x = 1; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
assert 'var x' not in result
assert 'function f' in result
- def test_referenced_var_in_function_kept(self):
+ def test_referenced_var_in_function_kept(self) -> None:
code = 'function f() { var x = 1; return x; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'var x = 1' in result
assert 'return x' in result
- def test_global_0x_prefixed_var_removed(self):
+ def test_global_0x_prefixed_var_removed(self) -> None:
code = 'var _0xabc = 1;'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
assert '_0xabc' not in result
- def test_global_normal_var_kept(self):
+ def test_global_normal_var_kept(self) -> None:
code = 'var foo = 1;'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'var foo = 1' in result
- def test_side_effect_call_expression_kept(self):
+ def test_side_effect_call_expression_kept(self) -> None:
code = 'function f() { var x = foo(); }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'var x = foo()' in result
- def test_side_effect_new_expression_kept(self):
+ def test_side_effect_new_expression_kept(self) -> None:
code = 'function f() { var x = new Foo(); }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'var x = new Foo()' in result
- def test_side_effect_assignment_expression_kept(self):
+ def test_side_effect_assignment_expression_kept(self) -> None:
code = 'function f() { var x = (a = 1); }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'var x' in result
- def test_side_effect_update_expression_kept(self):
+ def test_side_effect_update_expression_kept(self) -> None:
code = 'function f() { var x = a++; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'var x' in result
- def test_global_0x_function_declaration_removed(self):
+ def test_global_0x_function_declaration_removed(self) -> None:
code = 'function _0xabc() {}'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
assert '_0xabc' not in result
- def test_unreferenced_nested_function_removed(self):
+ def test_unreferenced_nested_function_removed(self) -> None:
code = 'function f() { function g() {} }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
assert 'function g' not in result
assert 'function f' in result
- def test_param_kept_even_if_unreferenced(self):
+ def test_param_kept_even_if_unreferenced(self) -> None:
code = 'function f(x) { return 1; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'function f(x)' in result
- def test_multiple_declarators_remove_only_unreferenced(self):
+ def test_multiple_declarators_remove_only_unreferenced(self) -> None:
code = 'function f() { var x = 1, y = 2; return x; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
assert 'x' in result
assert 'y' not in result
- def test_multiple_declarators_all_unreferenced(self):
+ def test_multiple_declarators_all_unreferenced(self) -> None:
code = 'function f() { var x = 1, y = 2; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
assert 'var' not in result
- def test_no_unused_returns_false(self):
+ def test_no_unused_returns_false(self) -> None:
code = 'function f(x) { return x; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
- def test_var_with_no_init_removed(self):
+ def test_var_with_no_init_removed(self) -> None:
code = 'function f() { var x; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
assert 'var x' not in result
- def test_global_normal_function_kept(self):
+ def test_global_normal_function_kept(self) -> None:
code = 'function foo() {}'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'function foo' in result
- def test_rebuild_scope_flag(self):
+ def test_rebuild_scope_flag(self) -> None:
assert UnusedVariableRemover.rebuild_scope is True
- def test_nested_side_effect_in_binary_kept(self):
+ def test_nested_side_effect_in_binary_kept(self) -> None:
code = 'function f() { var x = 1 + foo(); }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'var x' in result
- def test_pure_init_object_removed(self):
+ def test_pure_init_object_removed(self) -> None:
code = 'function f() { var x = {a: 1}; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
assert 'var x' not in result
- def test_pure_init_array_removed(self):
+ def test_pure_init_array_removed(self) -> None:
code = 'function f() { var x = [1, 2]; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
@@ -134,7 +133,7 @@ def test_pure_init_array_removed(self):
class TestUnusedVariableRemoverEdgeCases:
"""Tests for uncovered edge cases in unused variable removal."""
- def test_side_effect_in_dict_child(self):
+ def test_side_effect_in_dict_child(self) -> None:
"""Line 111: _has_side_effects with dict child having side effect."""
code = 'function f() { var x = -foo(); }'
result, changed = roundtrip(code, UnusedVariableRemover)
@@ -142,7 +141,7 @@ def test_side_effect_in_dict_child(self):
assert not changed
assert 'var x' in result
- def test_has_side_effects_none_child(self):
+ def test_has_side_effects_none_child(self) -> None:
"""Line 105: _has_side_effects with None child should continue."""
# A conditional expression has test, consequent, alternate — where alternate can be None-ish
code = 'function f() { var x = true ? 1 : 2; }'
@@ -151,7 +150,7 @@ def test_has_side_effects_none_child(self):
assert changed is True
assert 'var x' not in result
- def test_has_side_effects_non_dict_node(self):
+ def test_has_side_effects_non_dict_node(self) -> None:
"""Line 95: _has_side_effects with non-dict node returns False."""
# This is tested internally when init is a literal (non-dict-like in some paths)
code = 'function f() { var x = 42; }'
@@ -159,7 +158,7 @@ def test_has_side_effects_non_dict_node(self):
assert changed is True
assert 'var x' not in result
- def test_empty_declarations_in_batch_remove(self):
+ def test_empty_declarations_in_batch_remove(self) -> None:
"""Line 81: decls is empty/None — should return early."""
# Unusual case where VariableDeclaration has empty declarations
# Just test that normal removal works with a straightforward case
@@ -167,14 +166,14 @@ def test_empty_declarations_in_batch_remove(self):
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is True
- def test_all_declarators_referenced_no_change(self):
+ def test_all_declarators_referenced_no_change(self) -> None:
"""Line 84: new_decls length equals decls length (no change)."""
code = 'function f() { var x = 1, y = 2; return x + y; }'
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
assert 'var x' in result
- def test_side_effect_in_array_element(self):
+ def test_side_effect_in_array_element(self) -> None:
"""Lines 107-108: Array child item with side effect detected."""
# TemplateLiteral expressions list can contain call expressions
code = 'function f() { var x = [foo(), 1]; }'
@@ -183,7 +182,7 @@ def test_side_effect_in_array_element(self):
# But the init check is just on the top-level node type
assert isinstance(changed, bool)
- def test_conditional_with_side_effect(self):
+ def test_conditional_with_side_effect(self) -> None:
"""Side effect deep in conditional expression."""
code = 'function f() { var x = true ? foo() : 1; }'
result, changed = roundtrip(code, UnusedVariableRemover)
@@ -192,7 +191,7 @@ def test_conditional_with_side_effect(self):
assert not changed
assert 'var x' in result
- def test_binding_node_not_dict(self):
+ def test_binding_node_not_dict(self) -> None:
"""Line 56: binding.node that is not a dict should be skipped."""
# This is a defensive check. We test it by directly invoking with a scope
# that has a non-dict binding node. Since we can't easily construct that
@@ -201,14 +200,14 @@ def test_binding_node_not_dict(self):
result, changed = roundtrip(code, UnusedVariableRemover)
assert changed is False
- def test_has_side_effects_non_dict(self):
+ def test_has_side_effects_non_dict(self) -> None:
"""Line 95: _has_side_effects with non-dict returns False."""
remover = UnusedVariableRemover({'type': 'Program', 'body': []})
assert remover._has_side_effects(None) is False
assert remover._has_side_effects(42) is False
assert remover._has_side_effects('string') is False
- def test_has_side_effects_none_child(self):
+ def test_has_side_effects_none_child(self) -> None:
"""Line 105: None child in side effect check should be skipped."""
remover = UnusedVariableRemover({'type': 'Program', 'body': []})
# A ConditionalExpression where alternate is None
@@ -220,7 +219,7 @@ def test_has_side_effects_none_child(self):
}
assert remover._has_side_effects(node) is False
- def test_has_side_effects_list_child_with_side_effect(self):
+ def test_has_side_effects_list_child_with_side_effect(self) -> None:
"""Lines 107-108: list child with side effect item returns True."""
remover = UnusedVariableRemover({'type': 'Program', 'body': []})
# SequenceExpression has 'expressions' which is a list
@@ -233,7 +232,7 @@ def test_has_side_effects_list_child_with_side_effect(self):
}
assert remover._has_side_effects(node) is True
- def test_has_side_effects_list_child_no_side_effect(self):
+ def test_has_side_effects_list_child_no_side_effect(self) -> None:
"""Lines 107-108: list child without side effect items returns False."""
remover = UnusedVariableRemover({'type': 'Program', 'body': []})
node = {
@@ -245,7 +244,7 @@ def test_has_side_effects_list_child_no_side_effect(self):
}
assert remover._has_side_effects(node) is False
- def test_template_literal_recurses(self):
+ def test_template_literal_recurses(self) -> None:
"""TemplateLiteral is not in _PURE_TYPES or _SIDE_EFFECT_TYPES, so it recurses."""
code = 'function f() { var x = `hello`; }'
result, changed = roundtrip(code, UnusedVariableRemover)
@@ -253,11 +252,8 @@ def test_template_literal_recurses(self):
assert changed is True
assert 'var x' not in result
- def test_empty_decls_list(self):
+ def test_empty_decls_list(self) -> None:
"""Line 81: VariableDeclaration with empty declarations list."""
- from pyjsclear.parser import parse
- from pyjsclear.traverser import traverse
-
ast = parse('function f() { var x = 1; }')
# Manually clear declarations to trigger the empty check
for node in ast['body']:
diff --git a/tests/unit/transforms/variable_renamer_test.py b/tests/unit/transforms/variable_renamer_test.py
index 7a45747..8e1e993 100644
--- a/tests/unit/transforms/variable_renamer_test.py
+++ b/tests/unit/transforms/variable_renamer_test.py
@@ -12,31 +12,31 @@
class TestBasicRenaming:
"""Tests for basic _0x identifier renaming."""
- def test_single_var_renamed(self):
+ def test_single_var_renamed(self) -> None:
code = 'function f() { var _0x1234 = 1; return _0x1234; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert '_0x1234' not in result
- def test_multiple_vars_renamed(self):
+ def test_multiple_vars_renamed(self) -> None:
code = 'function f() { var _0x1 = 1; var _0x2 = 2; return _0x1 + _0x2; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert '_0x' not in result
- def test_function_name_renamed(self):
+ def test_function_name_renamed(self) -> None:
code = 'function _0xabc() { return 1; } _0xabc();'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert '_0xabc' not in result
- def test_param_renamed(self):
+ def test_param_renamed(self) -> None:
code = 'function f(_0xdef) { return _0xdef; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert '_0xdef' not in result
- def test_const_let_renamed(self):
+ def test_const_let_renamed(self) -> None:
code = 'function f() { const _0x1 = 1; let _0x2 = 2; return _0x1 + _0x2; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
@@ -46,25 +46,25 @@ def test_const_let_renamed(self):
class TestHeuristicNaming:
"""Tests for heuristic-based name inference."""
- def test_require_fs_named(self):
+ def test_require_fs_named(self) -> None:
code = 'function f() { const _0x1 = require("fs"); _0x1.readFileSync("a"); }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'const fs = require("fs")' in result
- def test_require_path_named(self):
+ def test_require_path_named(self) -> None:
code = 'function f() { const _0x1 = require("path"); _0x1.join("a"); }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'const path = require("path")' in result
- def test_require_child_process_named(self):
+ def test_require_child_process_named(self) -> None:
code = 'function f() { const _0x1 = require("child_process"); _0x1.spawn("a"); }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'const cp = require("child_process")' in result
- def test_require_dedupe(self):
+ def test_require_dedupe(self) -> None:
"""Multiple require("fs") in same scope get fs, fs2, fs3."""
code = '''
function f() {
@@ -79,43 +79,43 @@ def test_require_dedupe(self):
assert 'const fs =' in result
assert 'const fs2 =' in result
- def test_array_literal_named(self):
+ def test_array_literal_named(self) -> None:
code = 'function f() { const _0x1 = []; _0x1.push(1); }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'const arr = []' in result
- def test_object_literal_named(self):
+ def test_object_literal_named(self) -> None:
code = 'function f() { const _0x1 = {}; _0x1.foo = 1; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'const obj = {}' in result or 'const obj =' in result
- def test_buffer_from_named(self):
+ def test_buffer_from_named(self) -> None:
code = 'function f() { const _0x1 = Buffer.from("abc"); return _0x1; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'const buf = Buffer.from' in result
- def test_json_parse_named(self):
+ def test_json_parse_named(self) -> None:
code = 'function f(s) { const _0x1 = JSON.parse(s); return _0x1; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'const data = JSON.parse' in result
- def test_new_date_named(self):
+ def test_new_date_named(self) -> None:
code = 'function f() { const _0x1 = new Date(); return _0x1; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'const date = new Date()' in result
- def test_loop_counter_named_i(self):
+ def test_loop_counter_named_i(self) -> None:
code = 'function f() { for (var _0x1 = 0; _0x1 < 10; _0x1++) { console.log(_0x1); } }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'var i = 0' in result or 'var i2 = 0' in result or 'let i = 0' in result
- def test_usage_based_fs_naming(self):
+ def test_usage_based_fs_naming(self) -> None:
"""Variable used with .existsSync should be named fs even without require init."""
code = '''
function f(_0xabc) {
@@ -127,7 +127,7 @@ def test_usage_based_fs_naming(self):
assert changed is True
assert 'function f(fs)' in result
- def test_require_with_path_sanitized(self):
+ def test_require_with_path_sanitized(self) -> None:
r"""require(".\lib\Foo.node") should not produce invalid identifiers."""
code = r'function f() { const _0x1 = require(".\\lib\\Foo.node"); return _0x1; }'
result, changed = roundtrip(code, VariableRenamer)
@@ -140,20 +140,20 @@ def test_require_with_path_sanitized(self):
class TestPreservation:
"""Tests for names that should NOT be renamed."""
- def test_non_0x_preserved(self):
+ def test_non_0x_preserved(self) -> None:
code = 'function f() { var foo = 1; return foo; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is False
assert 'foo' in result
- def test_mixed_names(self):
+ def test_mixed_names(self) -> None:
code = 'function f() { var foo = 1; var _0x1234 = 2; return foo + _0x1234; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert 'foo' in result
assert '_0x1234' not in result
- def test_no_0x_returns_false(self):
+ def test_no_0x_returns_false(self) -> None:
code = 'var x = 1;'
result, changed = roundtrip(code, VariableRenamer)
assert changed is False
@@ -162,14 +162,14 @@ def test_no_0x_returns_false(self):
class TestNoConflict:
"""Tests that renaming doesn't create naming conflicts."""
- def test_no_conflict_with_existing_names(self):
+ def test_no_conflict_with_existing_names(self) -> None:
code = 'function f() { var a = 1; var _0x1234 = 2; return a + _0x1234; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert '_0x1234' not in result
assert 'var a = 1' in result
- def test_require_name_conflict_resolved(self):
+ def test_require_name_conflict_resolved(self) -> None:
"""If 'fs' already exists, require("fs") should get 'fs2'."""
code = '''
function f() {
@@ -186,18 +186,18 @@ def test_require_name_conflict_resolved(self):
class TestNameGenerator:
"""Tests for the name generator function."""
- def test_generates_single_letters(self):
+ def test_generates_single_letters(self) -> None:
gen = _name_generator(set())
names = [next(gen) for _ in range(26)]
assert names == list('abcdefghijklmnopqrstuvwxyz')
- def test_generates_two_letter_after_single(self):
+ def test_generates_two_letter_after_single(self) -> None:
gen = _name_generator(set())
names = [next(gen) for _ in range(28)]
assert names[26] == 'aa'
assert names[27] == 'ab'
- def test_skips_reserved(self):
+ def test_skips_reserved(self) -> None:
gen = _name_generator({'a', 'c'})
assert next(gen) == 'b'
assert next(gen) == 'd'
@@ -206,7 +206,7 @@ def test_skips_reserved(self):
class TestNestedScopes:
"""Tests for renaming across nested scopes."""
- def test_nested_functions(self):
+ def test_nested_functions(self) -> None:
code = '''
function _0xabc1() {
var _0x1 = 1;
@@ -221,31 +221,31 @@ def test_nested_functions(self):
assert changed is True
assert '_0x' not in result
- def test_arrow_function_params(self):
+ def test_arrow_function_params(self) -> None:
code = 'var f = (_0xaaa, _0xbbb) => _0xaaa + _0xbbb;'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert '_0x' not in result
- def test_class_expression_name_renamed(self):
+ def test_class_expression_name_renamed(self) -> None:
code = 'var C = class _0xabc { static m() { return _0xabc; } };'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert '_0xabc' not in result
- def test_rest_param_renamed(self):
+ def test_rest_param_renamed(self) -> None:
code = 'function f(_0xaaa, ..._0xbbb) { return _0xbbb; }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert '_0x' not in result
- def test_catch_param_renamed(self):
+ def test_catch_param_renamed(self) -> None:
code = 'function f() { try { x(); } catch (_0xabc) { console.log(_0xabc); } }'
result, changed = roundtrip(code, VariableRenamer)
assert changed is True
assert '_0xabc' not in result
- def test_destructuring_duplicate_names_fixed(self):
+ def test_destructuring_duplicate_names_fixed(self) -> None:
"""Obfuscators can produce const [a, a, a] = x; — fix duplicates."""
code = 'function f(_0xabc) { const [_0xabc, _0xabc, _0xabc] = _0xabc; }'
result, changed = roundtrip(code, VariableRenamer)
diff --git a/tests/unit/transforms/xor_string_decode_test.py b/tests/unit/transforms/xor_string_decode_test.py
index 14beed7..a870c53 100644
--- a/tests/unit/transforms/xor_string_decode_test.py
+++ b/tests/unit/transforms/xor_string_decode_test.py
@@ -1,7 +1,5 @@
"""Tests for the XorStringDecoder transform."""
-import pytest
-
from pyjsclear.transforms.xor_string_decode import XorStringDecoder
from pyjsclear.transforms.xor_string_decode import _extract_numeric_array
from pyjsclear.transforms.xor_string_decode import _xor_decode
@@ -11,7 +9,7 @@
class TestXorDecode:
"""Tests for the _xor_decode helper."""
- def test_basic_xor(self):
+ def test_basic_xor(self) -> None:
"""XOR decode with known prefix and data."""
# Prefix: [1, 2, 3, 4], data XOR'd with prefix
prefix = [1, 2, 3, 4]
@@ -20,11 +18,11 @@ def test_basic_xor(self):
result = _xor_decode(encoded)
assert result == 'ABCD'
- def test_too_short_returns_none(self):
+ def test_too_short_returns_none(self) -> None:
assert _xor_decode([1, 2, 3]) is None
assert _xor_decode([1, 2, 3, 4]) is None
- def test_invalid_utf8_returns_none(self):
+ def test_invalid_utf8_returns_none(self) -> None:
result = _xor_decode([0, 0, 0, 0, 0xFF, 0xFE])
assert result is None
@@ -32,7 +30,7 @@ def test_invalid_utf8_returns_none(self):
class TestExtractNumericArray:
"""Tests for _extract_numeric_array helper."""
- def test_valid_array(self):
+ def test_valid_array(self) -> None:
node = {
'type': 'ArrayExpression',
'elements': [
@@ -42,18 +40,18 @@ def test_valid_array(self):
}
assert _extract_numeric_array(node) == [10, 20]
- def test_non_array_returns_none(self):
+ def test_non_array_returns_none(self) -> None:
assert _extract_numeric_array({'type': 'Literal', 'value': 1}) is None
assert _extract_numeric_array(None) is None
- def test_out_of_range_returns_none(self):
+ def test_out_of_range_returns_none(self) -> None:
node = {
'type': 'ArrayExpression',
'elements': [{'type': 'Literal', 'value': 300, 'raw': '300'}],
}
assert _extract_numeric_array(node) is None
- def test_non_numeric_element_returns_none(self):
+ def test_non_numeric_element_returns_none(self) -> None:
node = {
'type': 'ArrayExpression',
'elements': [{'type': 'Literal', 'value': 'str', 'raw': '"str"'}],
@@ -64,11 +62,11 @@ def test_non_numeric_element_returns_none(self):
class TestXorStringDecoderTransform:
"""Tests for the full XorStringDecoder transform."""
- def test_no_decoder_returns_false(self):
+ def test_no_decoder_returns_false(self) -> None:
result, changed = roundtrip('var x = 1;', XorStringDecoder)
assert changed is False
- def test_decoder_detected_and_inlined(self):
+ def test_decoder_detected_and_inlined(self) -> None:
"""Integration test: XOR decoder function + call site should be resolved."""
# Build a XOR-encoded byte array for "test"
prefix = [0x10, 0x20, 0x30, 0x40]
@@ -91,7 +89,7 @@ def test_decoder_detected_and_inlined(self):
assert changed is True
assert 'obj.test' in result or 'obj["test"]' in result or '"test"' in result
- def test_standalone_identifier_replaced(self):
+ def test_standalone_identifier_replaced(self) -> None:
"""Standalone use of decoded var (e.g., require(_0xVar)) gets string literal."""
prefix = [0x10, 0x20, 0x30, 0x40]
message = b'fs'
@@ -113,7 +111,7 @@ def test_standalone_identifier_replaced(self):
assert changed is True
assert 'require("fs")' in result
- def test_dead_declaration_removed(self):
+ def test_dead_declaration_removed(self) -> None:
"""After inlining, the var _0xResult = decoder(...) should be removed."""
prefix = [0x10, 0x20, 0x30, 0x40]
message = b'test'
@@ -136,7 +134,7 @@ def test_dead_declaration_removed(self):
# _0xResult should be removed (no remaining references)
assert '_0xResult' not in result
- def test_non_valid_identifier_uses_string_literal(self):
+ def test_non_valid_identifier_uses_string_literal(self) -> None:
"""Decoded string with non-identifier chars stays as string literal."""
prefix = [0x10, 0x20, 0x30, 0x40]
message = b'a-b'
diff --git a/tests/unit/traverser_test.py b/tests/unit/traverser_test.py
index 4660df0..00b0102 100644
--- a/tests/unit/traverser_test.py
+++ b/tests/unit/traverser_test.py
@@ -647,3 +647,71 @@ def test_traverse_unknown_node_type(self):
assert 'Program' in visited
assert 'UnknownCustomNode' in visited
assert 'Identifier' in visited
+
+
+# ===========================================================================
+# Deep AST tests (exercises iterative fallback at depth > 500)
+# ===========================================================================
+
+
+def _make_deep_ast(depth):
+ """Build an AST of nested IfStatements to the given depth.
+
+ Structure: Program -> nested IfStatements each containing a BlockStatement
+ with a single child, down to a Literal leaf.
+ """
+ leaf = {'type': 'Literal', 'value': 42, 'raw': '42'}
+ node = leaf
+ for _ in range(depth):
+ node = {
+ 'type': 'IfStatement',
+ 'test': {'type': 'Literal', 'value': True, 'raw': 'true'},
+ 'consequent': {
+ 'type': 'BlockStatement',
+ 'body': [{'type': 'ExpressionStatement', 'expression': node}],
+ },
+ 'alternate': None,
+ }
+ return {'type': 'Program', 'sourceType': 'script', 'body': [node]}
+
+
+class TestDeepASTTraversal:
+ """Verify that traverse() and simple_traverse() handle ASTs deeper than 500."""
+
+ def test_traverse_deep_ast(self):
+ ast = _make_deep_ast(600)
+ types = []
+ traverse(ast, {'enter': lambda n, p, k, i: types.append(n['type'])})
+ assert 'Program' in types
+ assert 'Literal' in types
+ # Should have visited all IfStatements
+ assert types.count('IfStatement') == 600
+
+ def test_simple_traverse_deep_ast(self):
+ ast = _make_deep_ast(600)
+ types = []
+ simple_traverse(ast, lambda n, p: types.append(n['type']))
+ assert 'Program' in types
+ assert 'Literal' in types
+ assert types.count('IfStatement') == 600
+
+ def test_traverse_deep_ast_with_exit(self):
+ """Exit-callback path uses fully iterative traversal."""
+ ast = _make_deep_ast(600)
+ exit_types = []
+ traverse(ast, {'exit': lambda n, p, k, i: exit_types.append(n['type'])})
+ assert 'Program' in exit_types
+ assert 'Literal' in exit_types
+
+ def test_traverse_deep_ast_removal(self):
+ """REMOVE works correctly across the recursive/iterative boundary."""
+ ast = _make_deep_ast(600)
+ removed_count = []
+
+ def enter(node, parent, key, index):
+ if node['type'] == 'Literal' and node.get('value') == 42:
+ removed_count.append(1)
+ return REMOVE
+
+ traverse(ast, {'enter': enter})
+ assert len(removed_count) > 0
diff --git a/tests/unit/utils/ast_helpers_test.py b/tests/unit/utils/ast_helpers_test.py
index de169f8..7fa16e3 100644
--- a/tests/unit/utils/ast_helpers_test.py
+++ b/tests/unit/utils/ast_helpers_test.py
@@ -3,8 +3,6 @@
import copy
import math
-import pytest
-
from pyjsclear.utils.ast_helpers import _CHILD_KEYS
from pyjsclear.utils.ast_helpers import deep_copy
from pyjsclear.utils.ast_helpers import get_child_keys
@@ -32,29 +30,29 @@
class TestDeepCopy:
- def test_basic_node(self):
+ def test_basic_node(self) -> None:
node = {'type': 'Literal', 'value': 42, 'raw': '42'}
result = deep_copy(node)
assert result == node
assert result is not node
- def test_nested_node_independence(self):
+ def test_nested_node_independence(self) -> None:
inner = {'type': 'Identifier', 'name': 'x'}
outer = {'type': 'ExpressionStatement', 'expression': inner}
result = deep_copy(outer)
result['expression']['name'] = 'y'
assert inner['name'] == 'x'
- def test_list_children(self):
+ def test_list_children(self) -> None:
node = {'type': 'BlockStatement', 'body': [{'type': 'EmptyStatement'}]}
result = deep_copy(node)
result['body'].append({'type': 'EmptyStatement'})
assert len(node['body']) == 1
- def test_none_passthrough(self):
+ def test_none_passthrough(self) -> None:
assert deep_copy(None) is None
- def test_empty_dict(self):
+ def test_empty_dict(self) -> None:
assert deep_copy({}) == {}
@@ -64,74 +62,74 @@ def test_empty_dict(self):
class TestIsLiteral:
- def test_string_literal(self):
+ def test_string_literal(self) -> None:
assert is_literal({'type': 'Literal', 'value': 'hello', 'raw': '"hello"'})
- def test_numeric_literal(self):
+ def test_numeric_literal(self) -> None:
assert is_literal({'type': 'Literal', 'value': 42, 'raw': '42'})
- def test_not_a_dict(self):
+ def test_not_a_dict(self) -> None:
assert not is_literal('Literal')
assert not is_literal(42)
assert not is_literal(None)
assert not is_literal([])
- def test_wrong_type(self):
+ def test_wrong_type(self) -> None:
assert not is_literal({'type': 'Identifier', 'name': 'x'})
- def test_missing_type(self):
+ def test_missing_type(self) -> None:
assert not is_literal({'value': 42})
- def test_empty_dict(self):
+ def test_empty_dict(self) -> None:
assert not is_literal({})
class TestIsIdentifier:
- def test_basic(self):
+ def test_basic(self) -> None:
assert is_identifier({'type': 'Identifier', 'name': 'foo'})
- def test_not_identifier(self):
+ def test_not_identifier(self) -> None:
assert not is_identifier({'type': 'Literal', 'value': 1})
- def test_non_dict(self):
+ def test_non_dict(self) -> None:
assert not is_identifier('Identifier')
assert not is_identifier(None)
class TestIsStringLiteral:
- def test_string(self):
+ def test_string(self) -> None:
assert is_string_literal({'type': 'Literal', 'value': 'hello'})
- def test_empty_string(self):
+ def test_empty_string(self) -> None:
assert is_string_literal({'type': 'Literal', 'value': ''})
- def test_number_not_string(self):
+ def test_number_not_string(self) -> None:
assert not is_string_literal({'type': 'Literal', 'value': 42})
- def test_bool_not_string(self):
+ def test_bool_not_string(self) -> None:
assert not is_string_literal({'type': 'Literal', 'value': True})
- def test_none_not_string(self):
+ def test_none_not_string(self) -> None:
assert not is_string_literal({'type': 'Literal', 'value': None, 'raw': 'null'})
class TestIsNumericLiteral:
- def test_int(self):
+ def test_int(self) -> None:
assert is_numeric_literal({'type': 'Literal', 'value': 42})
- def test_float(self):
+ def test_float(self) -> None:
assert is_numeric_literal({'type': 'Literal', 'value': 3.14})
- def test_zero(self):
+ def test_zero(self) -> None:
assert is_numeric_literal({'type': 'Literal', 'value': 0})
- def test_negative(self):
+ def test_negative(self) -> None:
assert is_numeric_literal({'type': 'Literal', 'value': -1})
- def test_string_not_numeric(self):
+ def test_string_not_numeric(self) -> None:
assert not is_numeric_literal({'type': 'Literal', 'value': '42'})
- def test_bool_not_numeric(self):
+ def test_bool_not_numeric(self) -> None:
# In Python bool is subclass of int, but isinstance(True, bool) is True
# and the function checks for (int, float). Since bool IS int, this returns True.
# However, the check order matters: is_boolean_literal checks bool first.
@@ -140,45 +138,45 @@ def test_bool_not_numeric(self):
class TestIsBooleanLiteral:
- def test_true(self):
+ def test_true(self) -> None:
assert is_boolean_literal({'type': 'Literal', 'value': True})
- def test_false(self):
+ def test_false(self) -> None:
assert is_boolean_literal({'type': 'Literal', 'value': False})
- def test_int_not_bool(self):
+ def test_int_not_bool(self) -> None:
assert not is_boolean_literal({'type': 'Literal', 'value': 1})
- def test_string_not_bool(self):
+ def test_string_not_bool(self) -> None:
assert not is_boolean_literal({'type': 'Literal', 'value': 'true'})
class TestIsNullLiteral:
- def test_null(self):
+ def test_null(self) -> None:
assert is_null_literal({'type': 'Literal', 'value': None, 'raw': 'null'})
- def test_none_without_raw(self):
+ def test_none_without_raw(self) -> None:
# Must have raw == 'null' to be considered null literal
assert not is_null_literal({'type': 'Literal', 'value': None})
- def test_none_wrong_raw(self):
+ def test_none_wrong_raw(self) -> None:
assert not is_null_literal({'type': 'Literal', 'value': None, 'raw': 'undefined'})
- def test_non_literal(self):
+ def test_non_literal(self) -> None:
assert not is_null_literal({'type': 'Identifier', 'name': 'null'})
class TestIsUndefined:
- def test_undefined_identifier(self):
+ def test_undefined_identifier(self) -> None:
assert is_undefined({'type': 'Identifier', 'name': 'undefined'})
- def test_other_identifier(self):
+ def test_other_identifier(self) -> None:
assert not is_undefined({'type': 'Identifier', 'name': 'null'})
- def test_literal_not_undefined(self):
+ def test_literal_not_undefined(self) -> None:
assert not is_undefined({'type': 'Literal', 'value': None, 'raw': 'null'})
- def test_non_dict(self):
+ def test_non_dict(self) -> None:
assert not is_undefined(None)
@@ -188,33 +186,33 @@ def test_non_dict(self):
class TestGetLiteralValue:
- def test_string_value(self):
+ def test_string_value(self) -> None:
val, ok = get_literal_value({'type': 'Literal', 'value': 'hello'})
assert ok is True
assert val == 'hello'
- def test_numeric_value(self):
+ def test_numeric_value(self) -> None:
val, ok = get_literal_value({'type': 'Literal', 'value': 42})
assert ok is True
assert val == 42
- def test_none_value_null(self):
+ def test_none_value_null(self) -> None:
val, ok = get_literal_value({'type': 'Literal', 'value': None, 'raw': 'null'})
assert ok is True
assert val is None
- def test_non_literal(self):
+ def test_non_literal(self) -> None:
val, ok = get_literal_value({'type': 'Identifier', 'name': 'x'})
assert ok is False
assert val is None
- def test_literal_missing_value_key(self):
+ def test_literal_missing_value_key(self) -> None:
# A Literal node without a 'value' key; .get returns None
val, ok = get_literal_value({'type': 'Literal', 'raw': '42'})
assert ok is True
assert val is None
- def test_bool_value(self):
+ def test_bool_value(self) -> None:
val, ok = get_literal_value({'type': 'Literal', 'value': True})
assert ok is True
assert val is True
@@ -226,43 +224,43 @@ def test_bool_value(self):
class TestMakeLiteral:
- def test_integer(self):
+ def test_integer(self) -> None:
node = make_literal(42)
assert node == {'type': 'Literal', 'value': 42, 'raw': '42'}
- def test_float_whole_number(self):
+ def test_float_whole_number(self) -> None:
# Whole floats are rendered as int strings
node = make_literal(3.0)
assert node['raw'] == '3'
- def test_float_fractional(self):
+ def test_float_fractional(self) -> None:
node = make_literal(3.14)
assert node['raw'] == '3.14'
- def test_negative_zero_float(self):
+ def test_negative_zero_float(self) -> None:
# -0.0 in Python: str(-0.0) is '-0.0', and -0.0 == 0 is True
# but str(-0.0).startswith('-') triggers the else branch
node = make_literal(-0.0)
assert node['raw'] == '-0.0'
- def test_boolean_true(self):
+ def test_boolean_true(self) -> None:
node = make_literal(True)
assert node == {'type': 'Literal', 'value': True, 'raw': 'true'}
- def test_boolean_false(self):
+ def test_boolean_false(self) -> None:
node = make_literal(False)
assert node == {'type': 'Literal', 'value': False, 'raw': 'false'}
- def test_null(self):
+ def test_null(self) -> None:
node = make_literal(None)
assert node == {'type': 'Literal', 'value': None, 'raw': 'null'}
- def test_simple_string(self):
+ def test_simple_string(self) -> None:
node = make_literal('hello')
assert node['value'] == 'hello'
assert node['raw'] == '"hello"'
- def test_string_with_double_quotes(self):
+ def test_string_with_double_quotes(self) -> None:
"""Bug #2: String containing double quotes. repr('say "hi"') gives
The raw value must properly escape inner double quotes."""
node = make_literal('say "hi"')
@@ -271,15 +269,15 @@ def test_string_with_double_quotes(self):
assert raw.endswith('"')
assert raw == '"say \\"hi\\""'
- def test_custom_raw(self):
+ def test_custom_raw(self) -> None:
node = make_literal(42, raw='0x2A')
assert node == {'type': 'Literal', 'value': 42, 'raw': '0x2A'}
- def test_negative_int(self):
+ def test_negative_int(self) -> None:
node = make_literal(-5)
assert node['raw'] == '-5'
- def test_large_float(self):
+ def test_large_float(self) -> None:
node = make_literal(1e10)
# 1e10 == int(1e10) and not negative zero, so raw = str(int(1e10))
assert node['raw'] == '10000000000'
@@ -290,7 +288,7 @@ def test_large_float(self):
# then checks if raw starts with '"' and if not, re-wraps.
# This section documents the current behavior for edge cases.
- def test_bug2_string_with_single_quotes(self):
+ def test_bug2_string_with_single_quotes(self) -> None:
"""Bug #2: Strings containing single quotes. repr() would use double quotes
for the Python repr, so replacing ' with " produces unexpected results.
@@ -305,7 +303,7 @@ def test_bug2_string_with_single_quotes(self):
assert raw.startswith('"')
assert raw.endswith('"')
- def test_bug2_backslash_string(self):
+ def test_bug2_backslash_string(self) -> None:
"""Bug #2: Strings with backslashes. repr() produces escape sequences
which then get the quote replacement applied.
@@ -320,7 +318,7 @@ def test_bug2_backslash_string(self):
# but repr gives us Python escaping, not JS escaping
assert '\\\\' in raw # double-escaped backslash from repr
- def test_bug2_unicode_string(self):
+ def test_bug2_unicode_string(self) -> None:
"""Bug #2: Unicode strings. repr() may produce \\uXXXX or the literal char
depending on the character. For printable unicode, repr includes it literally."""
node = make_literal('\u00e9') # e-acute
@@ -328,7 +326,7 @@ def test_bug2_unicode_string(self):
assert raw.startswith('"')
assert raw.endswith('"')
- def test_bug2_newline_string(self):
+ def test_bug2_newline_string(self) -> None:
"""Bug #2: Strings with newlines. repr() produces \\n which is valid JS too,
so this case happens to work."""
node = make_literal('line1\nline2')
@@ -337,13 +335,13 @@ def test_bug2_newline_string(self):
assert raw.endswith('"')
assert '\\n' in raw
- def test_bug2_tab_string(self):
+ def test_bug2_tab_string(self) -> None:
"""Bug #2: Strings with tabs. repr() produces \\t which is valid JS."""
node = make_literal('a\tb')
raw = node['raw']
assert '\\t' in raw
- def test_bug2_string_with_both_quote_types(self):
+ def test_bug2_string_with_both_quote_types(self) -> None:
"""Bug #2: String containing both single and double quotes.
repr() for a string with both quotes uses single-quote wrapping and
escapes the single quotes. The replace("'", '"') then converts ALL
@@ -353,12 +351,12 @@ def test_bug2_string_with_both_quote_types(self):
assert raw.startswith('"')
assert raw.endswith('"')
- def test_bug2_empty_string(self):
+ def test_bug2_empty_string(self) -> None:
"""Bug #2: Empty string should produce '""'."""
node = make_literal('')
assert node['raw'] == '""'
- def test_bug2_null_byte_string(self):
+ def test_bug2_null_byte_string(self) -> None:
"""Bug #2: String with null byte. repr() produces \\x00 which is NOT
valid JS (JS uses \\0 or \\u0000). This documents the current behavior."""
node = make_literal('a\x00b')
@@ -374,10 +372,10 @@ def test_bug2_null_byte_string(self):
class TestMakeIdentifier:
- def test_basic(self):
+ def test_basic(self) -> None:
assert make_identifier('foo') == {'type': 'Identifier', 'name': 'foo'}
- def test_underscore(self):
+ def test_underscore(self) -> None:
assert make_identifier('_') == {'type': 'Identifier', 'name': '_'}
@@ -387,7 +385,7 @@ def test_underscore(self):
class TestMakeExpressionStatement:
- def test_wraps_expression(self):
+ def test_wraps_expression(self) -> None:
expr = {'type': 'Literal', 'value': 42, 'raw': '42'}
result = make_expression_statement(expr)
assert result['type'] == 'ExpressionStatement'
@@ -400,11 +398,11 @@ def test_wraps_expression(self):
class TestMakeBlockStatement:
- def test_empty_body(self):
+ def test_empty_body(self) -> None:
result = make_block_statement([])
assert result == {'type': 'BlockStatement', 'body': []}
- def test_with_statements(self):
+ def test_with_statements(self) -> None:
stmts = [{'type': 'EmptyStatement'}, {'type': 'EmptyStatement'}]
result = make_block_statement(stmts)
assert result['body'] is stmts
@@ -417,7 +415,7 @@ def test_with_statements(self):
class TestMakeVarDeclaration:
- def test_var_no_init(self):
+ def test_var_no_init(self) -> None:
result = make_var_declaration('x')
assert result['type'] == 'VariableDeclaration'
assert result['kind'] == 'var'
@@ -427,13 +425,13 @@ def test_var_no_init(self):
assert decl['id'] == {'type': 'Identifier', 'name': 'x'}
assert decl['init'] is None
- def test_let_with_init(self):
+ def test_let_with_init(self) -> None:
init = {'type': 'Literal', 'value': 5, 'raw': '5'}
result = make_var_declaration('y', init=init, kind='let')
assert result['kind'] == 'let'
assert result['declarations'][0]['init'] is init
- def test_const(self):
+ def test_const(self) -> None:
result = make_var_declaration('Z', kind='const')
assert result['kind'] == 'const'
@@ -444,40 +442,40 @@ def test_const(self):
class TestIsValidIdentifier:
- def test_simple(self):
+ def test_simple(self) -> None:
assert is_valid_identifier('foo')
- def test_underscore_prefix(self):
+ def test_underscore_prefix(self) -> None:
assert is_valid_identifier('_private')
- def test_dollar_prefix(self):
+ def test_dollar_prefix(self) -> None:
assert is_valid_identifier('$jquery')
- def test_digits_allowed_after_first(self):
+ def test_digits_allowed_after_first(self) -> None:
assert is_valid_identifier('x1')
- def test_starts_with_digit(self):
+ def test_starts_with_digit(self) -> None:
assert not is_valid_identifier('1abc')
- def test_empty_string(self):
+ def test_empty_string(self) -> None:
assert not is_valid_identifier('')
- def test_none(self):
+ def test_none(self) -> None:
assert not is_valid_identifier(None)
- def test_non_string(self):
+ def test_non_string(self) -> None:
assert not is_valid_identifier(42)
- def test_hyphen(self):
+ def test_hyphen(self) -> None:
assert not is_valid_identifier('foo-bar')
- def test_single_dollar(self):
+ def test_single_dollar(self) -> None:
assert is_valid_identifier('$')
- def test_single_underscore(self):
+ def test_single_underscore(self) -> None:
assert is_valid_identifier('_')
- def test_reserved_words_pass_regex(self):
+ def test_reserved_words_pass_regex(self) -> None:
# The function only does regex check, not reserved word check
assert is_valid_identifier('if')
assert is_valid_identifier('return')
@@ -490,25 +488,25 @@ def test_reserved_words_pass_regex(self):
class TestChildKeys:
- def test_known_node_type(self):
+ def test_known_node_type(self) -> None:
node = {'type': 'BinaryExpression', 'left': {}, 'right': {}, 'operator': '+'}
assert get_child_keys(node) == ('left', 'right')
- def test_literal_has_no_children(self):
+ def test_literal_has_no_children(self) -> None:
assert get_child_keys({'type': 'Literal', 'value': 1}) == ()
- def test_identifier_has_no_children(self):
+ def test_identifier_has_no_children(self) -> None:
assert get_child_keys({'type': 'Identifier', 'name': 'x'}) == ()
- def test_non_dict_returns_empty(self):
+ def test_non_dict_returns_empty(self) -> None:
assert get_child_keys('not a node') == ()
assert get_child_keys(None) == ()
assert get_child_keys(42) == ()
- def test_missing_type_returns_empty(self):
+ def test_missing_type_returns_empty(self) -> None:
assert get_child_keys({'value': 42}) == ()
- def test_unknown_node_type_fallback(self):
+ def test_unknown_node_type_fallback(self) -> None:
# Unknown type falls back to heuristic: keys with dict/list values not in _SKIP_KEYS
node = {'type': 'UnknownThing', 'body': [], 'extra': {}, 'name': 'test'}
keys = get_child_keys(node)
@@ -517,13 +515,13 @@ def test_unknown_node_type_fallback(self):
# 'name' is in _SKIP_KEYS so should not appear
assert 'name' not in keys
- def test_fallback_skips_type_key(self):
+ def test_fallback_skips_type_key(self) -> None:
node = {'type': 'CustomNode', 'child': {'type': 'Literal'}}
keys = get_child_keys(node)
assert 'type' not in keys
assert 'child' in keys
- def test_all_known_types_in_child_keys(self):
+ def test_all_known_types_in_child_keys(self) -> None:
# Sanity check: every value in _CHILD_KEYS is a tuple
for node_type, keys in _CHILD_KEYS.items():
assert isinstance(keys, tuple), f'{node_type} keys is not a tuple'
@@ -535,7 +533,7 @@ def test_all_known_types_in_child_keys(self):
class TestReplaceIdentifiers:
- def test_simple_replacement(self):
+ def test_simple_replacement(self) -> None:
node = {
'type': 'BinaryExpression',
'operator': '+',
@@ -547,7 +545,7 @@ def test_simple_replacement(self):
assert node['left'] == {'type': 'Literal', 'value': 1, 'raw': '1'}
assert node['right'] == {'type': 'Identifier', 'name': 'b'}
- def test_replacement_is_deep_copied(self):
+ def test_replacement_is_deep_copied(self) -> None:
replacement = {'type': 'Literal', 'value': 1, 'raw': '1'}
node = {
'type': 'ExpressionStatement',
@@ -558,7 +556,7 @@ def test_replacement_is_deep_copied(self):
replacement['value'] = 999
assert node['expression']['value'] == 1
- def test_skips_non_computed_member_property(self):
+ def test_skips_non_computed_member_property(self) -> None:
# obj.foo -- 'foo' should NOT be replaced even if in param_map
node = {
'type': 'MemberExpression',
@@ -576,7 +574,7 @@ def test_skips_non_computed_member_property(self):
# property should NOT be replaced (non-computed)
assert node['property']['name'] == 'foo'
- def test_replaces_computed_member_property(self):
+ def test_replaces_computed_member_property(self) -> None:
# obj[foo] -- 'foo' SHOULD be replaced
node = {
'type': 'MemberExpression',
@@ -588,7 +586,7 @@ def test_replaces_computed_member_property(self):
replace_identifiers(node, param_map)
assert node['property'] == {'type': 'Literal', 'value': 'bar', 'raw': '"bar"'}
- def test_replaces_in_array_children(self):
+ def test_replaces_in_array_children(self) -> None:
node = {
'type': 'ArrayExpression',
'elements': [
@@ -602,7 +600,7 @@ def test_replaces_in_array_children(self):
assert node['elements'][0] == {'type': 'Literal', 'value': 1, 'raw': '1'}
assert node['elements'][1] == {'type': 'Identifier', 'name': 'b'}
- def test_recursive_replacement(self):
+ def test_recursive_replacement(self) -> None:
node = {
'type': 'ExpressionStatement',
'expression': {
@@ -616,18 +614,18 @@ def test_recursive_replacement(self):
replace_identifiers(node, param_map)
assert node['expression']['left'] == {'type': 'Literal', 'value': 10, 'raw': '10'}
- def test_non_dict_input_noop(self):
+ def test_non_dict_input_noop(self) -> None:
# Should not raise
replace_identifiers(None, {'x': {'type': 'Literal'}})
replace_identifiers('string', {'x': {'type': 'Literal'}})
replace_identifiers([], {'x': {'type': 'Literal'}})
- def test_no_type_key_noop(self):
+ def test_no_type_key_noop(self) -> None:
node = {'value': 42}
replace_identifiers(node, {'value': {'type': 'Literal'}})
assert node == {'value': 42}
- def test_child_is_none_skipped(self):
+ def test_child_is_none_skipped(self) -> None:
node = {
'type': 'ReturnStatement',
'argument': None,
@@ -635,7 +633,7 @@ def test_child_is_none_skipped(self):
# Should not raise
replace_identifiers(node, {'x': {'type': 'Literal'}})
- def test_nested_list_with_non_identifier_dicts(self):
+ def test_nested_list_with_non_identifier_dicts(self) -> None:
# Items in a list that are dicts with 'type' but not Identifier should recurse
node = {
'type': 'ArrayExpression',
@@ -659,69 +657,69 @@ def test_nested_list_with_non_identifier_dicts(self):
class TestNodesEqual:
- def test_identical_literals(self):
+ def test_identical_literals(self) -> None:
a = {'type': 'Literal', 'value': 42, 'raw': '42'}
b = {'type': 'Literal', 'value': 42, 'raw': '42'}
assert nodes_equal(a, b)
- def test_different_values(self):
+ def test_different_values(self) -> None:
a = {'type': 'Literal', 'value': 42, 'raw': '42'}
b = {'type': 'Literal', 'value': 43, 'raw': '43'}
assert not nodes_equal(a, b)
- def test_ignores_position_info(self):
+ def test_ignores_position_info(self) -> None:
a = {'type': 'Literal', 'value': 1, 'raw': '1', 'start': 0, 'end': 1}
b = {'type': 'Literal', 'value': 1, 'raw': '1', 'start': 50, 'end': 51}
assert nodes_equal(a, b)
- def test_ignores_loc(self):
+ def test_ignores_loc(self) -> None:
a = {'type': 'Identifier', 'name': 'x', 'loc': {'start': {'line': 1}}}
b = {'type': 'Identifier', 'name': 'x', 'loc': {'start': {'line': 5}}}
assert nodes_equal(a, b)
- def test_ignores_range(self):
+ def test_ignores_range(self) -> None:
a = {'type': 'Identifier', 'name': 'x', 'range': [0, 1]}
b = {'type': 'Identifier', 'name': 'x', 'range': [10, 11]}
assert nodes_equal(a, b)
- def test_different_types(self):
+ def test_different_types(self) -> None:
a = {'type': 'Literal', 'value': 1}
b = {'type': 'Identifier', 'name': '1'}
assert not nodes_equal(a, b)
- def test_list_equality(self):
+ def test_list_equality(self) -> None:
a = [{'type': 'Literal', 'value': 1}, {'type': 'Literal', 'value': 2}]
b = [{'type': 'Literal', 'value': 1}, {'type': 'Literal', 'value': 2}]
assert nodes_equal(a, b)
- def test_list_different_length(self):
+ def test_list_different_length(self) -> None:
a = [{'type': 'Literal', 'value': 1}]
b = [{'type': 'Literal', 'value': 1}, {'type': 'Literal', 'value': 2}]
assert not nodes_equal(a, b)
- def test_list_different_order(self):
+ def test_list_different_order(self) -> None:
a = [{'type': 'Literal', 'value': 1}, {'type': 'Literal', 'value': 2}]
b = [{'type': 'Literal', 'value': 2}, {'type': 'Literal', 'value': 1}]
assert not nodes_equal(a, b)
- def test_scalar_equality(self):
+ def test_scalar_equality(self) -> None:
assert nodes_equal(42, 42)
assert nodes_equal('hello', 'hello')
assert not nodes_equal(42, 43)
assert not nodes_equal('a', 'b')
- def test_type_mismatch(self):
+ def test_type_mismatch(self) -> None:
# type(a) != type(b) returns False
assert not nodes_equal(42, '42')
assert not nodes_equal({}, [])
assert not nodes_equal(None, {})
- def test_dict_extra_key(self):
+ def test_dict_extra_key(self) -> None:
a = {'type': 'Literal', 'value': 1}
b = {'type': 'Literal', 'value': 1, 'extra': True}
assert not nodes_equal(a, b)
- def test_nested_structures(self):
+ def test_nested_structures(self) -> None:
a = {
'type': 'BinaryExpression',
'operator': '+',
@@ -733,7 +731,7 @@ def test_nested_structures(self):
b['right']['value'] = 3
assert not nodes_equal(a, b)
- def test_nested_with_position_ignored(self):
+ def test_nested_with_position_ignored(self) -> None:
a = {
'type': 'BinaryExpression',
'operator': '+',
@@ -748,15 +746,15 @@ def test_nested_with_position_ignored(self):
}
assert nodes_equal(a, b)
- def test_empty_dicts(self):
+ def test_empty_dicts(self) -> None:
assert nodes_equal({}, {})
- def test_empty_lists(self):
+ def test_empty_lists(self) -> None:
assert nodes_equal([], [])
- def test_none_values(self):
+ def test_none_values(self) -> None:
assert nodes_equal(None, None)
- def test_bool_equality(self):
+ def test_bool_equality(self) -> None:
assert nodes_equal(True, True)
assert not nodes_equal(True, False)
diff --git a/tests/unit/utils/string_decoders_test.py b/tests/unit/utils/string_decoders_test.py
index 50e96f1..1ddc8b9 100644
--- a/tests/unit/utils/string_decoders_test.py
+++ b/tests/unit/utils/string_decoders_test.py
@@ -16,19 +16,19 @@
class TestDecoderType:
- def test_basic_value(self):
+ def test_basic_value(self) -> None:
assert DecoderType.BASIC.value == 'basic'
- def test_base64_value(self):
+ def test_base64_value(self) -> None:
assert DecoderType.BASE_64.value == 'base64'
- def test_rc4_value(self):
+ def test_rc4_value(self) -> None:
assert DecoderType.RC4.value == 'rc4'
- def test_enum_members(self):
+ def test_enum_members(self) -> None:
assert set(DecoderType) == {DecoderType.BASIC, DecoderType.BASE_64, DecoderType.RC4}
- def test_lookup_by_value(self):
+ def test_lookup_by_value(self) -> None:
assert DecoderType('basic') is DecoderType.BASIC
assert DecoderType('base64') is DecoderType.BASE_64
assert DecoderType('rc4') is DecoderType.RC4
@@ -40,44 +40,44 @@ def test_lookup_by_value(self):
class TestBase64Transform:
- def test_empty_string(self):
+ def test_empty_string(self) -> None:
assert base64_transform('') == ''
- def test_non_standard_alphabet(self):
+ def test_non_standard_alphabet(self) -> None:
"""Standard base64 'aGVsbG8=' decodes to 'hello', but the custom
alphabet reverses case mapping so the result must differ."""
result = base64_transform('aGVsbG8=')
assert result != 'hello'
- def test_decode_known_4char_group(self):
+ def test_decode_known_4char_group(self) -> None:
# 'abcd' -> indices 0,1,2,3 in custom alphabet -> bytes [0, 16, 131]
assert [ord(c) for c in base64_transform('abcd')] == [0, 16, 131]
- def test_decode_uppercase_group(self):
+ def test_decode_uppercase_group(self) -> None:
# 'ABCD' -> indices 26,27,28,29 -> bytes [105, 183, 29]
result = base64_transform('ABCD')
assert [ord(c) for c in result] == [105, 183, 29]
- def test_decode_mixed_case(self):
+ def test_decode_mixed_case(self) -> None:
assert [ord(c) for c in base64_transform('aBcD')] == [1, 176, 157]
- def test_decode_two_groups(self):
+ def test_decode_two_groups(self) -> None:
# 8 chars = two 4-char groups = 6 decoded bytes
assert [ord(c) for c in base64_transform('abcdefgh')] == [0, 16, 131, 16, 81, 135]
- def test_padding_double_equals(self):
+ def test_padding_double_equals(self) -> None:
# 'ab==' has one meaningful 6-bit pair
assert [ord(c) for c in base64_transform('ab==')] == [0, 32, 64]
- def test_padding_single_equals(self):
+ def test_padding_single_equals(self) -> None:
assert [ord(c) for c in base64_transform('abc=')] == [0, 16, 192]
- def test_invalid_chars_are_skipped(self):
+ def test_invalid_chars_are_skipped(self) -> None:
"""Characters not in the alphabet should be silently ignored."""
# '$' and '!' are not in the alphabet
assert base64_transform('$ab!cd') == base64_transform('abcd')
- def test_returns_string(self):
+ def test_returns_string(self) -> None:
result = base64_transform('abcd')
assert isinstance(result, str)
@@ -88,24 +88,24 @@ def test_returns_string(self):
class TestStringDecoder:
- def test_get_string_raises_not_implemented(self):
+ def test_get_string_raises_not_implemented(self) -> None:
decoder = StringDecoder(['a', 'b'], 0)
with pytest.raises(NotImplementedError):
decoder.get_string(0)
- def test_get_string_for_rotation_raises_on_first_call(self):
+ def test_get_string_for_rotation_raises_on_first_call(self) -> None:
decoder = BasicStringDecoder(['hello'], 0)
with pytest.raises(RuntimeError, match='First call'):
decoder.get_string_for_rotation(0)
- def test_get_string_for_rotation_works_on_second_call(self):
+ def test_get_string_for_rotation_works_on_second_call(self) -> None:
decoder = BasicStringDecoder(['hello', 'world'], 0)
with pytest.raises(RuntimeError):
decoder.get_string_for_rotation(0)
# Second call should succeed
assert decoder.get_string_for_rotation(0) == 'hello'
- def test_get_string_for_rotation_passes_args(self):
+ def test_get_string_for_rotation_passes_args(self) -> None:
decoder = BasicStringDecoder(['hello', 'world'], 0)
# Exhaust first-call guard
with pytest.raises(RuntimeError):
@@ -113,7 +113,7 @@ def test_get_string_for_rotation_passes_args(self):
# Second call with index 1
assert decoder.get_string_for_rotation(1) == 'world'
- def test_is_first_call_flag(self):
+ def test_is_first_call_flag(self) -> None:
decoder = BasicStringDecoder(['x'], 0)
assert decoder.is_first_call is True
with pytest.raises(RuntimeError):
@@ -127,58 +127,58 @@ def test_is_first_call_flag(self):
class TestBasicStringDecoder:
- def test_type_property(self):
+ def test_type_property(self) -> None:
decoder = BasicStringDecoder(['a'], 0)
assert decoder.type == DecoderType.BASIC
- def test_zero_offset(self):
+ def test_zero_offset(self) -> None:
arr = ['alpha', 'beta', 'gamma']
decoder = BasicStringDecoder(arr, 0)
assert decoder.get_string(0) == 'alpha'
assert decoder.get_string(1) == 'beta'
assert decoder.get_string(2) == 'gamma'
- def test_positive_offset(self):
+ def test_positive_offset(self) -> None:
arr = ['a', 'b', 'c', 'd']
decoder = BasicStringDecoder(arr, 2)
# index 0 -> array[0+2] = 'c'
assert decoder.get_string(0) == 'c'
assert decoder.get_string(1) == 'd'
- def test_negative_offset(self):
+ def test_negative_offset(self) -> None:
arr = ['a', 'b', 'c', 'd']
decoder = BasicStringDecoder(arr, -2)
# index 2 -> array[2-2] = 'a'
assert decoder.get_string(2) == 'a'
assert decoder.get_string(3) == 'b'
- def test_out_of_bounds_positive(self):
+ def test_out_of_bounds_positive(self) -> None:
arr = ['a', 'b']
decoder = BasicStringDecoder(arr, 0)
assert decoder.get_string(5) is None
- def test_out_of_bounds_negative(self):
+ def test_out_of_bounds_negative(self) -> None:
arr = ['a', 'b']
decoder = BasicStringDecoder(arr, -5)
assert decoder.get_string(0) is None
- def test_exact_boundary(self):
+ def test_exact_boundary(self) -> None:
arr = ['a', 'b', 'c']
decoder = BasicStringDecoder(arr, 0)
assert decoder.get_string(2) == 'c' # last valid index
assert decoder.get_string(3) is None # one past end
- def test_negative_index_yields_negative_array_index(self):
+ def test_negative_index_yields_negative_array_index(self) -> None:
arr = ['a', 'b']
decoder = BasicStringDecoder(arr, 0)
# index -1 -> array[-1] which is < 0 -> returns None
assert decoder.get_string(-1) is None
- def test_empty_array(self):
+ def test_empty_array(self) -> None:
decoder = BasicStringDecoder([], 0)
assert decoder.get_string(0) is None
- def test_extra_args_are_ignored(self):
+ def test_extra_args_are_ignored(self) -> None:
decoder = BasicStringDecoder(['x'], 0)
assert decoder.get_string(0, 'extra', 'args') == 'x'
@@ -189,24 +189,24 @@ def test_extra_args_are_ignored(self):
class TestBase64StringDecoder:
- def test_type_property(self):
+ def test_type_property(self) -> None:
decoder = Base64StringDecoder(['x'], 0)
assert decoder.type == DecoderType.BASE_64
- def test_decodes_value(self):
+ def test_decodes_value(self) -> None:
decoder = Base64StringDecoder(['abcd'], 0)
assert decoder.get_string(0) == base64_transform('abcd')
- def test_with_offset(self):
+ def test_with_offset(self) -> None:
decoder = Base64StringDecoder(['SKIP', 'abcd'], 1)
# index 0 -> array[0+1] = 'abcd'
assert decoder.get_string(0) == base64_transform('abcd')
- def test_out_of_bounds_returns_none(self):
+ def test_out_of_bounds_returns_none(self) -> None:
decoder = Base64StringDecoder(['abcd'], 0)
assert decoder.get_string(5) is None
- def test_caching(self):
+ def test_caching(self) -> None:
decoder = Base64StringDecoder(['abcd'], 0)
result1 = decoder.get_string(0)
result2 = decoder.get_string(0)
@@ -215,7 +215,7 @@ def test_caching(self):
assert 0 in decoder._cache
assert decoder._cache[0] == result1
- def test_multiple_indices(self):
+ def test_multiple_indices(self) -> None:
decoder = Base64StringDecoder(['abcd', 'ABCD'], 0)
r0 = decoder.get_string(0)
r1 = decoder.get_string(1)
@@ -223,11 +223,11 @@ def test_multiple_indices(self):
assert r1 == base64_transform('ABCD')
assert r0 != r1
- def test_empty_encoded_string(self):
+ def test_empty_encoded_string(self) -> None:
decoder = Base64StringDecoder([''], 0)
assert decoder.get_string(0) == ''
- def test_get_string_for_rotation(self):
+ def test_get_string_for_rotation(self) -> None:
decoder = Base64StringDecoder(['abcd'], 0)
with pytest.raises(RuntimeError):
decoder.get_string_for_rotation(0)
@@ -241,36 +241,36 @@ def test_get_string_for_rotation(self):
class TestRc4StringDecoder:
- def test_type_property(self):
+ def test_type_property(self) -> None:
decoder = Rc4StringDecoder(['x'], 0)
assert decoder.type == DecoderType.RC4
- def test_key_none_returns_none(self):
+ def test_key_none_returns_none(self) -> None:
decoder = Rc4StringDecoder(['abcd'], 0)
assert decoder.get_string(0) is None
- def test_key_none_explicit(self):
+ def test_key_none_explicit(self) -> None:
decoder = Rc4StringDecoder(['abcd'], 0)
assert decoder.get_string(0, key=None) is None
- def test_out_of_bounds_returns_none(self):
+ def test_out_of_bounds_returns_none(self) -> None:
decoder = Rc4StringDecoder(['abcd'], 0)
assert decoder.get_string(5, key='k') is None
- def test_decodes_with_key(self):
+ def test_decodes_with_key(self) -> None:
decoder = Rc4StringDecoder(['abcd'], 0)
result = decoder.get_string(0, key='k')
assert result is not None
assert isinstance(result, str)
assert len(result) == 3
- def test_different_keys_give_different_results(self):
+ def test_different_keys_give_different_results(self) -> None:
decoder = Rc4StringDecoder(['abcd'], 0)
r1 = decoder.get_string(0, key='testkey')
r2 = decoder.get_string(0, key='otherkey')
assert r1 != r2
- def test_caching_with_same_key(self):
+ def test_caching_with_same_key(self) -> None:
decoder = Rc4StringDecoder(['abcd'], 0)
r1 = decoder.get_string(0, key='mykey')
r2 = decoder.get_string(0, key='mykey')
@@ -278,7 +278,7 @@ def test_caching_with_same_key(self):
assert r1 is r2 # same cached object
assert (0, 'mykey') in decoder._cache
- def test_cache_keyed_by_index_and_key(self):
+ def test_cache_keyed_by_index_and_key(self) -> None:
decoder = Rc4StringDecoder(['abcd', 'ABCD'], 0)
r1 = decoder.get_string(0, key='k')
r2 = decoder.get_string(1, key='k')
@@ -286,25 +286,25 @@ def test_cache_keyed_by_index_and_key(self):
assert (1, 'k') in decoder._cache
assert r1 != r2
- def test_with_offset(self):
+ def test_with_offset(self) -> None:
decoder = Rc4StringDecoder(['SKIP', 'abcd'], 1)
result = decoder.get_string(0, key='k')
assert result is not None
- def test_same_input_same_key_deterministic(self):
+ def test_same_input_same_key_deterministic(self) -> None:
"""Same input and key always produce the same output."""
dec1 = Rc4StringDecoder(['abcd'], 0)
dec2 = Rc4StringDecoder(['abcd'], 0)
assert dec1.get_string(0, key='k') == dec2.get_string(0, key='k')
- def test_get_string_for_rotation(self):
+ def test_get_string_for_rotation(self) -> None:
decoder = Rc4StringDecoder(['abcd'], 0)
with pytest.raises(RuntimeError):
decoder.get_string_for_rotation(0, key='k')
result = decoder.get_string_for_rotation(0, key='k')
assert result is not None
- def test_empty_array(self):
+ def test_empty_array(self) -> None:
decoder = Rc4StringDecoder([], 0)
assert decoder.get_string(0, key='k') is None
@@ -315,7 +315,7 @@ def test_empty_array(self):
class TestCrossDecoder:
- def test_basic_does_not_decode(self):
+ def test_basic_does_not_decode(self) -> None:
"""BasicStringDecoder returns the raw string, not decoded."""
arr = ['abcd']
basic = BasicStringDecoder(arr, 0)