Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 242 additions & 37 deletions pineforge_codegen/codegen/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,34 @@ def __init__(self, ctx: AnalyzerContext) -> None:
for site in ctx.ta_call_sites:
if site.node is not None:
if site.member_name not in self._func_ta_members:
# Top-level (non-function) site: maps to itself.
self._ta_site_map[id(site.node)] = site
elif not any(site.member_name.endswith(f"_cs{i}") for i in range(1, 100)):
# Original (cs0) function-local site — add to map for initial visit
elif id(site.node) not in self._ta_site_map:
# Function-local site. Multiple clones share the SAME AST
# node (clones copy ``node=orig.node``); the FIRST one in
# ``ta_call_sites`` order is the canonical original (cs0)
# whose ``member_name`` the per-call-site remap is keyed on.
# Later clones (``_cs{i}``, ``_cs{i}_cs{j}``, ``_u{n}`` …)
# must NOT overwrite it: doing so poisons the base name so
# the active-remap lookup misses and every clone collapses
# onto one member. Keep the original; the variant member is
# resolved via ``_active_ta_remap`` at emit time.
self._ta_site_map[id(site.node)] = site
self._ta_index_by_site_id: dict[int, int] = {
id(site): i for i, site in enumerate(ctx.ta_call_sites)
}
# Context-sensitive (call-path) instance machinery for nested stateful
# helpers. Built by ``_build_func_instances`` below. ``_current_instance_name``
# names the function clone whose body is currently being emitted (None at
# top level / non-variant bodies). ``_instance_dispatch`` maps
# ``(enclosing_instance_name, call_node_id) -> callee emit-name`` and is the
# authority for nested stateful-helper dispatch (see visit_call).
self._current_instance_name: str | None = None
self._instance_dispatch: dict[tuple[str | None, int], str] = {}
self._fresh_instances: list[dict] = []
self._fresh_var_members: list[tuple[str, str]] = []
# NOTE: _build_func_instances() runs at the top of generate() (it needs
# _all_member_names / _func_safe_name, which are populated later in __init__).
# Build lookup: node id -> FixnanCallSite (counter-based)
self._fixnan_counter = 0
self._switch_counter = 0
Expand Down Expand Up @@ -534,6 +555,194 @@ def __init__(self, ctx: AnalyzerContext) -> None:
# history reads off security-helper series.
self._max_bars_back_cap: int | None = self._compute_max_bars_back_cap()

# ------------------------------------------------------------------
# Context-sensitive (call-path) instance machinery
# ------------------------------------------------------------------
def _iter_func_calls(self, root) -> list:
"""Collect every ``FuncCall`` node reachable from ``root`` (a stmt list
or single AST node). Order-independent; used by the instance pre-pass to
find nested user-function calls inside a function body."""
out: list = []
seen: set[int] = set()
stack: list = list(root) if isinstance(root, (list, tuple)) else [root]
while stack:
node = stack.pop()
if node is None:
continue
if isinstance(node, (list, tuple)):
stack.extend(node)
continue
if isinstance(node, dict):
stack.extend(node.values())
continue
if not hasattr(node, "__dict__"):
continue
nid = id(node)
if nid in seen:
continue
seen.add(nid)
if isinstance(node, FuncCall):
out.append(node)
for v in vars(node).values():
if isinstance(v, (list, tuple, dict)) or hasattr(v, "__dict__"):
stack.append(v)
return out

def _build_func_instances(self) -> None:
"""Context-sensitive cloning of nested stateful helper functions.

A stateful helper ``G`` (carrying TA state and/or ``var`` members) may be
reached through several distinct call paths — e.g. ``leg`` called from
three clones of ``f_get`` (lengths 10/20/30) *and* directly. Each path is
a logically-distinct instance that must drive its OWN TA/var members.

The analyzer already mints the per-path members (via range-widening), but
the flat ``{G}_cs{idx}`` clone namespace conflates a callee's own textual
call sites with the enclosing function's call sites. This pre-pass walks
the call graph from each natural clone and, for every nested stateful
call, composes the enclosing clone's active remap with the callee's
per-call-site remap:

composed_ta[m] = R_enclosing.get(R_callee_cs[m], R_callee_cs[m])

When the composition equals the callee's natural ``cs{j}`` remap the call
dispatches to the existing ``{G}_cs{j}`` clone (output stays byte-identical
for the common single-caller case). Otherwise a fresh instance is minted,
bound to the path-specific members (and FRESH ``var`` members so two paths
never share scalar state). ``_instance_dispatch`` records the resolved
emit-name per ``(enclosing_instance, call_node)``; ``_fresh_instances`` /
``_fresh_var_members`` carry the extra code to emit.
"""
ctx = self.ctx
stateful = (set(ctx.func_ta_ranges.keys())
| set(ctx.func_var_members.keys())
| set(ctx.func_series_vars.keys()))
if not stateful:
return

func_bodies: dict[str, list] = {}
for fi in ctx.func_infos:
node = getattr(fi, "node", None)
if node is not None and getattr(node, "body", None):
func_bodies.setdefault(fi.name, node.body)

def ta_originals(fname: str) -> list[str]:
return list(self._func_cs_ta_remap.get((fname, 0), {}).keys())

def var_originals(fname: str) -> list[str]:
return [self._safe_name(n) for n, _, _ in ctx.func_var_members.get(fname, [])]

def natural_name(fname: str, cs_idx: int) -> str:
return f"{self._func_safe_name(fname)}_cs{cs_idx}"

interned: dict[tuple, dict] = {}
worklist: list[dict] = []
seen_walk: set[str] = set()
fresh_counter = 0

# Seed with the natural clones the flat emission loop produces.
for fname in stateful:
if fname not in func_bodies:
continue
total_cs = ctx.func_call_site_counts.get(fname, 0)
if total_cs > 0:
for k in range(total_cs):
worklist.append({
"fname": fname,
"name": natural_name(fname, k),
"ta_remap": self._func_cs_ta_remap.get((fname, k), {}),
"var_remap": self._func_cs_var_remap.get((fname, k), {}),
})
else:
worklist.append({
"fname": fname,
"name": self._func_safe_name(fname),
"ta_remap": {},
"var_remap": {},
})

while worklist:
inst = worklist.pop()
if inst["name"] in seen_walk:
continue
seen_walk.add(inst["name"])
body = func_bodies.get(inst["fname"])
if not body:
continue
active_ta = inst["ta_remap"]
for callnode in self._iter_func_calls(body):
cs_info = ctx.func_call_cs_map.get(id(callnode))
if cs_info is None:
continue
g_name, j = cs_info
if g_name not in stateful:
continue
natural_ta = self._func_cs_ta_remap.get((g_name, j), {})
composed_ta = {}
for m in ta_originals(g_name):
mid = natural_ta.get(m, m)
composed_ta[m] = active_ta.get(mid, mid)
if composed_ta == natural_ta:
# Path resolves to the callee's own cs{j} clone — reuse it.
self._instance_dispatch[(inst["name"], id(callnode))] = \
natural_name(g_name, j)
continue
key = (g_name, frozenset(composed_ta.items()))
ginst = interned.get(key)
if ginst is None:
fresh_counter += 1
inst_name = f"{self._func_safe_name(g_name)}__ni{fresh_counter}"
fvar_remap: dict[str, str] = {}
for v in var_originals(g_name):
fresh_member = f"{v}__ni{fresh_counter}"
fvar_remap[v] = fresh_member
self._fresh_var_members.append((v, fresh_member))
ginst = {
"fname": g_name,
"name": inst_name,
"ta_remap": composed_ta,
"var_remap": fvar_remap,
}
interned[key] = ginst
self._fresh_instances.append(ginst)
worklist.append(ginst)
self._instance_dispatch[(inst["name"], id(callnode))] = ginst["name"]

def _emit_cloned_var_decl(self, orig_safe: str, cloned_safe: str,
series_suffix: str, lines: list[str]) -> None:
"""Declare a per-clone copy of a function-scoped ``var`` member, matching
the original's C++ type (series / matrix / array / map / drawing-handle /
UDT / scalar). Shared by the per-call-site clone loop and the fresh
context-sensitive instance loop."""
for vname, ptype, _init_str in self.ctx.var_members:
if self._safe_name(vname) == orig_safe:
cpp_type = PINE_TYPE_TO_CPP.get(ptype, "double")
if vname in self.ctx.series_vars:
lines.append(f" Series<{cpp_type}> {cloned_safe}{series_suffix};")
elif vname in self._matrix_specs:
lines.append(f" {self._type_spec_to_cpp(self._matrix_specs[vname])} {cloned_safe};")
elif vname in self._array_vars:
lines.append(f" {self._type_spec_to_cpp(self._array_spec_for_name(vname))} {cloned_safe};")
elif vname in self._map_vars:
lines.append(f" {self._type_spec_to_cpp(self._map_spec_for_name(vname))} {cloned_safe};")
elif vname in self._udt_var_types:
# Drawing handle / UDT var clone must match the original's
# type (Line/Label/Box/<UDT>), not the coarse PineType
# default (double) — otherwise the clone can't hold the
# handle and drawing access on it reads a garbage / na id.
udt_t = self._udt_var_types[vname]
handle_cpp = DRAWING_TYPE_TO_CPP.get(udt_t, udt_t)
lines.append(f" {handle_cpp} {cloned_safe} = {handle_cpp}{{}};")
else:
lines.append(f" {cpp_type} {cloned_safe};")
return
# Non-var series var
if orig_safe in [self._safe_name(n) for n in self.ctx.series_vars]:
cpp_type = self._series_type_for(orig_safe)
lines.append(f" Series<{cpp_type}> {cloned_safe}{series_suffix};")
else:
lines.append(f" double {cloned_safe} = 0.0;")

@staticmethod
def _int_literal_value(node: ASTNode | None) -> int | None:
"""Return the integer value of a (possibly unary-minus) NumberLiteral,
Expand Down Expand Up @@ -979,6 +1188,9 @@ def walk(node):

def generate(self) -> str:
"""Generate C++ source from the AnalyzerContext."""
# Context-sensitive instance pre-pass (needs the naming helpers populated
# in __init__). Computes nested stateful-helper dispatch + fresh instances.
self._build_func_instances()
# Pre-scan for strategy series vars
self._prescan_strategy_series()
self._security_ohlc_hist_fields_by_sec: dict[int, set[str]] = {}
Expand Down Expand Up @@ -1306,41 +1518,16 @@ def generate(self) -> str:
if cloned_safe in emitted_clones:
continue # already declared by another function's clone
emitted_clones.add(cloned_safe)
# Determine the type by finding the original declaration
orig_name = orig_safe # _safe_name was already applied
# Check if it's a var member (Series) or plain series
found = False
for vname, ptype, init_str in self.ctx.var_members:
if self._safe_name(vname) == orig_safe:
cpp_type = PINE_TYPE_TO_CPP.get(ptype, "double")
if vname in self.ctx.series_vars:
lines.append(f" Series<{cpp_type}> {cloned_safe}{_mbb};")
elif vname in self._matrix_specs:
lines.append(f" {self._type_spec_to_cpp(self._matrix_specs[vname])} {cloned_safe};")
elif vname in self._array_vars:
lines.append(f" {self._type_spec_to_cpp(self._array_spec_for_name(vname))} {cloned_safe};")
elif vname in self._map_vars:
lines.append(f" {self._type_spec_to_cpp(self._map_spec_for_name(vname))} {cloned_safe};")
elif vname in self._udt_var_types:
# Drawing handle / UDT var clone must match the
# original's type (Line/Label/Box/<UDT>), not the
# coarse PineType default (double) — otherwise the
# clone can't hold the handle and drawing access on
# it reads a garbage / na id.
udt_t = self._udt_var_types[vname]
handle_cpp = DRAWING_TYPE_TO_CPP.get(udt_t, udt_t)
lines.append(f" {handle_cpp} {cloned_safe} = {handle_cpp}{{}};")
else:
lines.append(f" {cpp_type} {cloned_safe};")
found = True
break
if not found:
# Non-var series var
if orig_safe in [self._safe_name(n) for n in self.ctx.series_vars]:
cpp_type = self._series_type_for(orig_safe)
lines.append(f" Series<{cpp_type}> {cloned_safe}{_mbb};")
else:
lines.append(f" double {cloned_safe} = 0.0;")
self._emit_cloned_var_decl(orig_safe, cloned_safe, _mbb, lines)

# 8c2. Fresh var members for context-sensitive helper instances (nested
# helpers reached through >1 distinct call path). Each fresh instance
# gets its OWN scalar/series state so two paths never collide.
for orig_safe, fresh_safe in self._fresh_var_members:
if fresh_safe in emitted_clones:
continue
emitted_clones.add(fresh_safe)
self._emit_cloned_var_decl(orig_safe, fresh_safe, _mbb, lines)

# 8d. Drawing-objects-as-data arenas (gated on _uses_drawing so
# non-drawing strategies emit byte-identical C++). Each arena is a
Expand Down Expand Up @@ -1375,6 +1562,11 @@ def generate(self) -> str:
else:
lines.append(f" bool _fvinit_{self._func_safe_name(fi.name)} = false;")

# 9a2. ``var`` init flags for fresh context-sensitive helper instances.
for inst in self._fresh_instances:
if inst["fname"] in self.ctx.func_var_members and inst["var_remap"]:
lines.append(f" bool _fvinit_{inst['name']} = false;")

# 9b. _ta_initialized_ flag for runtime TA re-sizing (first on_bar only).
if self.ctx.ta_call_sites:
lines.append(" bool _ta_initialized_ = false;")
Expand Down Expand Up @@ -1403,7 +1595,20 @@ def generate(self) -> str:
self._emit_func_def(fi, lines)
lines.append("")

# 10a. Fresh context-sensitive instances of nested stateful helpers
# (reached through >1 distinct call path). Each is bound to its own
# path-specific TA + var members; see _build_func_instances.
if self._fresh_instances:
fi_by_name = {fi.name: fi for fi in self.ctx.func_infos}
for inst in self._fresh_instances:
fi = fi_by_name.get(inst["fname"])
if fi is None:
continue
self._emit_func_def(fi, lines, instance=inst)
lines.append("")

# 11. on_bar()
self._current_instance_name = None
self._emit_on_bar(lines)
lines.append("")

Expand Down
Loading
Loading