Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 157 additions & 8 deletions pineforge_codegen/analyzer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(self, ast: Program, filename: str = "<stdin>") -> None:
# expression (``=> Sample.new(...)`` or last stmt ``Sample.new(...)``).
# Probe: data/validation/udt-method-probe-20-udt-return-from-func.
self._func_udt_return_types: dict[str, str] = {}
self._func_return_type_specs: dict[str, "TypeSpec"] = {}
# Per-function var_members and series_vars (for call-site cloning)
self._func_var_members: dict[str, list] = {} # func_name -> [(name, PineType, init_str)]
self._func_series_vars: dict[str, set] = {} # func_name -> set[str]
Expand Down Expand Up @@ -232,6 +233,11 @@ def analyze(self) -> AnalyzerContext:
# can call g_csK with isolated state.
self._propagate_call_site_counts()

# Reject (loudly) a request.security whose timeframe is a UDF parameter
# called with multiple distinct literal timeframes — a single evaluator
# cannot serve them and per-callsite specialization is not yet wired.
self._check_mixed_callsite_security_tf()

# Keep only truly pure global expressions for request.security rebinding.
# Globals later reassigned with := become series/stateful variables and
# must not be rebound to their declaration-time initializer.
Expand Down Expand Up @@ -273,6 +279,7 @@ def analyze(self) -> AnalyzerContext:
global_mutable_infos=mutable_global_infos,
func_var_members=self._func_var_members,
func_series_vars=self._func_series_vars,
func_return_type_specs=dict(self._func_return_type_specs),
udt_var_types=dict(self._udt_var_types),
collection_types=dict(self._collection_types),
udt_field_type_specs=dict(self._udt_field_type_specs),
Expand Down Expand Up @@ -418,6 +425,98 @@ def _find_calls(node, known_funcs: set[str]) -> set[str]:
self._func_call_site_count[sub] = count
changed = True

# ------------------------------------------------------------------
# Mixed-callsite UDF timeframe-param security rejection.
#
# A ``request.security`` whose ``timeframe`` is a parameter of its
# containing UDF maps to ONE evaluator regardless of how many times the
# UDF is called. When the UDF is called from >= 2 sites with DISTINCT
# literal timeframes, a single evaluator cannot faithfully serve them
# all and the resolver would silently collapse onto the chart timeframe
# (``input_tf_``). Per-callsite evaluator specialization (cloning the
# evaluator + UDF) is the correct fix but is not wired in this iteration,
# so we reject deterministically instead of emitting wrong semantics.
# ------------------------------------------------------------------
def _check_mixed_callsite_security_tf(self) -> None:
sec_calls = getattr(self, "_security_calls", None)
if not sec_calls:
return
# Build user-function definitions lookup once.
func_defs: dict[str, FuncDef] = {}
for stmt in self._ast.body:
if isinstance(stmt, FuncDef):
func_defs[stmt.name] = stmt

for sec in sec_calls:
containing = getattr(sec, "containing_func", "") or ""
if not containing:
continue
tf_node = getattr(sec, "timeframe", None)
if not isinstance(tf_node, Identifier):
continue
param_name = tf_node.name
fdef = func_defs.get(containing)
if fdef is None or param_name not in fdef.params:
continue
pidx = fdef.params.index(param_name)
literals: set[str] = set()
found_call = False
for call in self._iter_user_func_calls(containing):
found_call = True
arg = call.args[pidx] if pidx < len(call.args) else None
lit = self._callsite_tf_literal_value(arg)
if lit is not None:
literals.add(lit)
if not found_call:
continue # dead code — evaluator result never read
if len(literals) >= 2:
self._error(
"request.security timeframe parameter '"
+ param_name
+ "' of function '"
+ containing
+ "' is called with multiple distinct literal timeframes ("
+ ", ".join(sorted(literals))
+ "). A single request.security evaluator cannot serve "
"them all and would silently collapse onto the chart "
"timeframe. Pass a single timeframe, or inline a separate "
"request.security call at each call site.",
tf_node.loc,
)

def _iter_user_func_calls(self, func_name: str):
"""Yield every ``func_name(...)`` call anywhere in the AST (top-level
and nested inside function bodies)."""
def _walk(node):
if node is None:
return
if (isinstance(node, FuncCall) and isinstance(node.callee, Identifier)
and node.callee.name == func_name):
yield node
for attr_val in vars(node).values():
if isinstance(attr_val, list):
for item in attr_val:
if hasattr(item, "__dict__"):
yield from _walk(item)
elif attr_val is not None and hasattr(attr_val, "__dict__"):
yield from _walk(attr_val)
yield from _walk(self._ast)

def _callsite_tf_literal_value(self, arg) -> str | None:
"""Resolve a UDF call-site timeframe argument to a literal string
value when it is statically known: a string literal, or a known
constant / input-backed variable whose stored value is a string.
Returns None for anything that is not a compile-time string."""
if isinstance(arg, StringLiteral):
return arg.value
if isinstance(arg, Identifier):
sym = self._symbols.resolve(arg.name)
if sym is not None and getattr(sym, "const_value", None) is not None:
val = sym.const_value
if isinstance(val, str):
return val
return None

def _is_static_expression(self, node: ASTNode | None) -> bool:
if node is None:
return True
Expand Down Expand Up @@ -535,19 +634,28 @@ def _visit_ImportStmt(self, node: ImportStmt) -> PineType:
# ------------------------------------------------------------------

def _udt_name_from_ctor(self, value: ASTNode) -> str | None:
"""If value is ``TypeName.new(...)`` for a user-defined type, return TypeName."""
"""If value is ``TypeName.new(...)`` for a user-defined type OR a
drawing handle (``label.new``/``line.new``/``box.new``/``linefill.new``),
return the type name."""
if not isinstance(value, FuncCall):
return None
cal = value.callee
if not isinstance(cal, MemberAccess) or not isinstance(cal.object, Identifier):
return None
owner = cal.object.name
if owner not in self._udt_fields:
return None
m = cal.member
if m == "new" or (isinstance(m, str) and m.startswith("new")):
if not (m == "new" or (isinstance(m, str) and m.startswith("new"))):
return None
# Drawing-objects-as-data: label.new(...)/line.new(...)/... return a
# handle of the self-type. These are not in _udt_fields (they are not
# user UDTs) but must still be recognised so a function whose body ends
# in label.new(...) emits a ``Label`` (not ``double``) return type.
from .types import _DRAWING_TYPE_NAMES
if owner in _DRAWING_TYPE_NAMES:
return owner
return None
if owner not in self._udt_fields:
return None
return owner

def _visit_VarDecl(self, node: VarDecl) -> PineType:
# Infer type from the value expression
Expand Down Expand Up @@ -732,6 +840,20 @@ def _visit_TupleAssign(self, node: TupleAssign) -> PineType:
setattr(sym, "is_static_series", True)
self._symbols.define(sym)

# Track global-scope tuple-assign targets (e.g.
# ``[pdH, pdL] = request.security(...)``) as class members so user
# functions / later references resolve — mirroring _visit_VarDecl.
# Without this the names are never declared and the C++ errors with
# "use of undeclared identifier".
if (self._global_scope
and self._symbols.current_scope.name == "global"
and name not in self._series_vars):
self._global_var_decls.append((name, PineType.FLOAT))
self._global_expr_map[name] = node.value
self._record_global_binding_stmt(
name, PineType.FLOAT, False, decl_node=node,
)

return val_type

# ------------------------------------------------------------------
Expand All @@ -745,18 +867,28 @@ def _visit_FuncDef(self, node: FuncDef) -> PineType:
# Enter function scope
self._symbols.enter_scope(f"func_{node.name}")

# Define parameters (type unknown until called)
# Define parameters. The type is UNKNOWN until inferred from a call
# site, BUT a declared type hint (``string tf``, ``pivot hi``, ``line[] arr``)
# is authoritative — record it as the symbol's ``type_spec`` / ``pine_type``
# so (a) the param emits with the right C++ type and (b) callers passing
# this param into another function can infer that function's param type
# (e.g. ``getLineStyle(styleStr)`` where ``styleStr`` is a ``string`` param).
loc = node.loc or SourceLocation(file=self._filename, line=1, col=1, end_col=1)
for param in node.params:
param_hints = (node.annotations or {}).get("param_type_hints", [])
for i, param in enumerate(node.params):
hint = param_hints[i] if i < len(param_hints) else None
pspec = self._type_spec_from_hint(hint) if hint else None
ptype = self._type_hint_to_pine(hint) if hint else PineType.UNKNOWN
sym = Symbol(
name=param,
pine_type=PineType.UNKNOWN,
pine_type=ptype,
is_series=False,
is_var=False,
is_const=False,
const_value=None,
scope=f"func_{node.name}",
loc=loc,
type_spec=pspec,
)
self._symbols.define(sym)

Expand Down Expand Up @@ -809,6 +941,14 @@ def _visit_FuncDef(self, node: FuncDef) -> PineType:
udt_ret = self._udt_name_from_ctor(ret_expr) if ret_expr is not None else None
if udt_ret is not None:
self._func_udt_return_types[node.name] = udt_ret
# Array-return inference: a function whose body ends in
# ``array.from(...)`` / ``array.new<T>(...)`` / a UDT method
# returning an array returns a ``std::vector<...>``. The coarse
# PineType return can't represent this, so carry the TypeSpec.
if ret_expr is not None:
ret_spec = self._type_spec_from_expr(ret_expr)
if ret_spec is not None and ret_spec.kind == "array":
self._func_return_type_specs[node.name] = ret_spec

# Store return type
self._func_return_types[node.name] = body_type
Expand Down Expand Up @@ -880,12 +1020,14 @@ def _visit_MethodDef(self, node) -> PineType:
loc = node.loc or SourceLocation(file=self._filename, line=1, col=1, end_col=1)
param_hints = (node.annotations or {}).get("param_type_hints", [])
param_types: list[PineType] = []
param_specs: list = []
for i, p in enumerate(node.params):
udt_self = node.type_name if i == 0 else None
hint = param_hints[i] if i < len(param_hints) else None
ptype = self._type_hint_to_pine(hint) if hint else PineType.FLOAT
pspec = self._type_spec_from_hint(hint) if hint else None
param_types.append(ptype)
param_specs.append(pspec)
self._symbols.define(Symbol(
name=p, pine_type=ptype, is_series=False,
is_var=False, is_const=False, const_value=None,
Expand Down Expand Up @@ -939,6 +1081,7 @@ def _visit_MethodDef(self, node) -> PineType:
returns_tuple=returns_tuple,
tuple_element_count=tuple_element_count,
param_defaults=param_defaults,
param_type_specs=param_specs,
)
self._func_infos.append(fi)
return PineType.VOID
Expand Down Expand Up @@ -1151,6 +1294,12 @@ def _visit_FuncCall(self, node: FuncCall) -> PineType:
if isinstance(obj, Identifier) and obj.name == "str":
for arg in node.args:
self._visit(arg)
# Most str.* return a string, but a few don't:
# str.tonumber -> float, str.length -> int
if member == "tonumber":
return PineType.FLOAT
if member == "length":
return PineType.INT
return PineType.STRING

# request.* calls
Expand Down
56 changes: 48 additions & 8 deletions pineforge_codegen/analyzer/call_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from typing import Any

from ..ast_nodes import (
ASTNode, BoolLiteral, FuncCall, Identifier, MemberAccess,
ASTNode, BoolLiteral, ExprStmt, FuncCall, Identifier, MemberAccess,
NumberLiteral, StringLiteral, TupleLiteral,
)
from ..symbols import PineType
Expand Down Expand Up @@ -293,6 +293,10 @@ def _handle_request_call(self, func_name: str, node: FuncCall) -> PineType:
lookahead_node = all_args[4] if len(all_args) > 4 else None

mutable_globals = tuple(sorted(self._collect_security_mutable_globals(expr_node)))
# Capture the user function (if any) whose body contains this call,
# so the codegen can resolve a parameter ``tf`` via the call sites.
scope_name = self._symbols.current_scope.name
containing_func = scope_name[5:] if scope_name.startswith("func_") else ""
self._security_calls.append(SecurityCallInfo(
sec_id=sec_id,
timeframe=tf_node,
Expand All @@ -304,6 +308,7 @@ def _handle_request_call(self, func_name: str, node: FuncCall) -> PineType:
ta_range=security_ta_range,
depends_on_mutable_globals=bool(mutable_globals),
mutable_globals=mutable_globals,
containing_func=containing_func,
))

return PineType.FLOAT
Expand Down Expand Up @@ -783,14 +788,28 @@ def _handle_user_func_call(self, func_name: str, node: FuncCall) -> PineType:
# For now, use the cached return type from initial analysis
return_type = self._func_return_types.get(func_name, PineType.FLOAT)

# If the return type was UNKNOWN or VOID, infer from param types
# If the return type was UNKNOWN or VOID, infer it ONLY when the body
# is a single bare identifier that returns a parameter directly
# (``f(s) => s``). Inferring from params for arbitrary bodies misfires
# when a function merely HAS a string param but returns something else
# (e.g. ``getLineStyle(s) => switch s ... => line.style_solid`` or a
# body ending in ``label.new(...)``). Other cases rely on the cached
# body type plus udt_return_type / tuple inference.
if return_type in (PineType.UNKNOWN, PineType.VOID):
if any(t == PineType.STRING for t in param_types):
return_type = PineType.STRING
elif any(t == PineType.FLOAT for t in param_types):
return_type = PineType.FLOAT
elif any(t == PineType.INT for t in param_types):
return_type = PineType.INT
if (func_def.is_single_expr and func_def.body
and isinstance(func_def.body[0], ExprStmt)
and isinstance(func_def.body[0].expr, Identifier)):
ret_name = func_def.body[0].expr.name
for idx, pname in enumerate(func_def.params):
if pname == ret_name and idx < len(param_types):
pt = param_types[idx]
if pt == PineType.STRING:
return_type = PineType.STRING
elif pt == PineType.INT:
return_type = PineType.INT
elif pt == PineType.FLOAT:
return_type = PineType.FLOAT
break

# If this function has series params, ensure bar-field arguments
# passed at the call site are registered as series_bar_fields so that
Expand Down Expand Up @@ -869,6 +888,15 @@ def _subst_params(arg: str, pmap: dict[str, str]) -> str:
# Forward UDT-return inference (set in _visit_FuncDef) so codegen can
# emit the struct return type. Probe: udt-method-probe-20.
udt_ret = self._func_udt_return_types.get(func_name)
ret_spec = getattr(self, "_func_return_type_specs", {}).get(func_name)
# Per-param TypeSpec: declared hints are authoritative; for untyped
# params, infer from the call-site argument's type_spec (so an untyped
# ``s`` used as a string, or a UDT passed by value, emits correctly).
param_specs = self._param_type_specs_from_def(func_def)
arg_specs = [self._type_spec_from_expr(arg) for arg in node.args]
for i in range(len(param_specs)):
if param_specs[i] is None and i < len(arg_specs):
param_specs[i] = arg_specs[i]
existing = [fi for fi in self._func_infos if fi.name == func_name]
if not existing:
fi = FuncInfo(
Expand All @@ -879,6 +907,8 @@ def _subst_params(arg: str, pmap: dict[str, str]) -> str:
returns_tuple=is_tuple,
tuple_element_count=tuple_count,
udt_return_type=udt_ret,
param_type_specs=param_specs,
return_type_spec=ret_spec,
)
self._func_infos.append(fi)
else:
Expand All @@ -889,7 +919,17 @@ def _subst_params(arg: str, pmap: dict[str, str]) -> str:
for i, pt in enumerate(param_types):
if i < len(fi.param_types) and fi.param_types[i] == PineType.UNKNOWN:
fi.param_types[i] = pt
# Merge per-param TypeSpecs: keep declared hints (authoritative),
# fill untyped slots from this call site if still unknown.
if not fi.param_type_specs:
fi.param_type_specs = list(param_specs)
else:
for i in range(len(param_specs)):
if i < len(fi.param_type_specs) and fi.param_type_specs[i] is None:
fi.param_type_specs[i] = param_specs[i]
if fi.udt_return_type is None and udt_ret is not None:
fi.udt_return_type = udt_ret
if fi.return_type_spec is None and ret_spec is not None:
fi.return_type_spec = ret_spec

return return_type
Loading
Loading