Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions pineforge_codegen/analyzer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ def __init__(self, ast: Program, filename: str = "<stdin>") -> None:
self._func_ta_ranges: dict[str, tuple[int, int]] = {} # func_name -> (start, end) indices
self._func_call_site_count: dict[str, int] = {} # func_name -> count
self._func_call_cs_map: dict[int, tuple[str, int]] = {} # call_node_id -> (func_name, cs_idx)
# Authoritative clone-name map: (func_name, cs_idx) -> {orig_member_name:
# cloned_member_name}. The codegen rebuilds its TA remap from the
# ``{orig}_cs{cs_idx}`` formula by default, but a TA site reached through
# MULTIPLE enclosing functions (e.g. a helper cloned both directly and via
# a range-widened outer function) can collide on that formula. When the
# analyzer must disambiguate a clone's member name, it records the actual
# chosen name here so the codegen consumes it verbatim instead of
# re-deriving a colliding name. Empty for the common (no-collision) case,
# keeping generated output byte-identical.
self._func_cs_ta_clone_names: dict[tuple[str, int], dict[str, str]] = {}
# All TA member names minted so far (base + clones), for O(1) collision
# detection when minting a new clone.
self._ta_member_names: set[str] = set()
# UDT field definitions: type_name -> {field_name: PineType}
self._udt_fields: dict[str, dict[str, PineType]] = {}
# var_name -> UDT type for variables holding UDT instances
Expand All @@ -164,6 +177,14 @@ def __init__(self, ast: Program, filename: str = "<stdin>") -> None:
self._current_top_level_stmt: ASTNode | None = None
self._global_scope = True
self._static_vars: set[str] = set()
# Stack of enclosing user-function param-name sets, pushed while visiting
# a FuncDef body. Lets a nested user-func call detect when it substitutes
# a TA ctor length with one of the OUTER function's params, so the outer
# call site can re-substitute (e.g. f_bbwp(_bbwLen) -> f_basisMa(_len)).
self._enclosing_func_params: list[set[str]] = []
# Set of TA-site indices a nested user-func call rewrote in terms of the
# current enclosing function's params (None when not inside a FuncDef body).
self._nested_ta_touched: set | None = None

# Pre-populate builtins
self._populate_builtins()
Expand Down Expand Up @@ -272,6 +293,7 @@ def analyze(self) -> AnalyzerContext:
func_ta_ranges=self._func_ta_ranges,
func_call_cs_map=self._func_call_cs_map,
func_call_site_counts=self._func_call_site_count,
func_cs_ta_clone_names=self._func_cs_ta_clone_names,
udt_defs=self._udt_fields,
enum_defs=self._enum_defs,
enum_member_strings=self._enum_member_strings,
Expand Down Expand Up @@ -899,16 +921,28 @@ def _visit_FuncDef(self, node: FuncDef) -> PineType:
body_type = PineType.VOID
old_global = self._global_scope
self._global_scope = False
self._enclosing_func_params.append(set(node.params))
self._nested_ta_touched = set()
try:
for stmt in node.body:
body_type = self._visit(stmt)
finally:
self._global_scope = old_global

# Record TA range for this function
self._enclosing_func_params.pop()
nested_touched = self._nested_ta_touched
self._nested_ta_touched = None

# Record TA range for this function. Widen to cover any nested-callee TA
# sites whose ctor args were rewritten in terms of THIS function's params
# (e.g. f_basisMa's sites parameterized by f_bbwp's _bbwLen), so resolving
# this function at its call site re-substitutes those nested sites too.
ta_end = len(self._ta_call_sites)
if ta_end > ta_start:
self._func_ta_ranges[node.name] = (ta_start, ta_end)
lo, hi = ta_start, ta_end
if nested_touched:
lo = min(lo, min(nested_touched))
hi = max(hi, max(nested_touched) + 1)
if hi > lo:
self._func_ta_ranges[node.name] = (lo, hi)

self._symbols.exit_scope()

Expand Down
132 changes: 128 additions & 4 deletions pineforge_codegen/analyzer/call_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
from typing import Any

from ..ast_nodes import (
ASTNode, BoolLiteral, ExprStmt, FuncCall, Identifier, MemberAccess,
NumberLiteral, StringLiteral, TupleLiteral,
ASTNode, BinOp, BoolLiteral, ExprStmt, FuncCall, Identifier, MemberAccess,
NumberLiteral, StringLiteral, TupleLiteral, UnaryOp, VarDecl,
)
from ..symbols import PineType
from .. import signatures as sigs
Expand Down Expand Up @@ -207,6 +207,7 @@ def _handle_ta_call(self, func_name: str, node: FuncCall) -> PineType:
is_static=is_static,
)
self._ta_call_sites.append(site)
self._ta_member_names.add(site.member_name)
return PineType.FLOAT

# Determine constructor args
Expand Down Expand Up @@ -249,6 +250,7 @@ def _handle_ta_call(self, func_name: str, node: FuncCall) -> PineType:
is_static=is_static,
)
self._ta_call_sites.append(site)
self._ta_member_names.add(site.member_name)

return PineType.FLOAT

Expand Down Expand Up @@ -770,6 +772,57 @@ def _handle_fixnan_call(self, node: FuncCall) -> PineType:
# User-defined function calls
# ------------------------------------------------------------------

def _func_local_length_defs(self, func_def) -> dict[str, str]:
"""Collect a user function's local scalar length-vars to their RHS
expression string, e.g. ``qqeCalc`` with ``wp = sf * 2 - 1`` returns
``{"wp": "sf * 2 - 1"}``.

Only plain (non-``var``/``varip``) declarations whose RHS is a pure
arithmetic expression over identifiers/numbers (NumberLiteral, Identifier,
BinOp, UnaryOp, or a math.* FuncCall) qualify — these are the shapes that
can legitimately feed a TA constructor length. Series-valued locals (whose
RHS is a ta.* call, a subscript, a ternary, etc.) are skipped so we never
inline a price series into a ctor-length slot. Names reassigned with ``:=``
are also skipped (their value is not a stable compile-time length).
"""
def _is_arith(n) -> bool:
if isinstance(n, (NumberLiteral, Identifier)):
return True
if isinstance(n, BinOp):
return _is_arith(n.left) and _is_arith(n.right)
if isinstance(n, UnaryOp):
return _is_arith(n.operand)
if isinstance(n, FuncCall):
# Allow math.* helpers (math.round/sqrt/...) over arith args.
callee = n.callee
if (isinstance(callee, MemberAccess)
and isinstance(callee.object, Identifier)
and callee.object.name == "math"):
return all(_is_arith(a) for a in n.args)
return False

reassigned: set[str] = set()
def _scan_reassign(stmts):
from ..ast_nodes import Assignment
for s in stmts or []:
if isinstance(s, Assignment) and isinstance(s.target, Identifier):
reassigned.add(s.target.name)
for attr in ("body", "else_body"):
sub = getattr(s, attr, None)
if isinstance(sub, list):
_scan_reassign(sub)
_scan_reassign(func_def.body)

defs: dict[str, str] = {}
for stmt in func_def.body or []:
if (isinstance(stmt, VarDecl)
and not stmt.is_var and not stmt.is_varip
and stmt.name not in reassigned
and stmt.value is not None
and _is_arith(stmt.value)):
defs[stmt.name] = self._expr_to_str(stmt.value)
return defs

def _handle_user_func_call(self, func_name: str, node: FuncCall) -> PineType:
"""Handle calls to user-defined functions."""
func_def = self._func_defs[func_name]
Expand Down Expand Up @@ -843,6 +896,16 @@ def _handle_user_func_call(self, func_name: str, node: FuncCall) -> PineType:
if has_ta:
start, end = self._func_ta_ranges[func_name]

# Map of this function's local (non-param, non-series) derived
# length vars to their raw RHS expression strings, e.g.
# ``qqeCalc`` => ``wp = sf * 2 - 1`` -> {"wp": "sf * 2 - 1"}.
# A TA ctor arg captured as the bare local name ("wp") must be
# expanded to its definition so the subsequent param-substitution
# turns it into a class-scope expression ("rsiSmooth * 2 - 1")
# rather than leaving a dangling local that degenerates to
# period 1 in codegen.
local_defs = self._func_local_length_defs(func_def)

def _subst_params(arg: str, pmap: dict[str, str]) -> str:
"""Substitute parameter names in an expression string.

Expand All @@ -856,23 +919,81 @@ def _subst_params(arg: str, pmap: dict[str, str]) -> str:
result = re.sub(rf'\b{re.escape(param)}\b', value, result)
return result

def _expand_locals(arg: str) -> str:
"""Recursively expand function-local length vars to their RHS
(parenthesized) so only params / class-scope names remain."""
import re
if not local_defs:
return arg
for _ in range(32):
def _rep(m: re.Match) -> str:
nm = m.group(0)
if nm in local_defs:
return "(" + local_defs[nm] + ")"
return nm
new = re.sub(r"[A-Za-z_][A-Za-z_0-9]*", _rep, arg)
if new == arg:
break
arg = new
return arg

# Params of the function we are *currently inside* (if this is a
# nested user-func call). Used to detect when a substituted ctor
# arg becomes parameterized by the OUTER function, so the outer
# call site can resolve it (f_bbwp's _bbwLen -> i_bbwLen reaches
# f_basisMa's sites).
import re as _re
enclosing_params: set[str] = set()
for s in self._enclosing_func_params:
enclosing_params |= s

if cs_idx == 0:
# First call site: save original param-based ctor_args for future cloning,
# then resolve to actual call-site values
for i in range(start, end):
site = self._ta_call_sites[i]
if not hasattr(site, '_orig_ctor_args'):
site._orig_ctor_args = site.ctor_args[:]
site._orig_ctor_args = [
_expand_locals(a) for a in site.ctor_args
]
site.ctor_args = [_subst_params(a, param_arg_map) for a in site._orig_ctor_args]
# If a substituted arg is now expressed in terms of an
# enclosing function's params, promote it to the original
# so the enclosing call re-substitutes, and mark the site
# so the enclosing function's TA range widens to cover it.
if enclosing_params and self._nested_ta_touched is not None:
for a in site.ctor_args:
toks = set(_re.findall(r"[A-Za-z_][A-Za-z_0-9]*", a))
if toks & enclosing_params:
site._orig_ctor_args = list(site.ctor_args)
self._nested_ta_touched.add(i)
break
else:
# Subsequent call sites: clone using saved original param names,
# substituted with this call site's arguments
clone_name_map: dict[str, str] = {}
for i in range(start, end):
orig = self._ta_call_sites[i]
orig_args = getattr(orig, '_orig_ctor_args', orig.ctor_args)
resolved_ctor = [_subst_params(a, param_arg_map) for a in orig_args]
# Default name follows the ``{base}_cs{cs_idx}`` formula the
# codegen re-derives. But the SAME base TA site can be reached
# through more than one enclosing function (e.g. a helper cloned
# both via its own call sites AND via a range-widened outer
# function), so two distinct (func, cs_idx) namespaces can mint
# the same name. Detect that collision and fall back to a
# globally-unique name; record the chosen name so the codegen
# consumes it verbatim (see _func_cs_ta_clone_names).
clone_name = f"{orig.member_name}_cs{cs_idx}"
if clone_name in self._ta_member_names:
base = clone_name
n = 2
while clone_name in self._ta_member_names:
clone_name = f"{base}_u{n}"
n += 1
clone_name_map[orig.member_name] = clone_name
cloned = TACallSite(
member_name=f"{orig.member_name}_cs{cs_idx}",
member_name=clone_name,
class_name=orig.class_name,
ctor_args=resolved_ctor,
compute_args=orig.compute_args[:],
Expand All @@ -881,6 +1002,9 @@ def _subst_params(arg: str, pmap: dict[str, str]) -> str:
is_static=orig.is_static,
)
self._ta_call_sites.append(cloned)
self._ta_member_names.add(clone_name)
if clone_name_map:
self._func_cs_ta_clone_names[(func_name, cs_idx)] = clone_name_map

# Create or update FuncInfo
is_tuple = self._func_returns_tuple.get(func_name, False)
Expand Down
5 changes: 5 additions & 0 deletions pineforge_codegen/analyzer/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ class AnalyzerContext:
func_ta_ranges: dict = field(default_factory=dict) # func_name -> (start_idx, end_idx)
func_call_cs_map: dict = field(default_factory=dict) # call_node_id -> (func_name, call_site_index)
func_call_site_counts: dict = field(default_factory=dict) # func_name -> int
# (func_name, cs_idx) -> {orig_member_name: cloned_member_name}. Populated by
# the analyzer ONLY for clones whose default ``{base}_cs{cs_idx}`` name would
# collide with a clone minted through another enclosing function; lets codegen
# use the disambiguated name instead of re-deriving a colliding one.
func_cs_ta_clone_names: dict = field(default_factory=dict)
# UDT / enum definitions:
udt_defs: dict = field(default_factory=dict) # type_name -> {field_name: PineType}
enum_defs: dict = field(default_factory=dict) # enum_name -> [member names]
Expand Down
Loading
Loading