diff --git a/pineforge_codegen/analyzer/base.py b/pineforge_codegen/analyzer/base.py index 4fe4e4c..43007e7 100644 --- a/pineforge_codegen/analyzer/base.py +++ b/pineforge_codegen/analyzer/base.py @@ -149,6 +149,19 @@ def __init__(self, ast: Program, filename: str = "") -> None: self._func_ta_ranges: dict[str, tuple[int, int]] = {} # func_name -> (start, end) indices self._func_call_site_count: dict[str, int] = {} # func_name -> count self._func_call_cs_map: dict[int, tuple[str, int]] = {} # call_node_id -> (func_name, cs_idx) + # Authoritative clone-name map: (func_name, cs_idx) -> {orig_member_name: + # cloned_member_name}. The codegen rebuilds its TA remap from the + # ``{orig}_cs{cs_idx}`` formula by default, but a TA site reached through + # MULTIPLE enclosing functions (e.g. a helper cloned both directly and via + # a range-widened outer function) can collide on that formula. When the + # analyzer must disambiguate a clone's member name, it records the actual + # chosen name here so the codegen consumes it verbatim instead of + # re-deriving a colliding name. Empty for the common (no-collision) case, + # keeping generated output byte-identical. + self._func_cs_ta_clone_names: dict[tuple[str, int], dict[str, str]] = {} + # All TA member names minted so far (base + clones), for O(1) collision + # detection when minting a new clone. + self._ta_member_names: set[str] = set() # UDT field definitions: type_name -> {field_name: PineType} self._udt_fields: dict[str, dict[str, PineType]] = {} # var_name -> UDT type for variables holding UDT instances @@ -164,6 +177,14 @@ def __init__(self, ast: Program, filename: str = "") -> None: self._current_top_level_stmt: ASTNode | None = None self._global_scope = True self._static_vars: set[str] = set() + # Stack of enclosing user-function param-name sets, pushed while visiting + # a FuncDef body. Lets a nested user-func call detect when it substitutes + # a TA ctor length with one of the OUTER function's params, so the outer + # call site can re-substitute (e.g. f_bbwp(_bbwLen) -> f_basisMa(_len)). + self._enclosing_func_params: list[set[str]] = [] + # Set of TA-site indices a nested user-func call rewrote in terms of the + # current enclosing function's params (None when not inside a FuncDef body). + self._nested_ta_touched: set | None = None # Pre-populate builtins self._populate_builtins() @@ -272,6 +293,7 @@ def analyze(self) -> AnalyzerContext: func_ta_ranges=self._func_ta_ranges, func_call_cs_map=self._func_call_cs_map, func_call_site_counts=self._func_call_site_count, + func_cs_ta_clone_names=self._func_cs_ta_clone_names, udt_defs=self._udt_fields, enum_defs=self._enum_defs, enum_member_strings=self._enum_member_strings, @@ -899,16 +921,28 @@ def _visit_FuncDef(self, node: FuncDef) -> PineType: body_type = PineType.VOID old_global = self._global_scope self._global_scope = False + self._enclosing_func_params.append(set(node.params)) + self._nested_ta_touched = set() try: for stmt in node.body: body_type = self._visit(stmt) finally: self._global_scope = old_global - - # Record TA range for this function + self._enclosing_func_params.pop() + nested_touched = self._nested_ta_touched + self._nested_ta_touched = None + + # Record TA range for this function. Widen to cover any nested-callee TA + # sites whose ctor args were rewritten in terms of THIS function's params + # (e.g. f_basisMa's sites parameterized by f_bbwp's _bbwLen), so resolving + # this function at its call site re-substitutes those nested sites too. ta_end = len(self._ta_call_sites) - if ta_end > ta_start: - self._func_ta_ranges[node.name] = (ta_start, ta_end) + lo, hi = ta_start, ta_end + if nested_touched: + lo = min(lo, min(nested_touched)) + hi = max(hi, max(nested_touched) + 1) + if hi > lo: + self._func_ta_ranges[node.name] = (lo, hi) self._symbols.exit_scope() diff --git a/pineforge_codegen/analyzer/call_handlers.py b/pineforge_codegen/analyzer/call_handlers.py index 4833836..69629bc 100644 --- a/pineforge_codegen/analyzer/call_handlers.py +++ b/pineforge_codegen/analyzer/call_handlers.py @@ -66,8 +66,8 @@ from typing import Any from ..ast_nodes import ( - ASTNode, BoolLiteral, ExprStmt, FuncCall, Identifier, MemberAccess, - NumberLiteral, StringLiteral, TupleLiteral, + ASTNode, BinOp, BoolLiteral, ExprStmt, FuncCall, Identifier, MemberAccess, + NumberLiteral, StringLiteral, TupleLiteral, UnaryOp, VarDecl, ) from ..symbols import PineType from .. import signatures as sigs @@ -207,6 +207,7 @@ def _handle_ta_call(self, func_name: str, node: FuncCall) -> PineType: is_static=is_static, ) self._ta_call_sites.append(site) + self._ta_member_names.add(site.member_name) return PineType.FLOAT # Determine constructor args @@ -249,6 +250,7 @@ def _handle_ta_call(self, func_name: str, node: FuncCall) -> PineType: is_static=is_static, ) self._ta_call_sites.append(site) + self._ta_member_names.add(site.member_name) return PineType.FLOAT @@ -770,6 +772,57 @@ def _handle_fixnan_call(self, node: FuncCall) -> PineType: # User-defined function calls # ------------------------------------------------------------------ + def _func_local_length_defs(self, func_def) -> dict[str, str]: + """Collect a user function's local scalar length-vars to their RHS + expression string, e.g. ``qqeCalc`` with ``wp = sf * 2 - 1`` returns + ``{"wp": "sf * 2 - 1"}``. + + Only plain (non-``var``/``varip``) declarations whose RHS is a pure + arithmetic expression over identifiers/numbers (NumberLiteral, Identifier, + BinOp, UnaryOp, or a math.* FuncCall) qualify — these are the shapes that + can legitimately feed a TA constructor length. Series-valued locals (whose + RHS is a ta.* call, a subscript, a ternary, etc.) are skipped so we never + inline a price series into a ctor-length slot. Names reassigned with ``:=`` + are also skipped (their value is not a stable compile-time length). + """ + def _is_arith(n) -> bool: + if isinstance(n, (NumberLiteral, Identifier)): + return True + if isinstance(n, BinOp): + return _is_arith(n.left) and _is_arith(n.right) + if isinstance(n, UnaryOp): + return _is_arith(n.operand) + if isinstance(n, FuncCall): + # Allow math.* helpers (math.round/sqrt/...) over arith args. + callee = n.callee + if (isinstance(callee, MemberAccess) + and isinstance(callee.object, Identifier) + and callee.object.name == "math"): + return all(_is_arith(a) for a in n.args) + return False + + reassigned: set[str] = set() + def _scan_reassign(stmts): + from ..ast_nodes import Assignment + for s in stmts or []: + if isinstance(s, Assignment) and isinstance(s.target, Identifier): + reassigned.add(s.target.name) + for attr in ("body", "else_body"): + sub = getattr(s, attr, None) + if isinstance(sub, list): + _scan_reassign(sub) + _scan_reassign(func_def.body) + + defs: dict[str, str] = {} + for stmt in func_def.body or []: + if (isinstance(stmt, VarDecl) + and not stmt.is_var and not stmt.is_varip + and stmt.name not in reassigned + and stmt.value is not None + and _is_arith(stmt.value)): + defs[stmt.name] = self._expr_to_str(stmt.value) + return defs + def _handle_user_func_call(self, func_name: str, node: FuncCall) -> PineType: """Handle calls to user-defined functions.""" func_def = self._func_defs[func_name] @@ -843,6 +896,16 @@ def _handle_user_func_call(self, func_name: str, node: FuncCall) -> PineType: if has_ta: start, end = self._func_ta_ranges[func_name] + # Map of this function's local (non-param, non-series) derived + # length vars to their raw RHS expression strings, e.g. + # ``qqeCalc`` => ``wp = sf * 2 - 1`` -> {"wp": "sf * 2 - 1"}. + # A TA ctor arg captured as the bare local name ("wp") must be + # expanded to its definition so the subsequent param-substitution + # turns it into a class-scope expression ("rsiSmooth * 2 - 1") + # rather than leaving a dangling local that degenerates to + # period 1 in codegen. + local_defs = self._func_local_length_defs(func_def) + def _subst_params(arg: str, pmap: dict[str, str]) -> str: """Substitute parameter names in an expression string. @@ -856,23 +919,81 @@ def _subst_params(arg: str, pmap: dict[str, str]) -> str: result = re.sub(rf'\b{re.escape(param)}\b', value, result) return result + def _expand_locals(arg: str) -> str: + """Recursively expand function-local length vars to their RHS + (parenthesized) so only params / class-scope names remain.""" + import re + if not local_defs: + return arg + for _ in range(32): + def _rep(m: re.Match) -> str: + nm = m.group(0) + if nm in local_defs: + return "(" + local_defs[nm] + ")" + return nm + new = re.sub(r"[A-Za-z_][A-Za-z_0-9]*", _rep, arg) + if new == arg: + break + arg = new + return arg + + # Params of the function we are *currently inside* (if this is a + # nested user-func call). Used to detect when a substituted ctor + # arg becomes parameterized by the OUTER function, so the outer + # call site can resolve it (f_bbwp's _bbwLen -> i_bbwLen reaches + # f_basisMa's sites). + import re as _re + enclosing_params: set[str] = set() + for s in self._enclosing_func_params: + enclosing_params |= s + if cs_idx == 0: # First call site: save original param-based ctor_args for future cloning, # then resolve to actual call-site values for i in range(start, end): site = self._ta_call_sites[i] if not hasattr(site, '_orig_ctor_args'): - site._orig_ctor_args = site.ctor_args[:] + site._orig_ctor_args = [ + _expand_locals(a) for a in site.ctor_args + ] site.ctor_args = [_subst_params(a, param_arg_map) for a in site._orig_ctor_args] + # If a substituted arg is now expressed in terms of an + # enclosing function's params, promote it to the original + # so the enclosing call re-substitutes, and mark the site + # so the enclosing function's TA range widens to cover it. + if enclosing_params and self._nested_ta_touched is not None: + for a in site.ctor_args: + toks = set(_re.findall(r"[A-Za-z_][A-Za-z_0-9]*", a)) + if toks & enclosing_params: + site._orig_ctor_args = list(site.ctor_args) + self._nested_ta_touched.add(i) + break else: # Subsequent call sites: clone using saved original param names, # substituted with this call site's arguments + clone_name_map: dict[str, str] = {} for i in range(start, end): orig = self._ta_call_sites[i] orig_args = getattr(orig, '_orig_ctor_args', orig.ctor_args) resolved_ctor = [_subst_params(a, param_arg_map) for a in orig_args] + # Default name follows the ``{base}_cs{cs_idx}`` formula the + # codegen re-derives. But the SAME base TA site can be reached + # through more than one enclosing function (e.g. a helper cloned + # both via its own call sites AND via a range-widened outer + # function), so two distinct (func, cs_idx) namespaces can mint + # the same name. Detect that collision and fall back to a + # globally-unique name; record the chosen name so the codegen + # consumes it verbatim (see _func_cs_ta_clone_names). + clone_name = f"{orig.member_name}_cs{cs_idx}" + if clone_name in self._ta_member_names: + base = clone_name + n = 2 + while clone_name in self._ta_member_names: + clone_name = f"{base}_u{n}" + n += 1 + clone_name_map[orig.member_name] = clone_name cloned = TACallSite( - member_name=f"{orig.member_name}_cs{cs_idx}", + member_name=clone_name, class_name=orig.class_name, ctor_args=resolved_ctor, compute_args=orig.compute_args[:], @@ -881,6 +1002,9 @@ def _subst_params(arg: str, pmap: dict[str, str]) -> str: is_static=orig.is_static, ) self._ta_call_sites.append(cloned) + self._ta_member_names.add(clone_name) + if clone_name_map: + self._func_cs_ta_clone_names[(func_name, cs_idx)] = clone_name_map # Create or update FuncInfo is_tuple = self._func_returns_tuple.get(func_name, False) diff --git a/pineforge_codegen/analyzer/contracts.py b/pineforge_codegen/analyzer/contracts.py index 9752cee..22dfaa9 100644 --- a/pineforge_codegen/analyzer/contracts.py +++ b/pineforge_codegen/analyzer/contracts.py @@ -156,6 +156,11 @@ class AnalyzerContext: func_ta_ranges: dict = field(default_factory=dict) # func_name -> (start_idx, end_idx) func_call_cs_map: dict = field(default_factory=dict) # call_node_id -> (func_name, call_site_index) func_call_site_counts: dict = field(default_factory=dict) # func_name -> int + # (func_name, cs_idx) -> {orig_member_name: cloned_member_name}. Populated by + # the analyzer ONLY for clones whose default ``{base}_cs{cs_idx}`` name would + # collide with a clone minted through another enclosing function; lets codegen + # use the disambiguated name instead of re-deriving a colliding one. + func_cs_ta_clone_names: dict = field(default_factory=dict) # UDT / enum definitions: udt_defs: dict = field(default_factory=dict) # type_name -> {field_name: PineType} enum_defs: dict = field(default_factory=dict) # enum_name -> [member names] diff --git a/pineforge_codegen/codegen/base.py b/pineforge_codegen/codegen/base.py index e2bc854..d5d3538 100644 --- a/pineforge_codegen/codegen/base.py +++ b/pineforge_codegen/codegen/base.py @@ -274,12 +274,21 @@ def __init__(self, ctx: AnalyzerContext) -> None: # Build cloned site remapping for cs > 0 (must happen before _ta_site_map # so cloned names are in _func_ta_members and get filtered out of the map) + # + # Default to the ``{orig}_cs{cs_idx}`` formula (matches the analyzer's clone + # naming), but defer to the analyzer's authoritative clone-name map for any + # site it had to disambiguate (a TA site reached through multiple enclosing + # functions would otherwise collide on the formula). Keeping the formula as + # the default leaves all non-colliding output byte-identical. + clone_names = getattr(ctx, "func_cs_ta_clone_names", {}) for fname, orig_names in func_ta_originals.items(): total_cs = ctx.func_call_site_counts.get(fname, 1) for cs_idx in range(1, total_cs): + overrides = clone_names.get((fname, cs_idx), {}) remap = {} for orig_name in orig_names: - remap[orig_name] = f"{orig_name}_cs{cs_idx}" + remap[orig_name] = overrides.get( + orig_name, f"{orig_name}_cs{cs_idx}") self._func_cs_ta_remap[(fname, cs_idx)] = remap self._func_ta_members.update(remap.values()) @@ -322,6 +331,12 @@ def __init__(self, ctx: AnalyzerContext) -> None: # Map input-backed var name -> its input.*() FuncCall node so we can # later emit a runtime get_input_*() read with the same title/default. self._input_var_to_call: dict[str, FuncCall] = {} + # Class-scope arithmetic-over-input vars (e.g. ``wilderLen = rsiLen*2-1``). + # Maps the derived var name -> its raw RHS expression string. When such a + # var feeds a TA ctor length, the runtime-reset path expands it so input + # overrides propagate (``(get_input_int("RSI Length",14) * 2 - 1)``); the + # ctor-init list still folds to the Pine-default literal via _resolve_known. + self._derived_input_expr: dict[str, str] = {} self._timeframe_period_vars: set[str] = set() self._collect_known_vars() # Track var names @@ -778,6 +793,45 @@ def walk(node): walk(stmt) return reassigned + def _arith_expr_to_str(self, node) -> str | None: + """Render a numeric arithmetic-over-identifiers expression to a string + whose token spelling matches ``_resolve_known`` / ``_runtime_ctor_arg_for_reset`` + (e.g. ``rsiLen * 2 - 1``). Returns None for any node shape we don't fold + (series subscripts, ternaries, etc.) so the caller leaves the var untracked. + """ + if isinstance(node, NumberLiteral): + v = node.value + if isinstance(v, float) and v == int(v): + return str(int(v)) + return str(v) + if isinstance(node, Identifier): + return node.name + if isinstance(node, MemberAccess) and isinstance(node.object, Identifier): + return f"{node.object.name}.{node.member}" + if isinstance(node, BinOp): + l = self._arith_expr_to_str(node.left) + r = self._arith_expr_to_str(node.right) + if l is None or r is None: + return None + return f"{l} {node.op} {r}" + if isinstance(node, UnaryOp): + o = self._arith_expr_to_str(node.operand) + if o is None: + return None + return f"{node.op}{o}" + if isinstance(node, FuncCall): + callee = self._arith_expr_to_str(node.callee) + if callee is None: + return None + parts = [] + for a in node.args: + s = self._arith_expr_to_str(a) + if s is None: + return None + parts.append(s) + return f"{callee}({', '.join(parts)})" + return None + def _collect_known_var(self, node: VarDecl) -> None: """Extract known constant value from a VarDecl.""" # Don't inline series variables — their values change over time @@ -830,6 +884,36 @@ def _collect_known_var(self, node: VarDecl) -> None: if stored: self._input_backed_vars.add(node.name) self._input_var_to_call[node.name] = node.value + # Class-scope arithmetic over known/input-backed vars + # (``wilderLen = rsiLen * 2 - 1``, ``n = math.round(len / 2)``). + # Without this branch the derived name is untracked, the TA ctor arg + # never folds, and the runtime-reset path silently degenerates to a + # period of 1. We (a) fold to a literal for the ctor-init list and + # (b) record the raw expression so the reset path can re-expand any + # input-backed operand to its get_input_*() runtime read. + elif isinstance(node.value, (BinOp, UnaryOp, FuncCall)): + expr_str = self._arith_expr_to_str(node.value) + if expr_str is not None: + import re as _re + tokens = set(_re.findall(r"[A-Za-z_][A-Za-z_0-9]*", expr_str)) + refs_known = any(t in self._known_vars for t in tokens) + refs_input = any(t in self._input_backed_vars for t in tokens) + refs_derived = any(t in self._derived_input_expr for t in tokens) + if refs_known or refs_input or refs_derived: + folded = self._resolve_known(expr_str) + if self._is_compile_time_value(folded): + try: + num = float(folded) + self._known_vars[node.name] = ( + int(num) if num == int(num) else num + ) + except ValueError: + pass + if refs_input or refs_derived: + # Track as input-backed so use-sites are not inlined and + # the runtime-reset path emits the override-aware expr. + self._derived_input_expr[node.name] = expr_str + self._input_backed_vars.add(node.name) # ------------------------------------------------------------------ # Public entry point @@ -1469,6 +1553,24 @@ def _runtime_ctor_arg_for_reset(self, arg_str: str) -> str | None: """ import re ident_re = re.compile(r"[A-Za-z_][A-Za-z_0-9]*") + + # Expand class-scope derived vars (``wilderLen`` -> ``(rsiLen * 2 - 1)``) + # to their raw RHS so the input-backed leaves become get_input_*() reads + # below. Recursive (bounded) to handle chains of derived vars; guards + # against cycles. + def _expand_derived(s: str, seen: frozenset = frozenset(), depth: int = 0) -> str: + if depth > 32: + return s + def _rep(m: re.Match) -> str: + nm = m.group(0) + if nm in self._derived_input_expr and nm not in seen: + inner = self._derived_input_expr[nm] + return "(" + _expand_derived(inner, seen | {nm}, depth + 1) + ")" + return nm + return ident_re.sub(_rep, s) + + arg_str = _expand_derived(arg_str) + tokens = ident_re.findall(arg_str) input_tokens = [t for t in tokens if t in self._input_backed_vars] if not input_tokens: diff --git a/pineforge_codegen/codegen/emit_top.py b/pineforge_codegen/codegen/emit_top.py index 6109d8e..1ae29d2 100644 --- a/pineforge_codegen/codegen/emit_top.py +++ b/pineforge_codegen/codegen/emit_top.py @@ -214,9 +214,29 @@ def _emit_constructor(self, lines: list[str]) -> None: # TA members with ctor args for site in self.ctx.ta_call_sites: if site.ctor_args: + # If a ctor arg is neither a compile-time literal nor expandable + # to an input-backed runtime expression, the old code silently + # emitted period 1 with no overwriting reset — a wrong indicator + # masquerading as a working one. Refuse loudly instead. Args that + # DO expand to a runtime expr (input-backed / arithmetic-over-input, + # incl. function-derived lengths) are safe: the `!_ta_initialized_` + # reset overwrites the placeholder before the first compute. + for a in site.ctor_args: + r = self._resolve_known(a) + if (not self._is_compile_time_value(r) + and self._runtime_ctor_arg_for_reset(a) is None): + self._codegen_error( + getattr(site, "node", None), + f"Unsupported TA constructor length '{a}' for " + f"{site.class_name}: it is neither a compile-time " + f"constant nor derived from an input, so PineForge " + f"cannot size the indicator buffer.", + hint=("Use a literal, an input.*() value, or " + "arithmetic over those for TA lengths."), + ) resolved = [self._resolve_known(a) for a in site.ctor_args] - # If any ctor arg isn't a compile-time value, use default 1 - # (TA in user functions with runtime params) + # Compile-time placeholder for the init list; the runtime reset + # (when the arg is input-derived) overwrites it on the first bar. safe_resolved = [] for r in resolved: if self._is_compile_time_value(r): diff --git a/pineforge_codegen/codegen/security.py b/pineforge_codegen/codegen/security.py index 59ac866..8531a21 100644 --- a/pineforge_codegen/codegen/security.py +++ b/pineforge_codegen/codegen/security.py @@ -98,6 +98,8 @@ def _resolve_security_tf(self, tf_node, containing_func: str): if (name in self._known_vars and name not in self._input_backed_vars and isinstance(self._known_vars[name], str)): return self._known_vars[name], None + if name in self._input_backed_vars and name in self._input_var_to_call: + return None, self._visit_expr(self._input_var_to_call[name]) # class-scope resolvable (global / input member)? if self._ident_is_resolvable(name): try: diff --git a/tests/test_codegen_ta_derived_length.py b/tests/test_codegen_ta_derived_length.py new file mode 100644 index 0000000..01cc848 --- /dev/null +++ b/tests/test_codegen_ta_derived_length.py @@ -0,0 +1,183 @@ +"""Regression tests for TA constructor lengths derived from arithmetic over +inputs and from function-local / parameter expressions. + +Bug history: when a ``ta.*`` length argument was not a literal or a direct +input alias — i.e. arithmetic over an input/const (``wilderLen = rsiLen*2-1``) +or a function-local / parameter-derived length (``wp = sf*2-1``; a helper +``f(src,_len) => ta.sma(src,_len)``) — the transpiler silently emitted a TA +constructor period of 1 with no overwriting runtime reset. The smoother then +degenerated to a no-op, producing a wrong indicator and wrong signals. + +These tests pin the three faithful behaviors: + 1. class-scope arithmetic over an input folds the ctor-init to the literal + AND emits an override-aware runtime reset; + 2. function-local / parameter-derived lengths (including a length threaded + through a nested user-function call) resolve to the real input; + 3. a legitimate input that genuinely defaults to 1 stays period 1; +and the guardrail: a genuinely-unresolvable computed length raises instead of +silently emitting period 1. +""" + +import re + +import pytest + +from pineforge_codegen import transpile +from pineforge_codegen.errors import CompileError + + +def _ctor_period(cpp: str, member: str) -> str: + """Return the ctor-init period for a TA member, e.g. ``_ta_ema_9`` -> '27'.""" + m = re.search(rf"\b{re.escape(member)}\(([^)]*)\)", cpp) + assert m, f"{member} not found in initializer list" + return m.group(1).strip() + + +def _reset_line(cpp: str, member: str) -> str: + """Return the runtime-reset assignment for a TA member (the line that + overwrites the ctor placeholder under ``!_ta_initialized_``).""" + for ln in cpp.splitlines(): + s = ln.strip() + if s.startswith(f"{member} = ") and s.endswith(";"): + return s + return "" + + +# --------------------------------------------------------------------------- +# 1. Class-scope arithmetic over an input (wilderLen = rsiLen * 2 - 1) +# --------------------------------------------------------------------------- + +def test_class_scope_arithmetic_over_input_length(): + src = """//@version=6 +strategy("derived-class-scope") +rsiLen = input.int(14, "RSI Length") +wilderLen = rsiLen * 2 - 1 +src = ta.ema(close, wilderLen) +plot(src) +""" + cpp = transpile(src) + # Find the EMA member sized by wilderLen. + members = re.findall(r"(_ta_ema_\d+)\(", cpp) + assert members, "no EMA member emitted" + # The wilderLen-sized EMA folds to 14*2-1 = 27 in the init list. + sized = [m for m in members if _ctor_period(cpp, m) == "27"] + assert sized, f"expected an EMA with ctor period 27, got {[(_m, _ctor_period(cpp, _m)) for _m in members]}" + member = sized[0] + # And the runtime reset re-derives it from the (possibly overridden) input. + reset = _reset_line(cpp, member) + assert 'get_input_int("RSI Length", 14)' in reset + assert "* 2 - 1" in reset + assert "ta::EMA(1)" not in reset # must NOT be the silent no-op + + +# --------------------------------------------------------------------------- +# 2a. Function-local derived length (qqeCalc: wp = sf * 2 - 1) +# --------------------------------------------------------------------------- + +def test_function_local_derived_length(): + src = """//@version=6 +strategy("derived-func-local") +rsiSmooth = input.int(5, "Smooth EMA Length") +qqeCalc(int sf) => + wp = sf * 2 - 1 + ta.ema(ta.ema(close, wp), wp) +out = qqeCalc(rsiSmooth) +plot(out) +""" + cpp = transpile(src) + members = re.findall(r"(_ta_ema_\d+)\(", cpp) + # sf = 5 -> wp = 9. The wp-sized EMAs must be period 9, not 1. + sized = [m for m in members if _ctor_period(cpp, m) == "9"] + assert len(sized) >= 2, ( + f"expected >=2 EMAs with ctor period 9, got " + f"{[(_m, _ctor_period(cpp, _m)) for _m in members]}" + ) + for m in sized: + reset = _reset_line(cpp, m) + assert 'get_input_int("Smooth EMA Length", 5)' in reset + assert "* 2 - 1" in reset + assert "ta::EMA(1)" not in cpp + + +# --------------------------------------------------------------------------- +# 2b. Length threaded through a NESTED user-function call +# f_basisMa(src, _len) => ta.sma(src, _len) called from f_bbwp(_bbwLen) +# --------------------------------------------------------------------------- + +def test_nested_user_func_param_length(): + src = """//@version=6 +strategy("derived-nested-param") +i_bbwLen = input.int(7, "BBW Basis Length") +f_basisMa(float _src, int _len) => + ta.sma(_src, _len) +f_bbwp(float _price, int _bbwLen) => + f_basisMa(_price, _bbwLen) + ta.stdev(_price, _bbwLen) +out = f_bbwp(close, i_bbwLen) +plot(out) +""" + cpp = transpile(src) + # The SMA inside f_basisMa must be sized by the real input (7), not 1. + sma_members = re.findall(r"(_ta_sma_\d+)\(", cpp) + assert sma_members + sized = [m for m in sma_members if _ctor_period(cpp, m) == "7"] + assert sized, ( + f"nested SMA length not resolved; got " + f"{[(_m, _ctor_period(cpp, _m)) for _m in sma_members]}" + ) + reset = _reset_line(cpp, sized[0]) + assert 'get_input_int("BBW Basis Length", 7)' in reset + assert "ta::SMA(1)" not in cpp + + +# --------------------------------------------------------------------------- +# 3. Legitimate input that genuinely defaults to 1 stays period 1 +# --------------------------------------------------------------------------- + +def test_legit_input_default_one_preserved(): + src = """//@version=6 +strategy("legit-default-one") +atrLen = input.int(1, "UT Bot ATR Period") +a = ta.atr(atrLen) +plot(a) +""" + cpp = transpile(src) + members = re.findall(r"(_ta_atr_\d+)\(", cpp) + assert members + m = members[0] + assert _ctor_period(cpp, m) == "1" + # The reset is still emitted so an override (e.g. set ATR Period = 10) + # re-sizes the buffer — the period 1 here is the genuine Pine default. + reset = _reset_line(cpp, m) + assert 'get_input_int("UT Bot ATR Period", 1)' in reset + + +# --------------------------------------------------------------------------- +# Guardrail: a genuinely-unresolvable computed length raises (no silent 1). +# --------------------------------------------------------------------------- + +def test_unresolvable_length_raises_loudly(): + # ``barsSince(...)`` is a runtime series, not a const or input — there is + # no faithful compile-time buffer size and no input to re-derive from. + src = """//@version=6 +strategy("unresolvable-length") +n = ta.barssince(close > open) +v = ta.ema(close, n) +plot(v) +""" + with pytest.raises(CompileError): + transpile(src) + + +# --------------------------------------------------------------------------- +# Determinism: transpiling the same source twice is byte-identical. +# --------------------------------------------------------------------------- + +def test_derived_length_transpile_is_deterministic(): + src = """//@version=6 +strategy("determinism") +rsiLen = input.int(14, "RSI Length") +wilderLen = rsiLen * 2 - 1 +x = ta.ema(close, wilderLen) +plot(x) +""" + assert transpile(src) == transpile(src)