diff --git a/src/ast.rs b/src/ast.rs index 6cda4851..3c328668 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -239,6 +239,8 @@ pub enum SingleExpressionInner { Call(Call), /// Match expression. Match(Match), + /// If expression. + If(If), } /// Call of a user-defined or of a builtin function. @@ -403,6 +405,38 @@ impl MatchArm { } } +#[derive(Clone, Debug)] +pub struct If { + scrutinee: Arc, + then_arm: Arc, + else_arm: Arc, + span: Span, +} + +impl If { + /// Access the expression who's output is deconstructed in the `if`. + pub fn scrutinee(&self) -> &Expression { + &self.scrutinee + } + + /// Access the branch that handles the `true` portion of the `if`. + pub fn then_arm(&self) -> &Expression { + &self.then_arm + } + + /// Access the branch that handles the `false` or `else` portion of the `if`. + pub fn else_arm(&self) -> &Expression { + &self.else_arm + } + + /// Access the span of the if statement. + pub fn span(&self) -> &Span { + &self.span + } +} + +impl_eq_hash!(If; scrutinee, then_arm, else_arm); + /// Item when analyzing modules. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub enum ModuleItem { @@ -462,6 +496,7 @@ pub enum ExprTree<'a> { Single(&'a SingleExpression), Call(&'a Call), Match(&'a Match), + If(&'a If), } impl TreeLike for ExprTree<'_> { @@ -502,6 +537,7 @@ impl TreeLike for ExprTree<'_> { } S::Call(call) => Tree::Unary(Self::Call(call)), S::Match(match_) => Tree::Unary(Self::Match(match_)), + S::If(if_) => Tree::Unary(Self::If(if_)), }, Self::Call(call) => Tree::Nary(call.args().iter().map(Self::Expression).collect()), Self::Match(match_) => Tree::Nary(Arc::new([ @@ -509,6 +545,11 @@ impl TreeLike for ExprTree<'_> { Self::Expression(match_.left().expression()), Self::Expression(match_.right().expression()), ])), + Self::If(if_) => Tree::Nary(Arc::new([ + Self::Expression(if_.scrutinee()), + Self::Expression(if_.then_arm()), + Self::Expression(if_.else_arm()), + ])), } } } @@ -1059,6 +1100,9 @@ impl AbstractSyntaxTree for SingleExpression { parse::SingleExpressionInner::Match(match_) => { Match::analyze(match_, ty, scope).map(SingleExpressionInner::Match)? } + parse::SingleExpressionInner::If(if_) => { + If::analyze(if_, ty, scope).map(SingleExpressionInner::If)? + } }; Ok(Self { @@ -1426,6 +1470,28 @@ impl AbstractSyntaxTree for Match { } } +impl AbstractSyntaxTree for If { + type From = parse::If; + + fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + let scrutinee = + Expression::analyze(from.scrutinee(), &ResolvedType::boolean(), scope).map(Arc::new)?; + scope.push_scope(); + let ast_then = Expression::analyze(from.then_arm(), ty, scope).map(Arc::new)?; + scope.pop_scope(); + scope.push_scope(); + let ast_else = Expression::analyze(from.else_arm(), ty, scope).map(Arc::new)?; + scope.pop_scope(); + + Ok(Self { + scrutinee, + then_arm: ast_then, + else_arm: ast_else, + span: *from.as_ref(), + }) + } +} + fn analyze_named_module( name: ModuleName, from: &parse::ModuleProgram, @@ -1559,6 +1625,12 @@ impl AsRef for Match { } } +impl AsRef for If { + fn as_ref(&self) -> &Span { + &self.span + } +} + impl AsRef for Module { fn as_ref(&self) -> &Span { &self.span @@ -1570,3 +1642,158 @@ impl AsRef for ModuleAssignment { &self.span } } + +#[cfg(test)] +mod test { + use super::*; + use crate::parse::{self, ParseFromStr}; + use crate::types::UIntType; + + /// Helper to check if an expression is a constant, unwrapping blocks if needed + fn is_constant_expr(expr: &Expression) -> bool { + match expr.inner() { + ExpressionInner::Single(single) => { + matches!(single.inner(), SingleExpressionInner::Constant(_)) + } + ExpressionInner::Block(_, Some(inner_expr)) => is_constant_expr(inner_expr), + _ => false, + } + } + + /// Helper to check if an expression is a block with statements + fn is_block_with_statements(expr: &Expression) -> bool { + matches!(expr.inner(), ExpressionInner::Block(stmts, Some(_)) if !stmts.is_empty()) + } + + fn parse_if(input: &str) -> parse::If { + // Parse the if expression + let parsed_expr = parse::Expression::parse_from_str(input).expect("Failed to parse"); + + // Extract the parsed If from the expression + let parsed_if = match parsed_expr.inner() { + parse::ExpressionInner::Single(single) => match single.inner() { + parse::SingleExpressionInner::If(if_) => if_.clone(), + _ => panic!("Expected If expression"), + }, + _ => panic!("Expected Single expression"), + }; + parsed_if + } + + #[test] + fn test_if_expression_analyze() { + let input = "if true { 0 } else { 1 }"; + + let parsed_if = &parse_if(input); + + // Analyze the if expression with u8 as the expected type + let expected_type = ResolvedType::from(UIntType::U8); + let mut scope = Scope::default(); + let ast_if = If::analyze(parsed_if, &expected_type, &mut scope) + .expect("Failed to analyze If expression"); + + // Verify the structure + assert_eq!( + ast_if.scrutinee().ty(), + &ResolvedType::boolean(), + "Scrutinee should be boolean type" + ); + assert_eq!( + ast_if.then_arm().ty(), + &expected_type, + "Then arm should have u8 type" + ); + assert_eq!( + ast_if.else_arm().ty(), + &expected_type, + "Else arm should have u8 type" + ); + + // Verify scrutinee is a boolean constant + match ast_if.scrutinee().inner() { + ExpressionInner::Single(single) => match single.inner() { + SingleExpressionInner::Constant(_) => { + // Boolean constant verified + } + _ => panic!("Expected boolean constant for scrutinee"), + }, + _ => panic!("Expected single expression for scrutinee"), + } + + // Verify both arms are constants (may be wrapped in blocks) + assert!( + is_constant_expr(ast_if.then_arm()), + "Then arm should be a constant" + ); + assert!( + is_constant_expr(ast_if.else_arm()), + "Else arm should be a constant" + ); + } + + #[test] + fn test_if_expression_with_complex_arms() { + let input = "if false { let x: u8 = 5; x } else { 10 }"; + + let parsed_expr = parse::Expression::parse_from_str(input).expect("Failed to parse"); + let expected_type = ResolvedType::from(UIntType::U8); + + // Analyze the entire expression (which will handle the if internally) + let ast_expr = Expression::analyze_const(&parsed_expr, &expected_type) + .expect("Failed to analyze expression"); + + // Verify the expression is an If + match ast_expr.inner() { + ExpressionInner::Single(single) => match single.inner() { + SingleExpressionInner::If(ast_if) => { + assert_eq!(ast_if.scrutinee().ty(), &ResolvedType::boolean()); + assert_eq!(ast_if.then_arm().ty(), &expected_type); + assert_eq!(ast_if.else_arm().ty(), &expected_type); + + // Verify then arm is a block with statements and else arm is a constant + assert!( + is_block_with_statements(ast_if.then_arm()), + "Then arm should be a block with statements" + ); + assert!( + is_constant_expr(ast_if.else_arm()), + "Else arm should be a constant" + ); + } + _ => panic!("Expected If expression"), + }, + _ => panic!("Expected Single expression"), + } + } + + #[test] + fn test_if_valid_parse_but_invalid_ast() { + let input = "if false { let x: u8 = 5; } else { 10 }"; + + let parsed_if = &parse_if(input); + let expected_type = ResolvedType::from(UIntType::U8); + let mut scope = Scope::default(); + let ast_if_result = If::analyze(parsed_if, &expected_type, &mut scope); + + assert!(ast_if_result + .err() + .map(|e| matches!(e.error(), Error::ExpressionTypeMismatch(..))) + .unwrap()); + } + + #[test] + fn test_if_valid_parse_but_invalid_scrutinee() { + let input = "if (()) { 1 } else { 10 }"; + + let parsed_if = &parse_if(input); + let expected_type = ResolvedType::from(UIntType::U8); + let mut scope = Scope::default(); + let ast_if_result = If::analyze(parsed_if, &expected_type, &mut scope); + + // Expected type of scrutinee is `bool` + assert!(ast_if_result + .err() + .map(|e| matches!(e.error(), Error::ExpressionUnexpectedType(..))) + .unwrap()); + } +} diff --git a/src/compile/mod.rs b/src/compile/mod.rs index 2af17e6f..ddeb8268 100644 --- a/src/compile/mod.rs +++ b/src/compile/mod.rs @@ -12,7 +12,7 @@ use simplicity::{types, Cmr, FailEntropy}; use self::builtins::array_fold; use crate::array::{BTreeSlice, Partition}; use crate::ast::{ - Call, CallName, Expression, ExpressionInner, Match, Program, SingleExpression, + Call, CallName, Expression, ExpressionInner, If, Match, Program, SingleExpression, SingleExpressionInner, Statement, }; use crate::debug::CallTracker; @@ -355,6 +355,7 @@ impl SingleExpression { } SingleExpressionInner::Call(call) => call.compile(scope)?, SingleExpressionInner::Match(match_) => match_.compile(scope)?, + SingleExpressionInner::If(if_) => if_.compile(scope)?, }; scope @@ -680,3 +681,22 @@ impl Match { input.comp(&output).with_span(self) } } + +impl If { + fn compile<'brand>( + &self, + scope: &mut Scope<'brand>, + ) -> Result>, RichError> { + scope.push_scope(); + let then_arm = self.then_arm().compile(scope)?; + scope.pop_scope(); + scope.push_scope(); + let else_arm = self.else_arm().compile(scope)?; + scope.pop_scope(); + + let scrutinee = self.scrutinee().compile(scope)?; + let input = scrutinee.pair(PairBuilder::iden(scope.ctx())); + let output = ProgNode::case(then_arm.as_ref(), else_arm.as_ref()).with_span(self)?; + input.comp(&output).with_span(self) + } +} diff --git a/src/lexer.rs b/src/lexer.rs index 71c004b6..1f7b2b4a 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -17,21 +17,37 @@ pub enum Token<'src> { Mod, Const, Match, + If, + Else, // Control symbols + /// `->` Arrow, + /// `:` Colon, + /// `;` Semi, + /// `,` Comma, + /// `=` Eq, + /// `=>` FatArrow, + /// `(` LParen, + /// `)` RParen, + /// `[` LBracket, + /// `]` RBracket, + /// `{` LBrace, + /// `}` RBrace, + /// `<` LAngle, + /// `>` RAngle, // Number literals @@ -69,6 +85,8 @@ impl<'src> fmt::Display for Token<'src> { Token::Mod => write!(f, "mod"), Token::Const => write!(f, "const"), Token::Match => write!(f, "match"), + Token::If => write!(f, "if"), + Token::Else => write!(f, "else"), Token::Arrow => write!(f, "->"), Token::Colon => write!(f, ":"), @@ -140,6 +158,8 @@ pub fn lexer<'src>( "mod" => Token::Mod, "const" => Token::Const, "match" => Token::Match, + "if" => Token::If, + "else" => Token::Else, // TODO: Else is never parsed. "true" => Token::Bool(true), "false" => Token::Bool(false), _ => Token::Ident(s), @@ -251,6 +271,88 @@ mod tests { (tokens, errors) } + /// Helper function to get the variant name of a token + fn variant_name(token: &Token) -> &'static str { + match token { + Token::Fn => "Fn", + Token::Let => "Let", + Token::Type => "Type", + Token::Mod => "Mod", + Token::Const => "Const", + Token::Match => "Match", + Token::If => "If", + Token::Else => "Else", + Token::Arrow => "Arrow", + Token::Colon => "Colon", + Token::Semi => "Semi", + Token::Comma => "Comma", + Token::Eq => "Eq", + Token::FatArrow => "FatArrow", + Token::LParen => "LParen", + Token::RParen => "RParen", + Token::LBracket => "LBracket", + Token::RBracket => "RBracket", + Token::LBrace => "LBrace", + Token::RBrace => "RBrace", + Token::LAngle => "LAngle", + Token::RAngle => "RAngle", + Token::DecLiteral(_) => "DecLiteral", + Token::HexLiteral(_) => "HexLiteral", + Token::BinLiteral(_) => "BinLiteral", + Token::Bool(_) => "Bool", + Token::Ident(_) => "Ident", + Token::Jet(_) => "Jet", + Token::Witness(_) => "Witness", + Token::Param(_) => "Param", + Token::Macro(_) => "Macro", + Token::Comment => "Comment", + Token::BlockComment => "BlockComment", + } + } + + /// Macro to assert that a sequence of tokens matches the expected variant types + macro_rules! assert_tokens_match { + ($tokens:expr, $($expected:ident),* $(,)?) => { + { + let tokens = $tokens.as_ref().expect("Expected Some tokens"); + let expected_variants = vec![$( stringify!($expected) ),*]; + + assert_eq!( + tokens.len(), + expected_variants.len(), + "Expected {} tokens, got {}.\nTokens: {:?}", + expected_variants.len(), + tokens.len(), + tokens + ); + + for (idx, (token, expected_variant)) in tokens.iter().zip(expected_variants.iter()).enumerate() { + let actual_variant = variant_name(token); + assert_eq!( + actual_variant, + *expected_variant, + "Token at index {} does not match: expected {}, got {} (token: {:?})", + idx, + expected_variant, + actual_variant, + token + ); + } + } + }; + } + + #[test] + fn test_if_statement() { + let input = "if true {0} else {1};"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + + assert_tokens_match!( + tokens, If, Bool, LBrace, DecLiteral, RBrace, Else, LBrace, DecLiteral, RBrace, Semi, + ); + } + #[test] fn test_block_comment_simple() { let input = "/* hello world */"; diff --git a/src/parse.rs b/src/parse.rs index f47dda5e..1629c266 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -351,6 +351,8 @@ pub enum SingleExpressionInner { Expression(Arc), /// Match expression over a sum type Match(Match), + /// If expression + If(If), /// Tuple wrapper expression Tuple(Arc<[Expression]>), /// Array wrapper expression @@ -465,6 +467,31 @@ impl MatchPattern { } } +/// Match expression. +#[derive(Clone, Debug)] +pub struct If { + scrutinee: Arc, + then_arm: Arc, + else_arm: Arc, + span: Span, +} + +impl If { + pub fn scrutinee(&self) -> &Expression { + &self.scrutinee + } + + pub fn then_arm(&self) -> &Expression { + &self.then_arm + } + + pub fn else_arm(&self) -> &Expression { + &self.else_arm + } +} + +impl_eq_hash!(If; scrutinee, then_arm, else_arm); + /// Program root when parsing modules. #[derive(Clone, Debug)] pub struct ModuleProgram { @@ -602,6 +629,7 @@ pub enum ExprTree<'a> { Single(&'a SingleExpression), Call(&'a Call), Match(&'a Match), + If(&'a If), } impl TreeLike for ExprTree<'_> { @@ -642,6 +670,7 @@ impl TreeLike for ExprTree<'_> { | S::Expression(l) => Tree::Unary(Self::Expression(l)), S::Call(call) => Tree::Unary(Self::Call(call)), S::Match(match_) => Tree::Unary(Self::Match(match_)), + S::If(if_) => Tree::Unary(Self::If(if_)), S::Tuple(elements) | S::Array(elements) | S::List(elements) => { Tree::Nary(elements.iter().map(Self::Expression).collect()) } @@ -652,6 +681,11 @@ impl TreeLike for ExprTree<'_> { Self::Expression(match_.left().expression()), Self::Expression(match_.right().expression()), ])), + Self::If(if_) => Tree::Nary(Arc::new([ + Self::Expression(if_.scrutinee()), + Self::Expression(if_.then_arm()), + Self::Expression(if_.else_arm()), + ])), } } } @@ -715,7 +749,7 @@ impl fmt::Display for ExprTree<'_> { write!(f, ")")?; } }, - S::Call(..) | S::Match(..) => {} + S::Call(..) | S::Match(..) | S::If(..) => {} S::Tuple(tuple) => { if data.n_children_yielded == 0 { write!(f, "(")?; @@ -766,6 +800,15 @@ impl fmt::Display for ExprTree<'_> { write!(f, ",\n}}")?; } }, + Self::If(..) => match data.n_children_yielded { + 0 => write!(f, "if ")?, + 1 => write!(f, "{{ ")?, + 2 => write!(f, "}} else {{")?, + n => { + debug_assert_eq!(n, 3); + write!(f, "}}")?; + } + }, } } @@ -1599,6 +1642,8 @@ impl SingleExpression { let match_expr = Match::parser(expr.clone()).map(SingleExpressionInner::Match); + let if_expr = If::parser(expr.clone()).map(SingleExpressionInner::If); + let variable = Identifier::parser().map(SingleExpressionInner::Variable); // Expression delimeted by parentheses @@ -1608,8 +1653,8 @@ impl SingleExpression { .map(|es| SingleExpressionInner::Expression(Arc::from(es))); choice(( - left, right, some, none, boolean, match_expr, expression, list, array, tuple, call, - literal, variable, + left, right, some, none, boolean, match_expr, if_expr, expression, list, array, tuple, + call, literal, variable, )) .map_with(|inner, e| Self { inner, @@ -1773,6 +1818,35 @@ impl Match { } } +impl If { + fn parser<'tokens, 'src: 'tokens, I, E>( + expr: E, + ) -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone + where + I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, + E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, + { + let scrutinee = expr.clone().map(Arc::new); + + let true_arm = expr.clone().map(Arc::new); + let false_arm = expr.clone().map(Arc::new); + // let true_arm = delimited_with_recovery(Expression::parser, Token::LBrace, Token::RBrace, |_| None); + // let false_arm = delimited_with_recovery(Expression::parser, Token::LBrace, Token::RBrace, |_| None); + + just(Token::If) + .ignore_then(scrutinee) + .then(true_arm) + .then_ignore(just(Token::Else)) + .then(false_arm) + .map_with(|((s, t), el), extra| Self { + scrutinee: s, + then_arm: t, + else_arm: el, + span: extra.span(), + }) + } +} + impl ChumskyParse for ModuleItem { fn parser<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone where @@ -1909,6 +1983,12 @@ impl AsRef for Match { } } +impl AsRef for If { + fn as_ref(&self) -> &Span { + &self.span + } +} + impl AsRef for ModuleProgram { fn as_ref(&self) -> &Span { &self.span @@ -2163,3 +2243,25 @@ impl crate::ArbitraryRec for Match { }) } } + +#[cfg(test)] +mod test { + + use super::*; + #[test] + fn test_if_statement_parse() { + let input = "if true {0} else {1}"; + + let statement = Expression::parse_from_str(input).expect("Error"); + + match &statement.inner() { + ExpressionInner::Single(se) => match se.inner() { + SingleExpressionInner::If(_if_) => { + // pass + } + _ => panic!("Did not find if statement correctly"), + }, + _ => panic!("Did not parse correctly"), + } + } +} diff --git a/src/value.rs b/src/value.rs index 3ca7fdca..27e46517 100644 --- a/src/value.rs +++ b/src/value.rs @@ -671,7 +671,8 @@ impl Value { | ExprTree::Statement(..) | ExprTree::Assignment(..) | ExprTree::Call(..) - | ExprTree::Match(..) => return None, // not const + | ExprTree::Match(..) + | ExprTree::If(..) => return None, // not const }; let size = data.node.n_children(); match single.inner() { @@ -680,7 +681,8 @@ impl Value { | S::Parameter(..) | S::Variable(..) | S::Call(..) - | S::Match(..) => return None, // not const + | S::Match(..) + | S::If(..) => return None, // not const S::Expression(..) => continue, // skip S::Tuple(..) => { let elements = output.split_off(output.len() - size);