diff --git a/lib/spitfire.ex b/lib/spitfire.ex index be54d5c..d28c027 100644 --- a/lib/spitfire.ex +++ b/lib/spitfire.ex @@ -1283,52 +1283,149 @@ defmodule Spitfire do meta end - lhs = - case lhs do - {:__block__, _, []} -> - [] + {lhs, parser} = normalize_stab_lhs(lhs, parser) - {:__block__, [{:parens, _} | _], [[{key, _} | _] = kw]} when is_atom(key) -> - [kw] + ast = + {token, meta, [lhs, rhs]} - {:__block__, [{:parens, _} | _], [[{{_, _, _}, _} | _] = kw]} -> - [kw] + parser = Map.put(parser, :nesting, old_nesting) - {:comma, _, lhs} -> - lhs + {ast, eat_eoe(parser)} + end + end + end - {:when, [{:parens, _} | when_meta], when_args} -> - [{:when, when_meta, when_args}] + # Normalize stab clause LHS forms into the shape expected by `->` construction. + # When parsing fn heads, this also performs grammar-level validation for + # disallowed nested parentheses while keeping recoverable AST output. + defp normalize_stab_lhs(lhs, parser) do + in_fn_head_context? = Map.get(parser, :fn_head_context?, false) - {:__block__, [{:parens, _} = paren_meta | _], exprs} -> - case exprs do - [[{key, _} | _] = kw] when is_atom(key) -> - [{:__block__, [paren_meta], [kw]}] + case lhs do + {:__block__, _, []} -> + {[], parser} - [[{{_, _, _}, _} | _] = kw] -> - [{:__block__, [paren_meta], [kw]}] + {:comma, comma_meta, comma_args} -> + parser = + if in_fn_head_context? do + parser + |> maybe_error_nested_top_parens(comma_meta) + |> maybe_error_invalid_fn_head_args(comma_args) + else + parser + end - [expr] -> - [{:__block__, [paren_meta], [expr]}] + {comma_args, parser} - _ -> - lhs - end + {:when, [{:parens, _} | when_meta], when_args} -> + {[{:when, when_meta, when_args}], parser} - lhs -> - [lhs] - end + {:__block__, [{:parens, _} = paren_meta | _], exprs} -> + case {parenthesized_kw_head(lhs), exprs} do + {{:ok, block_meta, _paren_meta, kw}, _} -> + parser = + if in_fn_head_context? do + maybe_error_nested_top_parens(parser, block_meta) + else + parser + end - ast = - {token, meta, [lhs, rhs]} + {[kw], parser} - parser = Map.put(parser, :nesting, old_nesting) + {:error, [expr]} -> + {[{:__block__, [paren_meta], [expr]}], parser} - {ast, eat_eoe(parser)} + {:error, _} -> + {lhs, parser} + end + + lhs -> + {[lhs], parser} + end + end + + # In fn-head context, only certain LHS shapes should lower `when` precedence. + # This allows `<-` and `\\` to be consumed by the guard expression when the + # head is explicitly grouped (empty parens, parenthesized comma args, or + # parenthesized keyword heads). + defp fn_head_simple_for_when_precedence?(lhs) do + match?({:__block__, _, []}, lhs) or + match?({:comma, [{:parens, _} | _], _}, lhs) or + match?({:ok, _block_meta, _paren_meta, _kw}, parenthesized_kw_head(lhs)) + end + + # Validate raw fn-head LHS before `when` normalization can flatten argument + # structure and hide nested-parens evidence. + defp maybe_error_invalid_fn_head_lhs(parser, lhs) do + case lhs do + {:comma, comma_meta, comma_args} -> + parser + |> maybe_error_nested_top_parens(comma_meta) + |> maybe_error_invalid_fn_head_args(comma_args) + + lhs -> + case parenthesized_kw_head(lhs) do + {:ok, block_meta, _paren_meta, _kw} -> + maybe_error_nested_top_parens(parser, block_meta) + + :error -> + parser + end + end + end + + # Reject nested grouped tuple/keyword arguments inside fn heads, for example + # `fn (a, (b, c)) -> ... end` and `fn ((a, b), c) -> ... end`. + defp maybe_error_invalid_fn_head_args(parser, args) when is_list(args) do + Enum.reduce(args, parser, fn arg, parser -> + case arg do + {:comma, [{:parens, parens_meta} | _], _} -> + put_error(parser, {nested_parens_error_meta(parens_meta), "unexpected parentheses"}) + + _ -> + case parenthesized_kw_head(arg) do + {:ok, _block_meta, {:parens, parens_meta}, _kw} -> + put_error(parser, {nested_parens_error_meta(parens_meta), "unexpected parentheses"}) + + :error -> + parser + end end + end) + end + + defp maybe_error_invalid_fn_head_args(parser, _args), do: parser + + # Reject top-level fn-head wrappers that carry more than one parens entry, + # such as `((a, b))` and `((a: 1))`. + defp maybe_error_nested_top_parens(parser, meta) do + case Enum.at(Keyword.get_values(meta, :parens), 1) do + nil -> + parser + + inner_parens -> + put_error(parser, {nested_parens_error_meta(inner_parens), "unexpected parentheses"}) end end + # Prefer the closing paren location when available to match parser diagnostics. + defp nested_parens_error_meta(parens_meta) do + Keyword.take(parens_meta[:closing] || parens_meta, [:line, :column]) + end + + # Extract parenthesized keyword-list heads like `(a: 1)` or `('x': 1)` from + # grouped AST nodes so fn-head normalization can preserve paren metadata. + defp parenthesized_kw_head({:__block__, [{:parens, _} = paren_meta | _] = block_meta, [[{key, _} | _] = kw]}) + when is_atom(key) do + {:ok, block_meta, paren_meta, kw} + end + + defp parenthesized_kw_head({:__block__, [{:parens, _} = paren_meta | _] = block_meta, [[{{_, _, _}, _} | _] = kw]}) do + {:ok, block_meta, paren_meta, kw} + end + + defp parenthesized_kw_head(_), do: :error + # Widen stab_state when outer expression is more complete than when `->` was first detected. defp maybe_widen_stab_state(parser, ast) do case {Map.get(parser, :stab_state), ast} do @@ -1396,11 +1493,17 @@ defmodule Spitfire do # e.g., `() when bar 1, 2, 3 -> foo()` should parse `bar 1, 2, 3` as the guard {rhs, parser} = if token == :when do - # Check if when has simple LHS (empty block or comma args). + # Check if when has simple LHS (empty block or parenthesized comma args). # If so and we're in fn context, use lower precedence to allow <- in guard. - in_fn_context = Map.get(parser, :stop_before_stab_op?, false) - simple_lhs = match?({:__block__, _, []}, lhs) or match?({:comma, _, _}, lhs) - when_precedence = if in_fn_context and simple_lhs, do: @list_comma, else: effective_precedence + # For bare comma args (no parens), keep normal precedence so `<-` and + # `\\` bind to the trailing head argument instead of becoming part of + # the `when` guard. + # Also validate the raw fn-head lhs here before `when` flattening. + in_fn_context = Map.get(parser, :fn_head_context?, false) + parser = if in_fn_context, do: maybe_error_invalid_fn_head_lhs(parser, lhs), else: parser + + when_precedence = + if in_fn_context and fn_head_simple_for_when_precedence?(lhs), do: @list_comma, else: effective_precedence {rhs, parser} = with_context(parser, %{stop_before_stab_op?: true}, fn parser -> @@ -1448,14 +1551,6 @@ defmodule Spitfire do # Empty block without parens {token, newlines ++ meta, [rhs]} - {:__block__, [{:parens, _} = paren_meta | _], [[{key, _} | _] = kw]} when is_atom(key) -> - # (a: 1) when ... - preserve parens meta for stab - {token, [paren_meta | newlines ++ meta], [kw, rhs]} - - {:__block__, [{:parens, _} = paren_meta | _], [[{{_, _, _}, _} | _] = kw]} -> - # Parenthesized kw list with interpolated key - {token, [paren_meta | newlines ++ meta], [kw, rhs]} - {:comma, [{:parens, _} = paren_meta | _], args} -> {token, [paren_meta | newlines ++ meta], args ++ [rhs]} @@ -1463,11 +1558,30 @@ defmodule Spitfire do {token, newlines ++ meta, args ++ [rhs]} _ -> - {token, newlines ++ meta, [lhs, rhs]} + case parenthesized_kw_head(lhs) do + {:ok, _block_meta, paren_meta, kw} -> + # (a: 1) when ... - preserve parens meta for stab + {token, [paren_meta | newlines ++ meta], [kw, rhs]} + + :error -> + {token, newlines ++ meta, [lhs, rhs]} + end end _ -> - {token, newlines ++ meta, [lhs, rhs]} + case lhs do + {:comma, comma_meta, args} when is_list(args) and args != [] -> + {leading, [last]} = Enum.split(args, -1) + {:comma, comma_meta, leading ++ [{token, newlines ++ meta, [last, rhs]}]} + + {:when, when_meta, when_args} when is_list(when_args) and length(when_args) > 2 -> + {leading, [second_last, guard]} = Enum.split(when_args, -2) + when_node = {:when, when_meta, [second_last, guard]} + {:comma, [], leading ++ [{token, newlines ++ meta, [when_node, rhs]}]} + + _ -> + {token, newlines ++ meta, [lhs, rhs]} + end end {ast, parser} @@ -1491,7 +1605,8 @@ defmodule Spitfire do (unmatched_expr?(lhs) and rhs_has_bare_comma?(rhs_parser)) do # When the RHS of `|` has low-precedence operators (::, when, <-, \\) or # the LHS is an unmatched_expr (do-end) and the RHS has no-parens commas, - # treat `|` as a regular pipe operator (matching Elixir's LALR grammar). + # treat `|` as a regular infix operator so RHS parsing completes before + # map-update pair extraction. parse_infix_expression(parser, lhs) else {pairs, pairs_parser} = parse_map_update_pairs(rhs_parser) @@ -2112,11 +2227,11 @@ defmodule Spitfire do newlines = get_newlines(parser) parser = parser |> next_token() |> eat_eoe() - # fn creates its own stab scope + # fn creates its own stab scope and enables fn-head specific validation. parser = Map.delete(parser, :stab_state) {exprs, parser} = - with_context(parser, %{stop_before_stab_op?: true}, fn parser -> + with_context(parser, %{stop_before_stab_op?: true, fn_head_context?: true}, fn parser -> while2 current_token(parser) not in [:end, :eof] <- parser do {ast, parser} = case Map.get(parser, :stab_state) do diff --git a/test/spitfire_test.exs b/test/spitfire_test.exs index 1dff579..68291b2 100644 --- a/test/spitfire_test.exs +++ b/test/spitfire_test.exs @@ -1483,6 +1483,42 @@ defmodule SpitfireTest do end end + test "fn args with parenthesized heads and low-precedence operators" do + codes = [ + "fn a, u\\\\c -> :ok end", + "fn (a, 0<-c) -> :ok end", + "fn (a, b<-c<-d) -> :ok end", + "fn (a, b\\\\c\\\\d) -> :ok end", + "fn (a, b<-c) when is_integer(c) -> :ok end", + "fn (a, b\\\\c) when is_integer(c) -> :ok end", + "fn (a, b, d<-c) when is_integer(c) -> :ok end", + "fn (a, b, d\\\\c) when is_integer(c) -> :ok end", + "fn (a, b<-c) when c<-c -> :ok end", + "fn (a, b\\\\c) when c\\\\c -> :ok end", + "fn (a, b<-c<-d) when c<-c -> :ok end", + "fn (a, b\\\\c\\\\d) when c\\\\c -> :ok end", + "fn a, (b<-c) when c<-c -> :ok end", + "fn a, (b\\\\c) when c\\\\c -> :ok end", + "fn a, b<-c when c<-c -> :ok end", + "fn a, b\\\\c when c\\\\c -> :ok end", + "fn a, b when c<-c -> :ok end", + "fn a, b when c\\\\c -> :ok end", + "fn a, b when c<-c<-d -> :ok end", + "fn (a: 1) when c<-c -> :ok end", + "fn (a: 1) when c<-c<-d -> :ok end", + "fn (a: 1) when c\\\\c -> :ok end", + "fn (a: 1) when c\\\\c\\\\d -> :ok end", + "fn (x: 1, y: 2) when c<-c -> :ok end", + "fn (x: 1, y: 2) when c\\\\c -> :ok end", + "fn ('x': 1, y: 2) when c<-c -> :ok end", + "fn (x: 1, 'y': 2) when c\\\\c -> :ok end" + ] + + for code <- codes do + assert Spitfire.parse(code) == s2q(code) + end + end + test "capture operator" do codes = [ ~s''' @@ -2301,6 +2337,9 @@ defmodule SpitfireTest do assert Spitfire.parse("%e.(){}") == s2q("%e.(){}") assert Spitfire.parse("%e.(1){}") == s2q("%e.(1){}") assert Spitfire.parse("%e.(a, b){}") == s2q("%e.(a, b){}") + + # Regression from absinthe-graphql/absinthe lexer + assert Spitfire.parse("comma = ascii_char([?,])") == s2q("comma = ascii_char([?,])") end end @@ -2662,6 +2701,33 @@ defmodule SpitfireTest do } end + test "rejects nested parenthesized fn args" do + codes = [ + # whole arg list double/triple-wrapped + "fn ((a, b)) -> :ok end", + "fn (((a, b))) -> :ok end", + "fn ((a, b)) when true -> :ok end", + "fn ((a, b<-c)) -> :ok end", + "fn ((a, b\\\\c)) -> :ok end", + "fn (((a, b<-c))) -> :ok end", + # keyword list double-wrapped + "fn ((a: 1)) -> :ok end", + "fn ((a: 1)) when true -> :ok end", + # individual args as parenthesized tuples + "fn ((a, b), c) -> :ok end", + "fn (a, (b, c)) -> :ok end", + "fn ((a, b), (c, d)) -> :ok end", + "fn ((a, b), (c, d)) when true -> :ok end", + "fn ((a, (b<-c))) -> :ok end", + "fn ((a, (b\\\\c))) -> :ok end" + ] + + for code <- codes do + assert {:error, _} = s2q(code) + assert {:error, _, _} = Spitfire.parse(code) + end + end + test "example from github issue" do code = ~S''' defmodule Foo do