diff --git a/contributing/typechecking.md b/contributing/typechecking.md new file mode 100644 index 0000000..d194ead --- /dev/null +++ b/contributing/typechecking.md @@ -0,0 +1,36 @@ +# Typechecking + +The current typechecker is based on the paper [Complete and Easy Bidirectional Typechecking for Higher-Rank Polymorphism](https://arxiv.org/abs/1306.6032), although we've had to make some modifications because their type system is pretty different from ours. Mainly, we've added gradual typing and subtyping while ignoring the higher-rank polymorphism stuff. + +The following are types used by the typechecker: + +- The top type: the supertype of all types (similar to Java's `Object`, which is the supertype of all reference types) +- The bottom type: the subtype of all types (same as TypeScript's `never`). You can never construct an instance of the bottom type. + - This is useful for indicating that a command will never return (due to throwing an error or looping infinitely). +- The `oneof` type: for creating union types. `oneof` can be used as the type for a value that's either an integer or a string. +- The `allof` type: for creating intersection types. `allof` represents a type that is simultaneously the subtype of `A`, `B`, and `C`. + - This isn't particularly useful in Nushell, so users won't be able to construct `allof` types themselves. + - However, the typechecker does use this internally. + +## How type checking/inference work + +The `typecheck_expr` method takes the expected type of the expression it is currently processing and provides an inferred type for every expression it visits. This way, we both check against the expected type and infer a type for the expression at the same time. Below is an example program using generics. + +```nu +def f [ x: T ] : nothing -> list { + let z = $x + [$z] +} +f 123 # Inferred to be of type list +``` + +For generics to work, the algorithm requires creating and solving type variables. These type variables have a lower bound and an upper bound. As we move through the program, these bounds are tightened further. At the end of the program, the lower bound of each type variable is chosen as its value. + +Every time we come across a call to a custom command with type parameters, we instantiate new type variables corresponding to those type parameters. For example, for the expression `f 123` above, we instantiate a new type variable `'0` with lower bound `bottom` and upper bound `top` (essentially unbounded). Because of the signature of `f`, we know that `123: '0` and `f 123: list<'0>`. + +So we first call `typecheck_expr` to check/infer the type of `123`, providing `'0` as its expected type. Since it's just an integer literal, we infer the type to be `int`. Then, to ensure that this type matches the expected type (`'0`), we call the `constrain_subtype` method to ensure that `int <: '0`. The existing lower bound for `'0` was `bottom`, so we set the new lower bound to `oneof = int`. + +Then we set the type of `f 123` to `list<'0>`. After all expressions have been processed, we replace all occurrences of `'0` with `int`. So this becomes `list`. + +Recursive bounds are not allowed. The bounds for a type variable (say, `'5`) can only refer to type variables created before it (here, `'0` through `'4`). +- If, during typechecking, we find ourselves setting the upper bound of `'5` to `'6`, then we instead set the upper bound of `'5` to whatever the lower bound of `'6` is at that point. This behavior can be improved upon somewhat in the future (in this particular example, we could instead update `'6`'s lower bound to include `'5`), but I think it's good enough for now. diff --git a/src/compiler.rs b/src/compiler.rs index 45f42a2..11136e0 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,7 +1,9 @@ use crate::errors::SourceError; use crate::parser::{AstNode, Block, NodeId, Pipeline}; use crate::protocol::Command; -use crate::resolver::{DeclId, Frame, NameBindings, ScopeId, VarId, Variable}; +use crate::resolver::{ + DeclId, Frame, NameBindings, ScopeId, TypeDecl, TypeDeclId, VarId, Variable, +}; use crate::typechecker::{TypeId, Types}; use std::collections::HashMap; @@ -58,8 +60,14 @@ pub struct Compiler { pub variables: Vec, /// Mapping of variable's name node -> Variable pub var_resolution: HashMap, - /// Declarations (commands, aliases, externs), indexed by VarId + /// Type declarations, indexed by TypeDeclId + pub type_decls: Vec, + /// Mapping of type decl's name node -> TypeDecl + pub type_resolution: HashMap, + /// Declarations (commands, aliases, externs), indexed by DeclId pub decls: Vec>, + /// Declaration NodeIds, indexed by DeclId + pub decl_nodes: Vec, /// Mapping of decl's name node -> Command pub decl_resolution: HashMap, @@ -71,7 +79,6 @@ pub struct Compiler { // Use/def // pub call_resolution: HashMap, - // pub type_resolution: HashMap, pub errors: Vec, } @@ -96,7 +103,10 @@ impl Compiler { scope_stack: vec![], variables: vec![], var_resolution: HashMap::new(), + type_decls: vec![], + type_resolution: HashMap::new(), decls: vec![], + decl_nodes: vec![], decl_resolution: HashMap::new(), // variables: vec![], @@ -104,8 +114,6 @@ impl Compiler { // types: vec![], // call_resolution: HashMap::new(), - // var_resolution: HashMap::new(), - // type_resolution: HashMap::new(), errors: vec![], } } @@ -157,7 +165,10 @@ impl Compiler { self.scope_stack.extend(name_bindings.scope_stack); self.variables.extend(name_bindings.variables); self.var_resolution.extend(name_bindings.var_resolution); + self.type_decls.extend(name_bindings.type_decls); + self.type_resolution.extend(name_bindings.type_resolution); self.decls.extend(name_bindings.decls); + self.decl_nodes.extend(name_bindings.decl_nodes); self.decl_resolution.extend(name_bindings.decl_resolution); self.errors.extend(name_bindings.errors); } diff --git a/src/parser.rs b/src/parser.rs index 8349171..7936492 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -179,6 +179,7 @@ pub enum AstNode { // Definitions Def { name: NodeId, + type_params: Option, params: NodeId, in_out_types: Option, block: NodeId, @@ -1018,6 +1019,32 @@ impl Parser { self.create_node(AstNode::Params(param_list), span_start, span_end) } + pub fn type_params(&mut self) -> NodeId { + let _span = span!(); + let span_start = self.position(); + self.less_than(); + + let mut param_list = vec![]; + + while self.has_tokens() { + if self.is_greater_than() { + break; + } + + if self.is_comma() { + self.tokens.advance(); + continue; + } + + param_list.push(self.name()); + } + + let span_end = self.position() + 1; + self.greater_than(); + + self.create_node(AstNode::Params(param_list), span_start, span_end) + } + pub fn type_args(&mut self) -> NodeId { let _span = span!(); let span_start = self.position(); @@ -1159,6 +1186,12 @@ impl Parser { _ => return self.error("expected def name"), }; + let type_params = if self.is_less_than() { + Some(self.type_params()) + } else { + None + }; + let params = self.signature_params(ParamsContext::Squares); let in_out_types = if self.is_colon() { Some(self.in_out_types()) @@ -1172,6 +1205,7 @@ impl Parser { self.create_node( AstNode::Def { name, + type_params, params, in_out_types, block, diff --git a/src/resolver.rs b/src/resolver.rs index d0f5319..1048cf2 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -23,6 +23,7 @@ pub enum FrameType { pub struct Frame { pub frame_type: FrameType, pub variables: HashMap, NodeId>, + pub type_decls: HashMap, NodeId>, pub decls: HashMap, NodeId>, /// Node that defined the scope frame (e.g., a block or overlay) pub node_id: NodeId, @@ -33,6 +34,7 @@ impl Frame { Frame { frame_type: scope_type, variables: HashMap::new(), + type_decls: HashMap::new(), decls: HashMap::new(), node_id, } @@ -47,6 +49,16 @@ pub struct Variable { #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub struct VarId(pub usize); +#[derive(Debug, Clone)] +pub enum TypeDecl { + /// A type parameter. Holds the parameter name node + Param(NodeId), + // In the future, we may have type aliases, user-defined classes, etc. +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct TypeDeclId(pub usize); + #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub struct DeclId(pub usize); @@ -56,7 +68,10 @@ pub struct NameBindings { pub scope_stack: Vec, pub variables: Vec, pub var_resolution: HashMap, + pub type_decls: Vec, + pub type_resolution: HashMap, pub decls: Vec>, + pub decl_nodes: Vec, pub decl_resolution: HashMap, pub errors: Vec, } @@ -68,7 +83,10 @@ impl NameBindings { scope_stack: vec![], variables: vec![], var_resolution: HashMap::new(), + type_decls: vec![], + type_resolution: HashMap::new(), decls: vec![], + decl_nodes: vec![], decl_resolution: HashMap::new(), errors: vec![], } @@ -93,8 +111,14 @@ pub struct Resolver<'a> { pub variables: Vec, /// Mapping of variable's name node -> Variable pub var_resolution: HashMap, + /// Type declarations, indexed by TypeDeclId + pub type_decls: Vec, + /// Mapping of type decl's name node -> TypeDecl + pub type_resolution: HashMap, /// Declarations (commands, aliases, etc.), indexed by DeclId pub decls: Vec>, + /// Declaration nodes, indexed by DeclId + pub decl_nodes: Vec, /// Mapping of decl's name node -> Command pub decl_resolution: HashMap, /// Errors encountered during name binding @@ -109,7 +133,10 @@ impl<'a> Resolver<'a> { scope_stack: vec![], variables: vec![], var_resolution: HashMap::new(), + type_decls: vec![], + type_resolution: HashMap::new(), decls: vec![], + decl_nodes: vec![], decl_resolution: HashMap::new(), errors: vec![], } @@ -121,7 +148,10 @@ impl<'a> Resolver<'a> { scope_stack: self.scope_stack, variables: self.variables, var_resolution: self.var_resolution, + type_decls: self.type_decls, + type_resolution: self.type_resolution, decls: self.decls, + decl_nodes: self.decl_nodes, decl_resolution: self.decl_resolution, errors: self.errors, } @@ -149,13 +179,19 @@ impl<'a> Resolver<'a> { .map(|(name, id)| format!("{0}: {id:?}", String::from_utf8_lossy(name))) .collect(); + let mut types: Vec = scope + .type_decls + .iter() + .map(|(name, id)| format!("{0}: {id:?}", String::from_utf8_lossy(name))) + .collect(); + let mut decls: Vec = scope .decls .iter() .map(|(name, id)| format!("{0}: {id:?}", String::from_utf8_lossy(name))) .collect(); - if vars.is_empty() && decls.is_empty() { + if vars.is_empty() && types.is_empty() && decls.is_empty() { result.push_str(" (empty)\n"); continue; } @@ -168,6 +204,12 @@ impl<'a> Resolver<'a> { result.push_str(&line_var); } + if !types.is_empty() { + types.sort(); + let line_type = format!(" type decls: [ {0} ]\n", types.join(", ")); + result.push_str(&line_type); + } + if !decls.is_empty() { decls.sort(); let line_decl = format!(" decls: [ {0} ]\n", decls.join(", ")); @@ -220,16 +262,28 @@ impl<'a> Resolver<'a> { } AstNode::Def { name, + type_params, params, - in_out_types: _, + in_out_types, block, } => { // define the command before the block to enable recursive calls - self.define_decl(name); + self.define_decl(name, node_id); // making sure the def parameters and body end up in the same scope frame self.enter_scope(block); + if let Some(type_params) = type_params { + let AstNode::Params(type_params) = self.compiler.get_node(type_params) else { + panic!("Internal error: expected type params") + }; + for type_param_id in type_params { + self.define_type_decl(*type_param_id, TypeDecl::Param(*type_param_id)); + } + } self.resolve_node(params); + if let Some(in_out_types) = in_out_types { + self.resolve_node(in_out_types); + } let def_scope = self.exit_scope(); let AstNode::Block(block_id) = self.compiler.ast_nodes[block.0] else { @@ -242,23 +296,28 @@ impl<'a> Resolver<'a> { new_name, old_name: _, } => { - self.define_decl(new_name); + self.define_decl(new_name, node_id); } AstNode::Params(ref params) => { for param in params { - if let AstNode::Param { name, .. } = self.compiler.ast_nodes[param.0] { - self.define_variable(name, false); - } else { + let AstNode::Param { name, ty } = self.compiler.ast_nodes[param.0] else { panic!("param is not a param"); + }; + self.define_variable(name, false); + if let Some(ty) = ty { + self.resolve_node(ty); } } } AstNode::Let { variable_name, - ty: _, + ty, initializer, is_mutable, } => { + if let Some(ty) = ty { + self.resolve_node(ty); + } self.resolve_node(initializer); self.define_variable(variable_name, is_mutable) } @@ -338,9 +397,38 @@ impl<'a> Resolver<'a> { } } AstNode::Statement(node) => self.resolve_node(node), + AstNode::Type { name, args, .. } => { + self.resolve_type(name); + if let Some(args) = args { + self.resolve_node(args); + } + } + AstNode::RecordType { fields, .. } => { + let AstNode::Params(fields) = self.compiler.get_node(fields) else { + panic!("Internal error: expected params for record field types"); + }; + for field in fields { + if let AstNode::Param { ty: Some(ty), .. } = self.compiler.get_node(*field) { + self.resolve_node(*ty); + } + } + } + AstNode::TypeArgs(ref args) => { + for arg in args { + self.resolve_node(*arg); + } + } + AstNode::InOutTypes(ref in_out_types) => { + for in_out_ty in in_out_types { + self.resolve_node(*in_out_ty); + } + } + AstNode::InOutType(in_ty, out_ty) => { + self.resolve_node(in_ty); + self.resolve_node(out_ty); + } AstNode::Pipeline(pipeline_id) => self.resolve_pipeline(pipeline_id), AstNode::Param { .. } => (/* seems unused for now */), - AstNode::Type { .. } => ( /* probably doesn't make sense to resolve? */ ), AstNode::NamedValue { .. } => (/* seems unused for now */), // All remaining matches do not contain NodeId => there is nothing to resolve _ => (), @@ -374,6 +462,31 @@ impl<'a> Resolver<'a> { } } + pub fn resolve_type(&mut self, unbound_node_id: NodeId) { + let type_name = self.compiler.get_span_contents(unbound_node_id); + + match type_name { + b"any" | b"list" | b"bool" | b"closure" | b"float" | b"int" | b"nothing" + | b"number" | b"string" => return, + _ => {} + } + + if let Some(node_id) = self.find_type(type_name) { + let type_id = self + .type_resolution + .get(&node_id) + .expect("internal error: missing resolved type"); + + self.type_resolution.insert(unbound_node_id, *type_id); + } else { + self.errors.push(SourceError { + message: format!("type `{}` not found", String::from_utf8_lossy(type_name)), + node_id: unbound_node_id, + severity: Severity::Error, + }) + } + } + pub fn resolve_call(&mut self, unbound_node_id: NodeId, parts: &[NodeId]) { // Find out the potentially longest command name let max_name_parts = parts @@ -485,7 +598,26 @@ impl<'a> Resolver<'a> { self.var_resolution.insert(var_name_id, var_id); } - pub fn define_decl(&mut self, decl_name_id: NodeId) { + pub fn define_type_decl(&mut self, type_name_id: NodeId, type_decl: TypeDecl) { + let type_name = self.compiler.get_span_contents(type_name_id).to_vec(); + + let current_scope_id = self + .scope_stack + .last() + .expect("internal error: missing scope frame id"); + + self.scope[current_scope_id.0] + .type_decls + .insert(type_name, type_name_id); + + self.type_decls.push(type_decl); + let type_id = TypeDeclId(self.type_decls.len() - 1); + + // let the definition of a type also count as its use + self.type_resolution.insert(type_name_id, type_id); + } + + pub fn define_decl(&mut self, decl_name_id: NodeId, decl_node_id: NodeId) { // TODO: Deduplicate code with define_variable() let decl_name = self.compiler.get_span_contents(decl_name_id); let decl_name = trim_decl_name(decl_name).to_vec(); @@ -501,8 +633,9 @@ impl<'a> Resolver<'a> { .insert(decl_name, decl_name_id); self.decls.push(Box::new(decl)); - let decl_id = DeclId(self.decls.len() - 1); + self.decl_nodes.push(decl_node_id); + let decl_id = DeclId(self.decls.len() - 1); // let the definition of a decl also count as its use self.decl_resolution.insert(decl_name_id, decl_id); } @@ -517,6 +650,16 @@ impl<'a> Resolver<'a> { None } + pub fn find_type(&self, type_name: &[u8]) -> Option { + for scope_id in self.scope_stack.iter().rev() { + if let Some(id) = self.scope[scope_id.0].type_decls.get(type_name) { + return Some(*id); + } + } + + None + } + pub fn find_decl(&self, var_name: &[u8]) -> Option { // TODO: Deduplicate code with find_variable() for scope_id in self.scope_stack.iter().rev() { diff --git a/src/snapshots/new_nu_parser__test__lexer.snap b/src/snapshots/new_nu_parser__test__lexer.snap deleted file mode 100644 index 87e09ba..0000000 --- a/src/snapshots/new_nu_parser__test__lexer.snap +++ /dev/null @@ -1,21 +0,0 @@ ---- -source: src/test.rs -expression: evaluate_lexer(path) -input_file: tests/lex/int.nu -snapshot_kind: text ---- -Token 0: Number span: 0 .. 1 '0' -Token 1: Newline span: 1 .. 2 '\n' -Token 2: Number span: 2 .. 4 '00' -Token 3: Newline span: 4 .. 5 '\n' -Token 4: Number span: 5 .. 10 '0x123' -Token 5: Newline span: 10 .. 11 '\n' -Token 6: Number span: 11 .. 16 '0b101' -Token 7: Newline span: 16 .. 17 '\n' -Token 8: Name span: 17 .. 19 '_0' -Token 9: Newline span: 19 .. 20 '\n' -Token 10: Number span: 20 .. 21 '0' -Token 11: Name span: 21 .. 22 '_' -Token 12: Newline span: 22 .. 23 '\n' -Token 13: Name span: 23 .. 24 '_' -Token 14: Newline span: 24 .. 25 '\n' diff --git a/src/snapshots/new_nu_parser__test__node_output@binary_ops_exact.nu.snap b/src/snapshots/new_nu_parser__test__node_output@binary_ops_exact.nu.snap index 9233e08..da32b54 100644 --- a/src/snapshots/new_nu_parser__test__node_output@binary_ops_exact.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@binary_ops_exact.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/binary_ops_exact.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Int (0 to 1) "1" @@ -12,33 +11,34 @@ snapshot_kind: text 4: True (8 to 12) 5: List([NodeId(4)]) (7 to 12) 6: Append (14 to 16) -7: False (17 to 22) -8: BinaryOp { lhs: NodeId(5), op: NodeId(6), rhs: NodeId(7) } (7 to 22) -9: Int (23 to 24) "1" -10: Plus (25 to 26) -11: Int (27 to 28) "1" -12: BinaryOp { lhs: NodeId(9), op: NodeId(10), rhs: NodeId(11) } (23 to 28) -13: Float (29 to 32) "1.0" -14: Plus (33 to 34) -15: Float (35 to 38) "1.0" -16: BinaryOp { lhs: NodeId(13), op: NodeId(14), rhs: NodeId(15) } (29 to 38) -17: True (39 to 43) -18: And (44 to 47) -19: False (48 to 53) -20: BinaryOp { lhs: NodeId(17), op: NodeId(18), rhs: NodeId(19) } (39 to 53) -21: String (54 to 59) ""foo"" -22: RegexMatch (60 to 62) -23: String (63 to 68) "".*o"" -24: BinaryOp { lhs: NodeId(21), op: NodeId(22), rhs: NodeId(23) } (54 to 68) -25: Int (69 to 70) "1" -26: In (71 to 73) -27: Int (75 to 76) "1" -28: Int (78 to 79) "2" -29: List([NodeId(27), NodeId(28)]) (74 to 79) -30: BinaryOp { lhs: NodeId(25), op: NodeId(26), rhs: NodeId(29) } (69 to 79) -31: Block(BlockId(0)) (0 to 81) +7: False (18 to 23) +8: List([NodeId(7)]) (17 to 23) +9: BinaryOp { lhs: NodeId(5), op: NodeId(6), rhs: NodeId(8) } (7 to 23) +10: Int (25 to 26) "1" +11: Plus (27 to 28) +12: Int (29 to 30) "1" +13: BinaryOp { lhs: NodeId(10), op: NodeId(11), rhs: NodeId(12) } (25 to 30) +14: Float (31 to 34) "1.0" +15: Plus (35 to 36) +16: Float (37 to 40) "1.0" +17: BinaryOp { lhs: NodeId(14), op: NodeId(15), rhs: NodeId(16) } (31 to 40) +18: True (41 to 45) +19: And (46 to 49) +20: False (50 to 55) +21: BinaryOp { lhs: NodeId(18), op: NodeId(19), rhs: NodeId(20) } (41 to 55) +22: String (56 to 61) ""foo"" +23: RegexMatch (62 to 64) +24: String (65 to 70) "".*o"" +25: BinaryOp { lhs: NodeId(22), op: NodeId(23), rhs: NodeId(24) } (56 to 70) +26: Int (71 to 72) "1" +27: In (73 to 75) +28: Int (77 to 78) "1" +29: Int (80 to 81) "2" +30: List([NodeId(28), NodeId(29)]) (76 to 81) +31: BinaryOp { lhs: NodeId(26), op: NodeId(27), rhs: NodeId(30) } (71 to 81) +32: Block(BlockId(0)) (0 to 83) ==== SCOPE ==== -0: Frame Scope, node_id: NodeId(31) (empty) +0: Frame Scope, node_id: NodeId(32) (empty) ==== TYPES ==== 0: int 1: forbidden @@ -49,29 +49,30 @@ snapshot_kind: text 6: forbidden 7: bool 8: list -9: int -10: forbidden -11: int +9: list +10: int +11: forbidden 12: int -13: float -14: forbidden -15: float +13: int +14: float +15: forbidden 16: float -17: bool -18: forbidden -19: bool +17: float +18: bool +19: forbidden 20: bool -21: string -22: forbidden -23: string -24: bool -25: int -26: forbidden -27: int +21: bool +22: string +23: forbidden +24: string +25: bool +26: int +27: forbidden 28: int -29: list -30: bool +29: int +30: list 31: bool +32: bool ==== IR ==== register_count: 2 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap b/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap index fa39f5a..42a713d 100644 --- a/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap @@ -25,27 +25,29 @@ input_file: tests/binary_ops_mismatch.nu 0: Frame Scope, node_id: NodeId(16) (empty) ==== TYPES ==== 0: string -1: error +1: forbidden 2: float -3: error +3: string 4: string 5: error 6: float 7: error 8: bool -9: error +9: forbidden 10: string -11: error +11: bool 12: bool -13: error +13: forbidden 14: string -15: error -16: error +15: bool +16: bool ==== TYPE ERRORS ==== -Error (NodeId 1): type mismatch: unsupported addition between string and float +Error (NodeId 2): Expected string, got float +Error (NodeId 4): Expected list <: '0 <: list, got string +Error (NodeId 6): Expected list <: '0 <: list, got float Error (NodeId 5): type mismatch: unsupported append between string and float -Error (NodeId 9): type mismatch: unsupported logical operation between bool and string -Error (NodeId 13): type mismatch: unsupported string operation between bool and string +Error (NodeId 10): Expected bool, got string +Error (NodeId 12): Expected string, got bool ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap b/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap index 9c76012..bb726ad 100644 --- a/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/binary_ops_subtypes.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Int (0 to 1) "1" @@ -20,70 +19,72 @@ snapshot_kind: text 12: Int (29 to 30) "1" 13: List([NodeId(12)]) (28 to 30) 14: Append (32 to 34) -15: Float (35 to 38) "1.0" -16: BinaryOp { lhs: NodeId(13), op: NodeId(14), rhs: NodeId(15) } (28 to 38) -17: Float (40 to 43) "1.0" -18: Int (44 to 45) "1" -19: List([NodeId(17), NodeId(18)]) (39 to 45) -20: Append (47 to 49) -21: String (50 to 53) ""a"" -22: BinaryOp { lhs: NodeId(19), op: NodeId(20), rhs: NodeId(21) } (39 to 53) -23: Int (56 to 57) "1" -24: List([NodeId(23)]) (55 to 57) -25: Int (60 to 61) "2" +15: Float (36 to 39) "1.0" +16: List([NodeId(15)]) (35 to 39) +17: BinaryOp { lhs: NodeId(13), op: NodeId(14), rhs: NodeId(16) } (28 to 39) +18: Float (42 to 45) "1.0" +19: Int (46 to 47) "1" +20: List([NodeId(18), NodeId(19)]) (41 to 47) +21: Append (49 to 51) +22: String (53 to 56) ""a"" +23: List([NodeId(22)]) (52 to 56) +24: BinaryOp { lhs: NodeId(20), op: NodeId(21), rhs: NodeId(23) } (41 to 56) +25: Int (60 to 61) "1" 26: List([NodeId(25)]) (59 to 61) -27: List([NodeId(24), NodeId(26)]) (54 to 62) -28: Append (64 to 66) -29: Int (69 to 70) "3" -30: List([NodeId(29)]) (68 to 70) -31: List([NodeId(30)]) (67 to 71) -32: BinaryOp { lhs: NodeId(27), op: NodeId(28), rhs: NodeId(31) } (54 to 71) -33: Int (75 to 76) "1" -34: List([NodeId(33)]) (74 to 76) -35: Int (79 to 80) "2" +27: Int (64 to 65) "2" +28: List([NodeId(27)]) (63 to 65) +29: List([NodeId(26), NodeId(28)]) (58 to 66) +30: Append (68 to 70) +31: Int (73 to 74) "3" +32: List([NodeId(31)]) (72 to 74) +33: List([NodeId(32)]) (71 to 75) +34: BinaryOp { lhs: NodeId(29), op: NodeId(30), rhs: NodeId(33) } (58 to 75) +35: Int (79 to 80) "1" 36: List([NodeId(35)]) (78 to 80) -37: List([NodeId(34), NodeId(36)]) (73 to 81) -38: Append (83 to 85) -39: Float (88 to 91) "3.0" -40: List([NodeId(39)]) (87 to 91) -41: List([NodeId(40)]) (86 to 92) -42: BinaryOp { lhs: NodeId(37), op: NodeId(38), rhs: NodeId(41) } (73 to 92) -43: Int (94 to 95) "1" -44: In (96 to 98) -45: Float (100 to 103) "1.0" -46: Int (105 to 106) "1" -47: List([NodeId(45), NodeId(46)]) (99 to 106) -48: BinaryOp { lhs: NodeId(43), op: NodeId(44), rhs: NodeId(47) } (94 to 106) -49: Float (108 to 111) "2.3" -50: Modulo (112 to 115) -51: Int (116 to 117) "1" -52: BinaryOp { lhs: NodeId(49), op: NodeId(50), rhs: NodeId(51) } (108 to 117) -53: String (120 to 121) "b" -54: Int (123 to 124) "2" -55: String (126 to 127) "c" -56: Int (129 to 130) "3" -57: Record { pairs: [(NodeId(53), NodeId(54)), (NodeId(55), NodeId(56))] } (119 to 131) -58: List([NodeId(57)]) (118 to 131) -59: Append (133 to 135) -60: String (138 to 139) "a" -61: Int (141 to 142) "3" -62: String (144 to 145) "b" -63: Float (147 to 150) "1.5" -64: String (152 to 153) "c" -65: String (155 to 160) ""foo"" -66: Record { pairs: [(NodeId(60), NodeId(61)), (NodeId(62), NodeId(63)), (NodeId(64), NodeId(65))] } (137 to 161) -67: List([NodeId(66)]) (136 to 161) -68: BinaryOp { lhs: NodeId(58), op: NodeId(59), rhs: NodeId(67) } (118 to 161) -69: Block(BlockId(0)) (0 to 163) +37: Int (83 to 84) "2" +38: List([NodeId(37)]) (82 to 84) +39: List([NodeId(36), NodeId(38)]) (77 to 85) +40: Append (87 to 89) +41: Float (92 to 95) "3.0" +42: List([NodeId(41)]) (91 to 95) +43: List([NodeId(42)]) (90 to 96) +44: BinaryOp { lhs: NodeId(39), op: NodeId(40), rhs: NodeId(43) } (77 to 96) +45: Int (98 to 99) "1" +46: In (100 to 102) +47: Float (104 to 107) "1.0" +48: Int (109 to 110) "1" +49: List([NodeId(47), NodeId(48)]) (103 to 110) +50: BinaryOp { lhs: NodeId(45), op: NodeId(46), rhs: NodeId(49) } (98 to 110) +51: Float (112 to 115) "2.3" +52: Modulo (116 to 119) +53: Int (120 to 121) "1" +54: BinaryOp { lhs: NodeId(51), op: NodeId(52), rhs: NodeId(53) } (112 to 121) +55: String (124 to 125) "b" +56: Int (127 to 128) "2" +57: String (130 to 131) "c" +58: Int (133 to 134) "3" +59: Record { pairs: [(NodeId(55), NodeId(56)), (NodeId(57), NodeId(58))] } (123 to 135) +60: List([NodeId(59)]) (122 to 135) +61: Append (137 to 139) +62: String (142 to 143) "a" +63: Int (145 to 146) "3" +64: String (148 to 149) "b" +65: Float (151 to 154) "1.5" +66: String (156 to 157) "c" +67: String (159 to 164) ""foo"" +68: Record { pairs: [(NodeId(62), NodeId(63)), (NodeId(64), NodeId(65)), (NodeId(66), NodeId(67))] } (141 to 165) +69: List([NodeId(68)]) (140 to 165) +70: BinaryOp { lhs: NodeId(60), op: NodeId(61), rhs: NodeId(69) } (122 to 165) +71: Block(BlockId(0)) (0 to 167) ==== SCOPE ==== -0: Frame Scope, node_id: NodeId(69) (empty) +0: Frame Scope, node_id: NodeId(71) (empty) ==== TYPES ==== 0: int 1: forbidden 2: float 3: bool 4: string -5: forbidden +5: error 6: float 7: bool 8: int @@ -94,60 +95,64 @@ snapshot_kind: text 13: list 14: forbidden 15: float -16: list -17: float -18: int -19: list -20: forbidden -21: string -22: list -23: int -24: list +16: list +17: list +18: float +19: int +20: list +21: forbidden +22: string +23: list +24: list> 25: int 26: list -27: list> -28: forbidden -29: int -30: list -31: list> -32: list> -33: int -34: list +27: int +28: list +29: list> +30: forbidden +31: int +32: list +33: list> +34: list> 35: int 36: list -37: list> -38: forbidden -39: float -40: list -41: list> -42: list> -43: int -44: forbidden -45: float -46: int -47: list -48: bool -49: float -50: forbidden -51: int -52: float -53: unknown -54: int +37: int +38: list +39: list> +40: forbidden +41: float +42: list +43: list> +44: list> +45: int +46: forbidden +47: float +48: int +49: list +50: bool +51: float +52: forbidden +53: int +54: float 55: unknown 56: int -57: record -58: list> -59: forbidden -60: unknown -61: int +57: unknown +58: int +59: record +60: list> +61: forbidden 62: unknown -63: float +63: int 64: unknown -65: string -66: record -67: list> -68: list> -69: list> +65: float +66: unknown +67: string +68: record +69: list> +70: list>> +71: list>> +==== TYPE ERRORS ==== +Error (NodeId 5): type mismatch: unsupported incompatible types for equal between string and float ==== IR ==== register_count: 1 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap b/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap index d450b35..60d535f 100644 --- a/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap @@ -31,7 +31,7 @@ input_file: tests/calls.nu 24: Variable (80 to 82) "$c" 25: List([NodeId(22), NodeId(23), NodeId(24)]) (70 to 82) 26: Block(BlockId(0)) (68 to 85) -27: Def { name: NodeId(8), params: NodeId(21), in_out_types: None, block: NodeId(26) } (24 to 85) +27: Def { name: NodeId(8), type_params: None, params: NodeId(21), in_out_types: None, block: NodeId(26) } (24 to 85) 28: Name (86 to 94) "existing" 29: Name (95 to 98) "foo" 30: String (100 to 104) ""ba"" @@ -85,7 +85,7 @@ input_file: tests/calls.nu 32: string 33: string 34: int -35: any +35: list 36: unknown 37: stream 38: stream diff --git a/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap b/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap new file mode 100644 index 0000000..8bca912 --- /dev/null +++ b/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap @@ -0,0 +1,52 @@ +--- +source: src/test.rs +expression: evaluate_example(path) +input_file: tests/calls_invalid.nu +--- +==== COMPILER ==== +0: Name (4 to 7) "foo" +1: Name (10 to 11) "a" +2: Name (13 to 16) "int" +3: Type { name: NodeId(2), args: None, optional: false } (13 to 16) +4: Param { name: NodeId(1), ty: Some(NodeId(3)) } (10 to 16) +5: Params([NodeId(4)]) (8 to 18) +6: Block(BlockId(0)) (19 to 21) +7: Def { name: NodeId(0), type_params: None, params: NodeId(5), in_out_types: None, block: NodeId(6) } (0 to 21) +8: Name (22 to 25) "foo" +9: Int (26 to 27) "1" +10: Int (28 to 29) "2" +11: Call { parts: [NodeId(8), NodeId(9), NodeId(10)] } (26 to 29) +12: Name (30 to 33) "foo" +13: String (34 to 42) ""string"" +14: Call { parts: [NodeId(12), NodeId(13)] } (34 to 42) +15: Block(BlockId(1)) (0 to 43) +==== SCOPE ==== +0: Frame Scope, node_id: NodeId(15) + decls: [ foo: NodeId(0) ] +1: Frame Scope, node_id: NodeId(6) + variables: [ a: NodeId(1) ] +==== TYPES ==== +0: unknown +1: unknown +2: unknown +3: int +4: int +5: forbidden +6: () +7: () +8: unknown +9: int +10: int +11: () +12: unknown +13: string +14: () +15: () +==== TYPE ERRORS ==== +Error (NodeId 11): Expected 1 argument(s), got 2 +Error (NodeId 13): Expected int, got string +==== IR ==== +register_count: 0 +file_count: 0 +==== IR ERRORS ==== +Error (NodeId 7): node Def { name: NodeId(0), type_params: None, params: NodeId(5), in_out_types: None, block: NodeId(6) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@def.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def.nu.snap index 12f3561..0baa77c 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def.nu.snap @@ -39,7 +39,7 @@ input_file: tests/def.nu 32: Variable (77 to 79) "$z" 33: List([NodeId(29), NodeId(30), NodeId(31), NodeId(32)]) (64 to 80) 34: Block(BlockId(0)) (62 to 83) -35: Def { name: NodeId(0), params: NodeId(28), in_out_types: None, block: NodeId(34) } (0 to 83) +35: Def { name: NodeId(0), type_params: None, params: NodeId(28), in_out_types: None, block: NodeId(34) } (0 to 83) 36: Block(BlockId(1)) (0 to 83) ==== SCOPE ==== 0: Frame Scope, node_id: NodeId(36) @@ -88,4 +88,4 @@ input_file: tests/def.nu register_count: 0 file_count: 0 ==== IR ERRORS ==== -Error (NodeId 35): node Def { name: NodeId(0), params: NodeId(28), in_out_types: None, block: NodeId(34) } not suported yet +Error (NodeId 35): node Def { name: NodeId(0), type_params: None, params: NodeId(28), in_out_types: None, block: NodeId(34) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap index a73afae..76247ec 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap @@ -17,7 +17,7 @@ input_file: tests/def_return_type.nu 10: InOutTypes([NodeId(9)]) (14 to 35) 11: List([]) (37 to 38) 12: Block(BlockId(0)) (35 to 41) -13: Def { name: NodeId(0), params: NodeId(1), in_out_types: Some(NodeId(10)), block: NodeId(12) } (0 to 41) +13: Def { name: NodeId(0), type_params: None, params: NodeId(1), in_out_types: Some(NodeId(10)), block: NodeId(12) } (0 to 41) 14: Name (46 to 49) "bar" 15: Params([]) (50 to 53) 16: Name (58 to 64) "string" @@ -39,7 +39,7 @@ input_file: tests/def_return_type.nu 32: InOutTypes([NodeId(23), NodeId(31)]) (56 to 101) 33: List([]) (103 to 104) 34: Block(BlockId(1)) (101 to 107) -35: Def { name: NodeId(14), params: NodeId(15), in_out_types: Some(NodeId(32)), block: NodeId(34) } (42 to 107) +35: Def { name: NodeId(14), type_params: None, params: NodeId(15), in_out_types: Some(NodeId(32)), block: NodeId(34) } (42 to 107) 36: Block(BlockId(2)) (0 to 108) ==== SCOPE ==== 0: Frame Scope, node_id: NodeId(36) @@ -50,12 +50,12 @@ input_file: tests/def_return_type.nu 0: unknown 1: forbidden 2: unknown -3: unknown +3: nothing 4: unknown 5: unknown 6: any 7: forbidden -8: unknown +8: list 9: unknown 10: unknown 11: list @@ -64,20 +64,20 @@ input_file: tests/def_return_type.nu 14: unknown 15: forbidden 16: unknown -17: unknown +17: string 18: unknown 19: unknown 20: string 21: forbidden -22: unknown +22: list 23: unknown 24: unknown -25: unknown +25: int 26: unknown 27: unknown 28: int 29: forbidden -30: unknown +30: list 31: unknown 32: unknown 33: list @@ -88,4 +88,4 @@ input_file: tests/def_return_type.nu register_count: 0 file_count: 0 ==== IR ERRORS ==== -Error (NodeId 13): node Def { name: NodeId(0), params: NodeId(1), in_out_types: Some(NodeId(10)), block: NodeId(12) } not suported yet +Error (NodeId 13): node Def { name: NodeId(0), type_params: None, params: NodeId(1), in_out_types: Some(NodeId(10)), block: NodeId(12) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@for_break_continue.nu.snap b/src/snapshots/new_nu_parser__test__node_output@for_break_continue.nu.snap index 86ad7e5..81769c8 100644 --- a/src/snapshots/new_nu_parser__test__node_output@for_break_continue.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@for_break_continue.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/for_break_continue.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Variable (4 to 5) "x" @@ -57,16 +56,16 @@ snapshot_kind: text 9: forbidden 10: int 11: bool -12: unknown -13: unknown -14: oneof<(), unknown> +12: () +13: () +14: () 15: int 16: forbidden 17: int 18: bool -19: unknown -20: unknown -21: oneof<(), unknown> +19: () +20: () +21: () 22: int 23: forbidden 24: int @@ -77,9 +76,6 @@ snapshot_kind: text 29: () 30: () 31: () -==== TYPE ERRORS ==== -Error (NodeId 12): unsupported ast node 'Break' in typechecker -Error (NodeId 19): unsupported ast node 'Continue' in typechecker ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@infer_complex.nu.snap b/src/snapshots/new_nu_parser__test__node_output@infer_complex.nu.snap new file mode 100644 index 0000000..b91bb9c --- /dev/null +++ b/src/snapshots/new_nu_parser__test__node_output@infer_complex.nu.snap @@ -0,0 +1,206 @@ +--- +source: src/test.rs +expression: evaluate_example(path) +input_file: tests/infer_complex.nu +--- +==== COMPILER ==== +0: Name (4 to 5) "f" +1: Name (6 to 7) "A" +2: Name (9 to 10) "B" +3: Params([NodeId(1), NodeId(2)]) (5 to 11) +4: Name (14 to 15) "x" +5: Name (17 to 23) "record" +6: Name (24 to 25) "a" +7: Name (27 to 28) "A" +8: Type { name: NodeId(7), args: None, optional: false } (27 to 28) +9: Param { name: NodeId(6), ty: Some(NodeId(8)) } (24 to 28) +10: Name (30 to 31) "b" +11: Name (33 to 34) "B" +12: Type { name: NodeId(11), args: None, optional: false } (33 to 34) +13: Param { name: NodeId(10), ty: Some(NodeId(12)) } (30 to 34) +14: Params([NodeId(9), NodeId(13)]) (23 to 35) +15: RecordType { fields: NodeId(14), optional: false } (17 to 35) +16: Param { name: NodeId(4), ty: Some(NodeId(15)) } (14 to 35) +17: Name (37 to 38) "y" +18: Name (40 to 46) "record" +19: Name (47 to 48) "a" +20: Name (50 to 51) "A" +21: Type { name: NodeId(20), args: None, optional: false } (50 to 51) +22: Param { name: NodeId(19), ty: Some(NodeId(21)) } (47 to 51) +23: Name (53 to 54) "b" +24: Name (56 to 57) "B" +25: Type { name: NodeId(24), args: None, optional: false } (56 to 57) +26: Param { name: NodeId(23), ty: Some(NodeId(25)) } (53 to 57) +27: Params([NodeId(22), NodeId(26)]) (46 to 58) +28: RecordType { fields: NodeId(27), optional: false } (40 to 59) +29: Param { name: NodeId(17), ty: Some(NodeId(28)) } (37 to 59) +30: Params([NodeId(16), NodeId(29)]) (12 to 60) +31: Name (63 to 70) "nothing" +32: Type { name: NodeId(31), args: None, optional: false } (63 to 70) +33: Name (74 to 80) "record" +34: Name (81 to 82) "a" +35: Name (84 to 85) "A" +36: Type { name: NodeId(35), args: None, optional: false } (84 to 85) +37: Param { name: NodeId(34), ty: Some(NodeId(36)) } (81 to 85) +38: Name (87 to 88) "b" +39: Name (90 to 91) "B" +40: Type { name: NodeId(39), args: None, optional: false } (90 to 91) +41: Param { name: NodeId(38), ty: Some(NodeId(40)) } (87 to 91) +42: Params([NodeId(37), NodeId(41)]) (80 to 92) +43: RecordType { fields: NodeId(42), optional: false } (74 to 93) +44: InOutType(NodeId(32), NodeId(43)) (63 to 93) +45: InOutTypes([NodeId(44)]) (63 to 93) +46: Variable (97 to 99) "$x" +47: Block(BlockId(0)) (93 to 101) +48: Def { name: NodeId(0), type_params: Some(NodeId(3)), params: NodeId(30), in_out_types: Some(NodeId(45)), block: NodeId(47) } (0 to 101) +49: Name (106 to 116) "mysterious" +50: Name (117 to 118) "T" +51: Params([NodeId(50)]) (116 to 119) +52: Name (122 to 123) "x" +53: Name (125 to 128) "int" +54: Type { name: NodeId(53), args: None, optional: false } (125 to 128) +55: Param { name: NodeId(52), ty: Some(NodeId(54)) } (122 to 128) +56: Params([NodeId(55)]) (120 to 130) +57: Name (133 to 140) "nothing" +58: Type { name: NodeId(57), args: None, optional: false } (133 to 140) +59: Name (144 to 145) "T" +60: Type { name: NodeId(59), args: None, optional: false } (144 to 145) +61: InOutType(NodeId(58), NodeId(60)) (133 to 146) +62: InOutTypes([NodeId(61)]) (133 to 146) +63: Block(BlockId(1)) (146 to 148) +64: Def { name: NodeId(49), type_params: Some(NodeId(51)), params: NodeId(56), in_out_types: Some(NodeId(62)), block: NodeId(63) } (102 to 148) +65: Variable (154 to 155) "m" +66: Name (158 to 168) "mysterious" +67: Int (169 to 170) "0" +68: Call { parts: [NodeId(66), NodeId(67)] } (169 to 170) +69: Let { variable_name: NodeId(65), ty: None, initializer: NodeId(68), is_mutable: false } (150 to 170) +70: Variable (175 to 176) "a" +71: Name (178 to 184) "record" +72: Name (185 to 186) "a" +73: Name (188 to 194) "number" +74: Type { name: NodeId(73), args: None, optional: false } (188 to 194) +75: Param { name: NodeId(72), ty: Some(NodeId(74)) } (185 to 194) +76: Params([NodeId(75)]) (184 to 195) +77: RecordType { fields: NodeId(76), optional: false } (178 to 196) +78: Name (198 to 199) "f" +79: String (202 to 203) "a" +80: Int (205 to 208) "123" +81: String (210 to 211) "b" +82: Variable (213 to 215) "$m" +83: Record { pairs: [(NodeId(79), NodeId(80)), (NodeId(81), NodeId(82))] } (200 to 218) +84: String (220 to 221) "a" +85: Float (223 to 227) "12.3" +86: String (229 to 230) "b" +87: String (232 to 237) ""foo"" +88: Record { pairs: [(NodeId(84), NodeId(85)), (NodeId(86), NodeId(87))] } (218 to 239) +89: Call { parts: [NodeId(78), NodeId(83), NodeId(88)] } (200 to 239) +90: Let { variable_name: NodeId(70), ty: Some(NodeId(77)), initializer: NodeId(89), is_mutable: false } (171 to 239) +91: Block(BlockId(2)) (0 to 240) +==== SCOPE ==== +0: Frame Scope, node_id: NodeId(91) + variables: [ a: NodeId(70), m: NodeId(65) ] + decls: [ f: NodeId(0), mysterious: NodeId(49) ] +1: Frame Scope, node_id: NodeId(47) + variables: [ x: NodeId(4), y: NodeId(17) ] + type decls: [ A: NodeId(1), B: NodeId(2) ] +2: Frame Scope, node_id: NodeId(63) + variables: [ x: NodeId(52) ] + type decls: [ T: NodeId(50) ] +==== TYPES ==== +0: unknown +1: unknown +2: unknown +3: unknown +4: unknown +5: unknown +6: unknown +7: unknown +8: A +9: unknown +10: unknown +11: unknown +12: B +13: unknown +14: unknown +15: record +16: record +17: unknown +18: unknown +19: unknown +20: unknown +21: A +22: unknown +23: unknown +24: unknown +25: B +26: unknown +27: unknown +28: record +29: record +30: forbidden +31: unknown +32: nothing +33: unknown +34: unknown +35: unknown +36: A +37: unknown +38: unknown +39: unknown +40: B +41: unknown +42: unknown +43: record +44: unknown +45: unknown +46: record +47: record +48: () +49: unknown +50: unknown +51: unknown +52: unknown +53: unknown +54: int +55: int +56: forbidden +57: unknown +58: nothing +59: unknown +60: T +61: unknown +62: unknown +63: () +64: () +65: bottom +66: unknown +67: int +68: bottom +69: () +70: record +71: unknown +72: unknown +73: unknown +74: number +75: unknown +76: unknown +77: record +78: unknown +79: unknown +80: int +81: unknown +82: bottom +83: record +84: unknown +85: float +86: unknown +87: string +88: record +89: record +90: () +91: () +==== IR ==== +register_count: 0 +file_count: 0 +==== IR ERRORS ==== +Error (NodeId 48): node Def { name: NodeId(0), type_params: Some(NodeId(3)), params: NodeId(30), in_out_types: Some(NodeId(45)), block: NodeId(47) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@infer_generics.nu.snap b/src/snapshots/new_nu_parser__test__node_output@infer_generics.nu.snap new file mode 100644 index 0000000..3a202d4 --- /dev/null +++ b/src/snapshots/new_nu_parser__test__node_output@infer_generics.nu.snap @@ -0,0 +1,78 @@ +--- +source: src/test.rs +expression: evaluate_example(path) +input_file: tests/infer_generics.nu +--- +==== COMPILER ==== +0: Name (4 to 5) "f" +1: Name (6 to 7) "T" +2: Params([NodeId(1)]) (5 to 8) +3: Name (11 to 12) "x" +4: Name (14 to 15) "T" +5: Type { name: NodeId(4), args: None, optional: false } (14 to 15) +6: Param { name: NodeId(3), ty: Some(NodeId(5)) } (11 to 15) +7: Params([NodeId(6)]) (9 to 17) +8: Name (20 to 27) "nothing" +9: Type { name: NodeId(8), args: None, optional: false } (20 to 27) +10: Name (31 to 35) "list" +11: Name (36 to 37) "T" +12: Type { name: NodeId(11), args: None, optional: false } (36 to 37) +13: TypeArgs([NodeId(12)]) (35 to 38) +14: Type { name: NodeId(10), args: Some(NodeId(13)), optional: false } (31 to 35) +15: InOutType(NodeId(9), NodeId(14)) (20 to 39) +16: InOutTypes([NodeId(15)]) (20 to 39) +17: Variable (47 to 48) "z" +18: Name (50 to 51) "T" +19: Type { name: NodeId(18), args: None, optional: false } (50 to 51) +20: Variable (54 to 56) "$x" +21: Let { variable_name: NodeId(17), ty: Some(NodeId(19)), initializer: NodeId(20), is_mutable: false } (43 to 56) +22: Variable (60 to 62) "$z" +23: List([NodeId(22)]) (59 to 62) +24: Block(BlockId(0)) (39 to 65) +25: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(16)), block: NodeId(24) } (0 to 65) +26: Name (67 to 68) "f" +27: Int (69 to 70) "1" +28: Call { parts: [NodeId(26), NodeId(27)] } (69 to 70) +29: Block(BlockId(1)) (0 to 71) +==== SCOPE ==== +0: Frame Scope, node_id: NodeId(29) + decls: [ f: NodeId(0) ] +1: Frame Scope, node_id: NodeId(24) + variables: [ x: NodeId(3), z: NodeId(17) ] + type decls: [ T: NodeId(1) ] +==== TYPES ==== +0: unknown +1: unknown +2: unknown +3: unknown +4: unknown +5: T +6: T +7: forbidden +8: unknown +9: nothing +10: unknown +11: unknown +12: T +13: forbidden +14: list +15: unknown +16: unknown +17: T +18: unknown +19: T +20: T +21: () +22: T +23: list +24: list +25: () +26: unknown +27: int +28: list +29: list +==== IR ==== +register_count: 0 +file_count: 0 +==== IR ERRORS ==== +Error (NodeId 25): node Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(16)), block: NodeId(24) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@infer_plus.nu.snap b/src/snapshots/new_nu_parser__test__node_output@infer_plus.nu.snap new file mode 100644 index 0000000..ff021e2 --- /dev/null +++ b/src/snapshots/new_nu_parser__test__node_output@infer_plus.nu.snap @@ -0,0 +1,81 @@ +--- +source: src/test.rs +expression: evaluate_example(path) +input_file: tests/infer_plus.nu +--- +==== COMPILER ==== +0: Name (4 to 14) "mysterious" +1: Name (15 to 16) "T" +2: Params([NodeId(1)]) (14 to 17) +3: Name (20 to 21) "x" +4: Name (23 to 26) "int" +5: Type { name: NodeId(4), args: None, optional: false } (23 to 26) +6: Param { name: NodeId(3), ty: Some(NodeId(5)) } (20 to 26) +7: Params([NodeId(6)]) (18 to 28) +8: Name (31 to 38) "nothing" +9: Type { name: NodeId(8), args: None, optional: false } (31 to 38) +10: Name (42 to 43) "T" +11: Type { name: NodeId(10), args: None, optional: false } (42 to 43) +12: InOutType(NodeId(9), NodeId(11)) (31 to 44) +13: InOutTypes([NodeId(12)]) (31 to 44) +14: Block(BlockId(0)) (44 to 46) +15: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(13)), block: NodeId(14) } (0 to 46) +16: Variable (52 to 53) "m" +17: Name (56 to 66) "mysterious" +18: Int (67 to 68) "0" +19: Call { parts: [NodeId(17), NodeId(18)] } (67 to 68) +20: Let { variable_name: NodeId(16), ty: None, initializer: NodeId(19), is_mutable: false } (48 to 68) +21: Variable (70 to 72) "$m" +22: Plus (73 to 74) +23: String (75 to 80) ""foo"" +24: BinaryOp { lhs: NodeId(21), op: NodeId(22), rhs: NodeId(23) } (70 to 80) +25: Variable (81 to 83) "$m" +26: Plus (84 to 85) +27: Int (86 to 89) "123" +28: BinaryOp { lhs: NodeId(25), op: NodeId(26), rhs: NodeId(27) } (81 to 89) +29: Block(BlockId(1)) (0 to 90) +==== SCOPE ==== +0: Frame Scope, node_id: NodeId(29) + variables: [ m: NodeId(16) ] + decls: [ mysterious: NodeId(0) ] +1: Frame Scope, node_id: NodeId(14) + variables: [ x: NodeId(3) ] + type decls: [ T: NodeId(1) ] +==== TYPES ==== +0: unknown +1: unknown +2: unknown +3: unknown +4: unknown +5: int +6: int +7: forbidden +8: unknown +9: nothing +10: unknown +11: T +12: unknown +13: unknown +14: () +15: () +16: bottom +17: unknown +18: int +19: bottom +20: () +21: bottom +22: forbidden +23: string +24: string +25: bottom +26: forbidden +27: int +28: string +29: string +==== TYPE ERRORS ==== +Error (NodeId 27): Expected string, got int +==== IR ==== +register_count: 0 +file_count: 0 +==== IR ERRORS ==== +Error (NodeId 15): node Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(13)), block: NodeId(14) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@invalid_if.nu.snap b/src/snapshots/new_nu_parser__test__node_output@invalid_if.nu.snap index e96d1b5..0adf8c9 100644 --- a/src/snapshots/new_nu_parser__test__node_output@invalid_if.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@invalid_if.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/invalid_if.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Int (3 to 4) "1" @@ -22,10 +21,10 @@ snapshot_kind: text 2: int 3: int 4: int -5: error -6: error +5: int +6: int ==== TYPE ERRORS ==== -Error (NodeId 0): The condition for if branch is not a boolean +Error (NodeId 0): Expected bool, got int ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@invalid_types.nu.snap b/src/snapshots/new_nu_parser__test__node_output@invalid_types.nu.snap index 89de622..2791175 100644 --- a/src/snapshots/new_nu_parser__test__node_output@invalid_types.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@invalid_types.nu.snap @@ -17,7 +17,7 @@ input_file: tests/invalid_types.nu 10: Params([NodeId(9)]) (8 to 30) 11: Variable (33 to 35) "$x" 12: Block(BlockId(0)) (31 to 37) -13: Def { name: NodeId(0), params: NodeId(10), in_out_types: None, block: NodeId(12) } (0 to 37) +13: Def { name: NodeId(0), type_params: None, params: NodeId(10), in_out_types: None, block: NodeId(12) } (0 to 37) 14: Name (42 to 45) "bar" 15: Name (47 to 48) "y" 16: Name (50 to 54) "list" @@ -27,7 +27,7 @@ input_file: tests/invalid_types.nu 20: Params([NodeId(19)]) (46 to 57) 21: Variable (60 to 62) "$y" 22: Block(BlockId(1)) (58 to 64) -23: Def { name: NodeId(14), params: NodeId(20), in_out_types: None, block: NodeId(22) } (38 to 64) +23: Def { name: NodeId(14), type_params: None, params: NodeId(20), in_out_types: None, block: NodeId(22) } (38 to 64) 24: Block(BlockId(2)) (0 to 65) ==== SCOPE ==== 0: Frame Scope, node_id: NodeId(24) @@ -69,4 +69,4 @@ Error (NodeId 17): list must have one type argument register_count: 0 file_count: 0 ==== IR ERRORS ==== -Error (NodeId 13): node Def { name: NodeId(0), params: NodeId(10), in_out_types: None, block: NodeId(12) } not suported yet +Error (NodeId 13): node Def { name: NodeId(0), type_params: None, params: NodeId(10), in_out_types: None, block: NodeId(12) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap b/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap index 429b553..88f25c1 100644 --- a/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap @@ -91,9 +91,11 @@ input_file: tests/let_mismatch.nu 39: () 40: () ==== TYPE ERRORS ==== -Error (NodeId 13): initializer does not match declared type -Error (NodeId 26): initializer does not match declared type -Error (NodeId 38): initializer does not match declared type +Error (NodeId 13): Expected string, got int +Error (NodeId 24): Expected int, got string +Error (NodeId 25): Expected list, got list +Error (NodeId 26): Expected list>, got list> +Error (NodeId 38): Expected record, got record ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@loop.nu.snap b/src/snapshots/new_nu_parser__test__node_output@loop.nu.snap index 8b4fb48..0e8bd92 100644 --- a/src/snapshots/new_nu_parser__test__node_output@loop.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@loop.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/loop.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Variable (4 to 5) "x" @@ -31,22 +30,20 @@ snapshot_kind: text 0: int 1: int 2: () -3: unknown -4: unknown -5: unknown -6: unknown -7: unknown -8: unknown -9: unknown -10: unknown -11: unknown -12: unknown -13: unknown -14: unknown -15: unknown -16: unknown -==== TYPE ERRORS ==== -Error (NodeId 15): unsupported ast node 'Loop { block: NodeId(14) }' in typechecker +3: int +4: forbidden +5: int +6: bool +7: () +8: () +9: () +10: int +11: forbidden +12: int +13: () +14: () +15: () +16: () ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@table.nu.snap b/src/snapshots/new_nu_parser__test__node_output@table.nu.snap index 9cc85a8..02ca62a 100644 --- a/src/snapshots/new_nu_parser__test__node_output@table.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@table.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/table.nu -snapshot_kind: text --- ==== COMPILER ==== 0: String (7 to 10) ""a"" @@ -28,10 +27,10 @@ snapshot_kind: text 6: unknown 7: unknown 8: unknown -9: unknown -10: unknown +9: error +10: error ==== TYPE ERRORS ==== -Error (NodeId 9): unsupported ast node 'Table { header: NodeId(2), rows: [NodeId(5), NodeId(8)] }' in typechecker +Error (NodeId 9): Expected an expression to typecheck, got 'Table { header: NodeId(2), rows: [NodeId(5), NodeId(8)] }' ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@table2.nu.snap b/src/snapshots/new_nu_parser__test__node_output@table2.nu.snap index 46ed72c..ad9d253 100644 --- a/src/snapshots/new_nu_parser__test__node_output@table2.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@table2.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/table2.nu -snapshot_kind: text --- ==== COMPILER ==== 0: String (7 to 8) "a" @@ -28,10 +27,10 @@ snapshot_kind: text 6: unknown 7: unknown 8: unknown -9: unknown -10: unknown +9: error +10: error ==== TYPE ERRORS ==== -Error (NodeId 9): unsupported ast node 'Table { header: NodeId(2), rows: [NodeId(5), NodeId(8)] }' in typechecker +Error (NodeId 9): Expected an expression to typecheck, got 'Table { header: NodeId(2), rows: [NodeId(5), NodeId(8)] }' ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/typechecker.rs b/src/typechecker.rs index 5f52cec..10990a1 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -1,8 +1,12 @@ +//! See typechecking.md in the contributing/ folder for more information on +//! how the typechecker works + use crate::compiler::Compiler; use crate::errors::{Severity, SourceError}; use crate::parser::{AstNode, NodeId}; +use crate::resolver::{TypeDecl, TypeDeclId}; use std::cmp::Ordering; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TypeId(pub usize); @@ -14,12 +18,24 @@ pub struct InOutType { pub out_type: TypeId, } +/// A type variable used for type inference +pub struct TypeVar { + lower_bound: TypeId, + upper_bound: TypeId, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TypeVarId(pub usize); + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RecordTypeId(pub usize); #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct OneOfId(pub usize); +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AllOfId(pub usize); + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Type { /// Any node that hasn't been touched by the typechecker will have this type @@ -27,9 +43,14 @@ pub enum Type { /// Some nodes shouldn't be directly evaluated (like operators). These will have a "forbidden" /// to differentiate them from the "unknown" type. Forbidden, + Error, /// None type means that a node has no type. For example, statements like let x = ... do not /// output anything and thus don't have any type. None, + /// Supertype of all types + Top, + /// Subtype of all types + Bottom, Any, Number, Nothing, @@ -42,8 +63,16 @@ pub enum Type { List(TypeId), Stream(TypeId), Record(RecordTypeId), + /// Union type. OneOf types should not be nested and should have at least two elements. + /// They can contain allof types. OneOf(OneOfId), - Error, + /// Intersection type. AllOf types should not be nested and should have at least two elements. + /// They also cannot contain oneof types. + AllOf(AllOfId), + /// A reference to a type declaration such as a type parameter + Ref(TypeDeclId), + /// A type variable that must be solved + Var(TypeVarId), } pub struct Types { @@ -73,6 +102,8 @@ pub const CLOSURE_TYPE: TypeId = TypeId(11); pub const LIST_ANY_TYPE: TypeId = TypeId(12); pub const BYTE_STREAM_TYPE: TypeId = TypeId(13); pub const ERROR_TYPE: TypeId = TypeId(14); +pub const TOP_TYPE: TypeId = TypeId(15); +pub const BOTTOM_TYPE: TypeId = TypeId(16); pub struct Typechecker<'a> { /// Immutable reference to a compiler after the name binding pass @@ -86,8 +117,12 @@ pub struct Typechecker<'a> { /// Record fields used for `RecordType`. Each value in this vector matches with the index in RecordTypeId. /// The individual field lists are stored sorted by field name. pub record_types: Vec>, - /// Types used for `OneOf`. Each value in this vector matches with the index in OneOfId + /// Types used for `OneOf`. Each value in this vector matches with the index in OneOfId. pub oneof_types: Vec>, + /// Types used for `AllOf`. Each value in this vector matches with the index in AllOfId. + pub allof_types: Vec>, + /// Type variables, indexed by TypeVarId + pub type_vars: Vec, /// Type of each Variable in compiler.variables, indexed by VarId pub variable_types: Vec, /// Input/output type pairs of each declaration in compiler.decls, indexed by DeclId @@ -117,10 +152,14 @@ impl<'a> Typechecker<'a> { Type::List(ANY_TYPE), Type::Stream(BINARY_TYPE), Type::Error, + Type::Top, + Type::Bottom, ], node_types: vec![UNKNOWN_TYPE; compiler.ast_nodes.len()], record_types: Vec::new(), oneof_types: Vec::new(), + allof_types: Vec::new(), + type_vars: Vec::new(), variable_types: vec![UNKNOWN_TYPE; compiler.variables.len()], decl_types: vec![ vec![InOutType { @@ -177,7 +216,21 @@ impl<'a> Typechecker<'a> { if !self.compiler.ast_nodes.is_empty() { let last = self.compiler.ast_nodes.len() - 1; let last_node_id = NodeId(last); - self.typecheck_node(last_node_id) + self.typecheck_node(last_node_id); + + for i in 0..self.type_vars.len() { + let var = &self.type_vars[i]; + let bound = var.lower_bound; + let cleaned = self.eliminate_type_vars(bound, TypeVarId(0), true); + self.types[bound.0] = self.types[cleaned.0]; + } + + for i in 0..self.types.len() { + if let Type::Var(var_id) = &self.types[i] { + let bound = self.type_vars[var_id.0].lower_bound; + self.types[i] = self.types[bound.0]; + } + } } } @@ -194,21 +247,6 @@ impl<'a> Typechecker<'a> { fn typecheck_node(&mut self, node_id: NodeId) { match self.compiler.ast_nodes[node_id.0] { - AstNode::Null => { - self.set_node_type_id(node_id, NOTHING_TYPE); - } - AstNode::Int => { - self.set_node_type_id(node_id, INT_TYPE); - } - AstNode::Float => { - self.set_node_type_id(node_id, FLOAT_TYPE); - } - AstNode::True | AstNode::False => { - self.set_node_type_id(node_id, BOOL_TYPE); - } - AstNode::String => { - self.set_node_type_id(node_id, STRING_TYPE); - } AstNode::Params(ref params) => { for param in params { self.typecheck_node(*param); @@ -218,74 +256,160 @@ impl<'a> Typechecker<'a> { } AstNode::Param { name, ty } => { if let Some(ty) = ty { - self.typecheck_node(ty); + let ty_id = self.typecheck_type(ty); let var_id = self .compiler .var_resolution .get(&name) .expect("missing resolved variable"); - self.variable_types[var_id.0] = self.type_id_of(ty); - self.set_node_type_id(node_id, self.type_id_of(ty)); + self.variable_types[var_id.0] = ty_id; + self.set_node_type_id(node_id, ty_id); } else { self.set_node_type_id(node_id, ANY_TYPE); } } - AstNode::Type { - name, - args, - optional, - } => { - let ty_id = self.typecheck_type(name, args, optional); - self.set_node_type_id(node_id, ty_id); - } - AstNode::RecordType { - fields, - optional: _, // TODO handle optional record types - } => { - let AstNode::Params(field_nodes) = self.compiler.get_node(fields) else { - panic!("internal error: record fields aren't Params"); - }; - let mut fields = field_nodes - .iter() - .map(|field| { - let AstNode::Param { name, ty } = self.compiler.get_node(*field) else { - panic!("internal error: record field isn't Param"); - }; - let ty_id = match ty { - Some(ty) => { - self.typecheck_node(*ty); - self.type_id_of(*ty) - } - None => ANY_TYPE, - }; - (*name, ty_id) - }) - .collect::>(); - // Store fields sorted by name - fields.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); - - self.record_types.push(fields); - let ty_id = self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))); - self.set_node_type_id(node_id, ty_id); - } AstNode::TypeArgs(ref args) => { for arg in args { - self.typecheck_node(*arg); + self.typecheck_type(*arg); } // Type argument lists are not supposed to be evaluated self.set_node_type_id(node_id, FORBIDDEN_TYPE); } + AstNode::Block(_) => { + self.typecheck_block(node_id, TOP_TYPE); + } + _ => self.error( + format!( + "unsupported/unexpected ast node '{:?}' in typechecker", + self.compiler.ast_nodes[node_id.0] + ), + node_id, + ), + } + } + + fn typecheck_block(&mut self, node_id: NodeId, expected: TypeId) -> TypeId { + let AstNode::Block(block_id) = self.compiler.ast_nodes[node_id.0] else { + panic!( + "Expected block to typecheck, got '{:?}'", + self.compiler.ast_nodes[node_id.0] + ); + }; + let block = &self.compiler.blocks[block_id.0]; + + for (i, inner_node_id) in block.nodes.iter().enumerate() { + if i == block.nodes.len() - 1 && self.is_expr(*inner_node_id) { + self.typecheck_expr(*inner_node_id, expected); + } else { + self.typecheck_stmt(*inner_node_id); + } + } + + // Block type is the type of the last statement, since blocks + // by themselves aren't supposed to be typed + let block_type = block + .nodes + .last() + .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)); + self.set_node_type_id(node_id, block_type); + block_type + } + + fn typecheck_stmt(&mut self, node_id: NodeId) { + match self.compiler.ast_nodes[node_id.0] { + AstNode::Let { + variable_name, + ty, + initializer, + is_mutable: _, + } => self.typecheck_let(variable_name, ty, initializer, node_id), + AstNode::Def { + name, + params, + in_out_types, + block, + .. + } => self.typecheck_def(name, params, in_out_types, block, node_id), + AstNode::Alias { new_name, old_name } => { + self.typecheck_alias(new_name, old_name, node_id) + } + AstNode::For { + variable, + range, + block, + } => { + // We don't need to typecheck variable after this + self.typecheck_expr(range, TOP_TYPE); + + let var_id = self + .compiler + .var_resolution + .get(&variable) + .expect("missing resolved variable"); + if let Type::List(type_id) = self.type_of(range) { + self.variable_types[var_id.0] = type_id; + self.set_node_type_id(variable, type_id); + } else { + self.variable_types[var_id.0] = ANY_TYPE; + self.set_node_type_id(variable, ERROR_TYPE); + self.error("For loop range is not a list", range); + } + + self.typecheck_node(block); + if self.type_id_of(block) != NONE_TYPE { + self.error("Blocks in looping constructs cannot return values", block); + } + + if self.type_id_of(node_id) != ERROR_TYPE { + self.set_node_type_id(node_id, NONE_TYPE); + } + } + AstNode::While { condition, block } => { + self.typecheck_expr(condition, BOOL_TYPE); + self.typecheck_node(block); + self.set_node_type_id(node_id, NONE_TYPE); + } + AstNode::Loop { block } => { + self.typecheck_node(block); + self.set_node_type_id(node_id, NONE_TYPE); + } + AstNode::Break | AstNode::Continue => { + // TODO make sure we're in a loop + self.set_node_type_id(node_id, NONE_TYPE); + } + _ if self.is_expr(node_id) => { + self.typecheck_expr(node_id, TOP_TYPE); + } + _ => self.error( + format!( + "Expected statement to typecheck, got '{:?}'", + self.compiler.ast_nodes[node_id.0] + ), + node_id, + ), + } + } + + fn typecheck_expr(&mut self, node_id: NodeId, expected: TypeId) -> TypeId { + let ty_id = match self.compiler.ast_nodes[node_id.0] { + AstNode::Null => NOTHING_TYPE, + AstNode::Int => INT_TYPE, + AstNode::Float => FLOAT_TYPE, + AstNode::True | AstNode::False => BOOL_TYPE, + AstNode::String => STRING_TYPE, AstNode::List(ref items) => { + // TODO infer a union type instead if let Some(first_id) = items.first() { - self.typecheck_node(*first_id); + let expected_elem = self.extract_elem_type(expected); + self.typecheck_expr(*first_id, expected_elem.unwrap_or(TOP_TYPE)); let first_type = self.type_of(*first_id); let mut all_numbers = self.is_type_compatible(first_type, Type::Number); let mut all_same = true; for item_id in items.iter().skip(1) { - self.typecheck_node(*item_id); + self.typecheck_expr(*item_id, TOP_TYPE); let item_type = self.type_of(*item_id); if all_numbers && !self.is_type_compatible(item_type, Type::Number) { @@ -298,60 +422,39 @@ impl<'a> Typechecker<'a> { } if all_same { - self.set_node_type(node_id, Type::List(self.type_id_of(*first_id))); + self.push_type(Type::List(self.type_id_of(*first_id))) } else if all_numbers { - self.set_node_type(node_id, Type::List(NUMBER_TYPE)); + self.push_type(Type::List(NUMBER_TYPE)) } else { - self.set_node_type_id(node_id, LIST_ANY_TYPE); + LIST_ANY_TYPE } } else { - self.set_node_type_id(node_id, LIST_ANY_TYPE); + LIST_ANY_TYPE } } AstNode::Record { ref pairs } => { + // TODO take expected type into account let mut field_types = pairs .iter() - .map(|(name, value)| { - self.typecheck_node(*value); - (*name, self.type_id_of(*value)) - }) + .map(|(name, value)| (*name, self.typecheck_expr(*value, TOP_TYPE))) .collect::>(); field_types.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); self.record_types.push(field_types); - let ty_id = self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))); - self.set_node_type_id(node_id, ty_id); - } - AstNode::Block(block_id) => { - let block = &self.compiler.blocks[block_id.0]; - - for inner_node_id in &block.nodes { - self.typecheck_node(*inner_node_id); - } - - // Block type is the type of the last statement, since blocks - // by themselves aren't supposed to be typed - let block_type = block - .nodes - .last() - .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)); - - self.set_node_type_id(node_id, block_type); + self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) } AstNode::Pipeline(pipeline_id) => { let pipeline = &self.compiler.pipelines[pipeline_id.0]; let expressions = pipeline.get_expressions(); for inner in expressions { - self.typecheck_node(*inner) + self.typecheck_expr(*inner, TOP_TYPE); } // pipeline type is the type of the last expression, since blocks // by themselves aren't supposed to be typed - let pipeline_type = expressions + expressions .last() - .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)); - - self.set_node_type_id(node_id, pipeline_type); + .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)) } AstNode::Closure { params, block } => { // TODO: input/output types @@ -360,15 +463,9 @@ impl<'a> Typechecker<'a> { } self.typecheck_node(block); - self.set_node_type_id(node_id, CLOSURE_TYPE); + CLOSURE_TYPE } - AstNode::BinaryOp { lhs, op, rhs } => self.typecheck_binary_op(lhs, op, rhs, node_id), - AstNode::Let { - variable_name, - ty, - initializer, - is_mutable: _, - } => self.typecheck_let(variable_name, ty, initializer, node_id), + AstNode::BinaryOp { lhs, op, rhs } => self.typecheck_binary_op(lhs, op, rhs), AstNode::Variable => { let var_id = self .compiler @@ -376,154 +473,109 @@ impl<'a> Typechecker<'a> { .get(&node_id) .expect("missing resolved variable"); - self.set_node_type_id(node_id, self.variable_types[var_id.0]); + self.variable_types[var_id.0] } AstNode::If { condition, then_block, else_block, } => { - self.typecheck_node(condition); - self.typecheck_node(then_block); + self.typecheck_expr(condition, BOOL_TYPE); - let then_type_id = self.type_id_of(then_block); - let mut else_type = None; + let then_type_id = self.typecheck_block(then_block, expected); if let Some(else_blk) = else_block { - self.typecheck_node(else_blk); - else_type = Some(self.type_of(else_blk)); - } - - let mut types = HashSet::new(); - self.add_resolved_types(&mut types, &then_type_id); - - if let Some(Type::OneOf(id)) = else_type { - types.extend(self.oneof_types[id.0].iter()); - } else if else_type.is_none() { - types.insert(NONE_TYPE); - } else { - types.insert(self.type_id_of(else_block.expect("Already checked"))); - } - - // the condition should always evaluate to a boolean - if self.type_of(condition) != Type::Bool { - self.error("The condition for if branch is not a boolean", condition); - self.set_node_type_id(node_id, ERROR_TYPE); - } else if types.len() > 1 { - self.oneof_types.push(types); - self.set_node_type(node_id, Type::OneOf(OneOfId(self.oneof_types.len() - 1))); + let else_type_id = + if let AstNode::Block(_) = self.compiler.ast_nodes[else_blk.0] { + self.typecheck_block(else_blk, expected) + } else { + self.typecheck_expr(else_blk, expected) + }; + let mut types = HashSet::new(); + types.insert(then_type_id); + types.insert(else_type_id); + self.create_oneof(types) } else { - self.set_node_type_id(node_id, *types.iter().next().expect("Can't be empty")); + // If there's no else block, the if expression is a statement + NONE_TYPE } } - AstNode::Def { - name, - params, - in_out_types, - block, - } => self.typecheck_def(name, params, in_out_types, block, node_id), - AstNode::Alias { new_name, old_name } => { - self.typecheck_alias(new_name, old_name, node_id) - } AstNode::Call { ref parts } => self.typecheck_call(parts, node_id), - AstNode::For { - variable, - range, - block, - } => { - // We don't need to typecheck variable after this - self.typecheck_node(range); - - let var_id = self - .compiler - .var_resolution - .get(&variable) - .expect("missing resolved variable"); - if let Type::List(type_id) = self.type_of(range) { - self.variable_types[var_id.0] = type_id; - self.set_node_type_id(variable, type_id); - } else { - self.variable_types[var_id.0] = ANY_TYPE; - self.set_node_type_id(variable, ERROR_TYPE); - self.error("For loop range is not a list", range); - } - - self.typecheck_node(block); - if self.type_id_of(block) != NONE_TYPE { - self.error("Blocks in looping constructs cannot return values", block); - } - - if self.type_id_of(node_id) != ERROR_TYPE { - self.set_node_type_id(node_id, NONE_TYPE); - } - } - AstNode::While { condition, block } => { - self.typecheck_node(block); - if self.type_id_of(block) != NONE_TYPE { - self.error("Blocks in looping constructs cannot return values", block); - } - - self.typecheck_node(condition); - - // the condition should always evaluate to a boolean - if self.type_of(condition) != Type::Bool { - self.error("The condition for while loop is not a boolean", condition); - self.set_node_type_id(node_id, ERROR_TYPE); - } else { - self.set_node_type_id(node_id, self.type_id_of(block)); - } - } AstNode::Match { ref target, ref match_arms, } => { // Check all the output types of match - let output_types = self.typecheck_match(target, match_arms); - match output_types.len().cmp(&1) { - Ordering::Greater => { - self.oneof_types.push(output_types); - self.set_node_type( - node_id, - Type::OneOf(OneOfId(self.oneof_types.len() - 1)), - ); - } - Ordering::Equal => { - self.set_node_type_id( - node_id, - *output_types - .iter() - .next() - .expect("Will contain one element"), - ); - } - Ordering::Less => { - self.set_node_type_id(node_id, NOTHING_TYPE); - } + let output_types = self.typecheck_match(target, match_arms, expected); + if output_types.is_empty() { + NOTHING_TYPE + } else { + self.create_oneof(output_types) } } - _ => self.error( + _ => { + self.error( + format!( + "Expected an expression to typecheck, got '{:?}'", + self.compiler.ast_nodes[node_id.0] + ), + node_id, + ); + ERROR_TYPE + } + }; + self.set_node_type_id(node_id, ty_id); + + if !self.constrain_subtype(ty_id, expected) { + self.error( format!( - "unsupported ast node '{:?}' in typechecker", - self.compiler.ast_nodes[node_id.0] + "Expected {}, got {}", + self.type_to_string(expected), + self.type_to_string(ty_id) ), node_id, - ), + ); } + + ty_id + } + + fn is_expr(&mut self, node_id: NodeId) -> bool { + matches!( + self.compiler.ast_nodes[node_id.0], + AstNode::Null + | AstNode::Int + | AstNode::Float + | AstNode::True + | AstNode::False + | AstNode::String + | AstNode::Variable + | AstNode::List(_) + | AstNode::Record { .. } + | AstNode::Table { .. } + | AstNode::Pipeline(_) + | AstNode::Closure { .. } + | AstNode::BinaryOp { .. } + | AstNode::If { .. } + | AstNode::Call { .. } + | AstNode::Match { .. } + ) } fn typecheck_match( &mut self, target: &NodeId, match_arms: &Vec<(NodeId, NodeId)>, + expected: TypeId, ) -> HashSet { - self.typecheck_node(*target); + self.typecheck_expr(*target, TOP_TYPE); let mut output_types = HashSet::new(); // typecheck each node let target_id = self.type_id_of(*target); for (match_node, result_node) in match_arms { self.typecheck_node(*match_node); - self.typecheck_node(*result_node); + self.typecheck_expr(*result_node, expected); let match_id = self.type_id_of(*match_node); match (self.type_of(*target), self.type_of(*match_node)) { @@ -568,26 +620,30 @@ impl<'a> Typechecker<'a> { output_types } - fn typecheck_binary_op(&mut self, lhs: NodeId, op: NodeId, rhs: NodeId, node_id: NodeId) { - self.typecheck_node(lhs); - self.typecheck_node(rhs); + fn typecheck_binary_op(&mut self, lhs: NodeId, op: NodeId, rhs: NodeId) -> TypeId { self.set_node_type_id(op, FORBIDDEN_TYPE); - let lhs_type = self.type_of(lhs); - let rhs_type = self.type_of(rhs); - - let out_type = match self.compiler.ast_nodes[op.0] { - AstNode::Equal | AstNode::NotEqual => Some(Type::Bool), + // TODO: better error messages for type mismatches, the previous messages were better + match self.compiler.ast_nodes[op.0] { + AstNode::Equal | AstNode::NotEqual => { + let lhs_ty = self.typecheck_expr(lhs, TOP_TYPE); + let rhs_ty = self.typecheck_expr(rhs, TOP_TYPE); + if !(self.is_subtype(lhs_ty, rhs_ty) + || self.is_subtype(rhs_ty, lhs_ty) + || (self.is_subtype(lhs_ty, NUMBER_TYPE) + && self.is_subtype(rhs_ty, NUMBER_TYPE))) + { + self.binary_op_err("incompatible types for equal", lhs, op, rhs); + } + BOOL_TYPE + } AstNode::LessThan | AstNode::GreaterThan | AstNode::LessThanOrEqual | AstNode::GreaterThanOrEqual => { - if check_numeric_op(lhs_type, rhs_type) == Type::Unknown { - self.binary_op_err("comparison", lhs, op, rhs); - None - } else { - Some(Type::Bool) - } + self.typecheck_expr(lhs, NUMBER_TYPE); + self.typecheck_expr(rhs, NUMBER_TYPE); + BOOL_TYPE } AstNode::Minus | AstNode::Multiply @@ -595,104 +651,129 @@ impl<'a> Typechecker<'a> { | AstNode::FloorDiv | AstNode::Modulo | AstNode::Pow => { - let type_id = check_numeric_op(lhs_type, rhs_type); - - if type_id == Type::Unknown { - self.binary_op_err("math operation", lhs, op, rhs); - None - } else { - Some(type_id) + let lhs_ty = self.typecheck_expr(lhs, NUMBER_TYPE); + let rhs_ty = self.typecheck_expr(rhs, NUMBER_TYPE); + + match (self.types[lhs_ty.0], self.types[rhs_ty.0]) { + (Type::Int, Type::Int) => INT_TYPE, + (Type::Int, Type::Float) => FLOAT_TYPE, + (Type::Float, Type::Int) => FLOAT_TYPE, + (Type::Float, Type::Float) => FLOAT_TYPE, + _ => NUMBER_TYPE, } } - AstNode::RegexMatch | AstNode::NotRegexMatch => match (lhs_type, rhs_type) { - (Type::String | Type::Any, Type::String | Type::Any) => Some(Type::Bool), - _ => { - self.binary_op_err("string operation", lhs, op, rhs); - None - } - }, - AstNode::In => match rhs_type { - Type::String => match lhs_type { - Type::String | Type::Any => Some(Type::Bool), - _ => { - self.binary_op_err("string operation", lhs, op, rhs); - None - } - }, - Type::List(elem_ty) => { - if self.is_type_compatible(lhs_type, self.types[elem_ty.0]) { - Some(Type::Bool) - } else { - self.binary_op_err("list operation", lhs, op, rhs); - None - } - } - Type::Any => Some(Type::Bool), - _ => { - self.binary_op_err("list/string operation", lhs, op, rhs); - None - } - }, - AstNode::And | AstNode::Xor | AstNode::Or => match (lhs_type, rhs_type) { - (Type::Bool, Type::Bool) => Some(Type::Bool), - _ => { - self.binary_op_err("logical operation", lhs, op, rhs); - None - } - }, - AstNode::Plus => { - let ty = check_plus_op(lhs_type, rhs_type); - - if ty == Type::Unknown { - self.binary_op_err("addition", lhs, op, rhs); - None - } else { - Some(ty) - } + AstNode::RegexMatch | AstNode::NotRegexMatch => { + self.typecheck_expr(lhs, STRING_TYPE); + self.typecheck_expr(rhs, STRING_TYPE); + BOOL_TYPE } - AstNode::Append => { - let lhs_type = self.type_of(lhs); - let rhs_type = self.type_of(rhs); - - match (lhs_type, rhs_type) { - (Type::List(lhs_item_id), Type::List(rhs_item_id)) => { - let lhs_item_type = self.types[lhs_item_id.0]; - let rhs_item_type = self.types[rhs_item_id.0]; - let common_type = self.least_common_type(lhs_item_type, rhs_item_type); - let common_type_id = self.push_type(common_type); - Some(Type::List(common_type_id)) + AstNode::In => { + let rhs_type = self.typecheck_expr(rhs, TOP_TYPE); + match self.types[rhs_type.0] { + Type::String => { + self.typecheck_expr(lhs, STRING_TYPE); + BOOL_TYPE } - (Type::List(item_id), rhs_type) => { - let item_type = self.types[item_id.0]; - let common_type = self.least_common_type(item_type, rhs_type); - let common_type_id = self.push_type(common_type); - Some(Type::List(common_type_id)) + Type::List(elem_ty) => { + self.typecheck_expr(lhs, elem_ty); + BOOL_TYPE } - (lhs_type, Type::List(item_id)) => { - let item_type = self.types[item_id.0]; - let common_type = self.least_common_type(lhs_type, item_type); - let common_type_id = self.push_type(common_type); - Some(Type::List(common_type_id)) + Type::Any | Type::Bottom => { + self.typecheck_expr(lhs, TOP_TYPE); + BOOL_TYPE } _ => { - self.binary_op_err("append", lhs, op, rhs); - None + self.binary_op_err("list/string operation", lhs, op, rhs); + ERROR_TYPE } } } - AstNode::Assignment - | AstNode::AddAssignment - | AstNode::SubtractAssignment - | AstNode::MultiplyAssignment - | AstNode::DivideAssignment - | AstNode::AppendAssignment => Some(Type::None), - _ => panic!("internal error: unsupported node passed as binary op: {op:?}"), - }; + AstNode::And | AstNode::Xor | AstNode::Or => { + self.typecheck_expr(lhs, BOOL_TYPE); + self.typecheck_expr(rhs, BOOL_TYPE); + BOOL_TYPE + } + AstNode::Plus => { + let mut types = HashSet::new(); + types.insert(STRING_TYPE); + types.insert(NUMBER_TYPE); + let common_ty = self.create_oneof(types); + + let lhs_ty = self.typecheck_expr(lhs, common_ty); + let lhs_bottom = self.is_subtype(lhs_ty, BOTTOM_TYPE); + if !lhs_bottom && self.is_subtype(lhs_ty, STRING_TYPE) { + self.typecheck_expr(rhs, STRING_TYPE); + STRING_TYPE + } else if !lhs_bottom && self.is_subtype(lhs_ty, NUMBER_TYPE) { + let rhs_ty = self.typecheck_expr(rhs, NUMBER_TYPE); + self.numeric_op_type(lhs_ty, rhs_ty) + } else { + let rhs_ty = self.typecheck_expr(rhs, common_ty); + let rhs_bottom = self.is_subtype(rhs_ty, BOTTOM_TYPE); + if !rhs_bottom && self.is_subtype(rhs_ty, STRING_TYPE) { + if !self.constrain_subtype(lhs_ty, STRING_TYPE) { + self.error( + format!("Expected string, got {}", self.type_to_string(lhs_ty)), + lhs, + ); + } + STRING_TYPE + } else if !rhs_bottom && self.is_subtype(rhs_ty, NUMBER_TYPE) { + if !self.constrain_subtype(lhs_ty, NUMBER_TYPE) { + self.error( + format!("Expected number, got {}", self.type_to_string(lhs_ty)), + lhs, + ); + } + self.numeric_op_type(lhs_ty, rhs_ty) + } else if lhs_bottom && rhs_bottom { + common_ty + } else { + ERROR_TYPE + } + } + } + AstNode::Append => { + // TODO cache these two types + let top_list = self.push_type(Type::List(TOP_TYPE)); + let bottom_list = self.push_type(Type::List(BOTTOM_TYPE)); - if let Some(ty) = out_type { - self.set_node_type(node_id, ty); - } else { - self.set_node_type_id(node_id, ERROR_TYPE); + let res_var = self.new_typevar(bottom_list, top_list); + let res_type = self.push_type(Type::Var(res_var)); + let lhs_type = self.typecheck_expr(lhs, res_type); + let rhs_type = self.typecheck_expr(rhs, res_type); + + if self.is_subtype(lhs_type, LIST_ANY_TYPE) + && self.is_subtype(rhs_type, LIST_ANY_TYPE) + { + res_type + } else { + self.binary_op_err("append", lhs, op, rhs); + ERROR_TYPE + } + } + AstNode::Assignment + | AstNode::AddAssignment + | AstNode::SubtractAssignment + | AstNode::MultiplyAssignment + | AstNode::DivideAssignment + | AstNode::AppendAssignment => { + // TODO: actually check if operands are right for operator + self.typecheck_expr(lhs, TOP_TYPE); + self.typecheck_expr(rhs, TOP_TYPE); + NONE_TYPE + } + _ => panic!("internal error: unsupported node passed as binary op: {op:?}"), + } + } + + fn numeric_op_type(&self, lhs: TypeId, rhs: TypeId) -> TypeId { + match (self.types[lhs.0], self.types[rhs.0]) { + (Type::Int, Type::Int) => INT_TYPE, + (Type::Int, Type::Float) => FLOAT_TYPE, + (Type::Float, Type::Int) => FLOAT_TYPE, + (Type::Float, Type::Float) => FLOAT_TYPE, + _ => NUMBER_TYPE, } } @@ -715,25 +796,9 @@ impl<'a> Typechecker<'a> { let AstNode::InOutType(in_ty, out_ty) = self.compiler.get_node(*ty) else { panic!("internal error: return type is not a return type"); }; - let AstNode::Type { - name: in_name, - args: in_args, - optional: in_optional, - } = *self.compiler.get_node(*in_ty) - else { - panic!("internal error: type is not a type"); - }; - let AstNode::Type { - name: out_name, - args: out_args, - optional: out_optional, - } = *self.compiler.get_node(*out_ty) - else { - panic!("internal error: type is not a type"); - }; InOutType { - in_type: self.typecheck_type(in_name, in_args, in_optional), - out_type: self.typecheck_type(out_name, out_args, out_optional), + in_type: self.typecheck_type(*in_ty), + out_type: self.typecheck_type(*out_ty), } }) .collect::>() @@ -783,24 +848,88 @@ impl<'a> Typechecker<'a> { ); } - fn typecheck_call(&mut self, parts: &[NodeId], node_id: NodeId) { - let num_name_parts = if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { - // TODO: The type should be `oneof` - self.set_node_type_id(node_id, ANY_TYPE); + fn typecheck_call(&mut self, parts: &[NodeId], node_id: NodeId) -> TypeId { + if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { + let num_name_parts = self.compiler.decls[decl_id.0].name().split(' ').count(); + let decl_node_id = self.compiler.decl_nodes[decl_id.0]; + let AstNode::Def { + type_params, + params, + .. + } = self.compiler.get_node(decl_node_id) + else { + panic!("Internal error: Expected def") + }; + let AstNode::Params(params) = self.compiler.get_node(*params) else { + panic!("Internal error: Expected params") + }; + + let type_substs = if let Some(type_params) = type_params { + let AstNode::Params(type_params) = self.compiler.get_node(*type_params) else { + panic!("Internal error: expected type params"); + }; + let mut type_substs = HashMap::new(); + for type_param in type_params.iter() { + let type_decl_id = self.compiler.type_resolution[type_param]; + let var = self.new_typevar(BOTTOM_TYPE, TOP_TYPE); + type_substs.insert(type_decl_id, var); + } + type_substs + } else { + HashMap::new() + }; + + let num_args = parts.len() - num_name_parts; + if params.len() != num_args { + self.error( + format!("Expected {} argument(s), got {}", params.len(), num_args), + node_id, + ); + } + for (param, arg) in params.iter().zip(&parts[num_name_parts..]) { + let expected = self.type_id_of(*param); + let expected = self.subst(expected, &type_substs); + if matches!(self.compiler.ast_nodes[arg.0], AstNode::Name) { + self.set_node_type_id(*arg, STRING_TYPE); + if !self.constrain_subtype(STRING_TYPE, expected) { + self.error( + format!("Expected {}, got string", self.type_to_string(expected)), + *arg, + ); + } + } else { + self.typecheck_expr(*arg, expected); + } + } + if num_args > params.len() { + // Typecheck extra arguments too + for arg in &parts[num_name_parts + params.len()..] { + if matches!(self.compiler.ast_nodes[arg.0], AstNode::Name) { + self.set_node_type_id(*arg, STRING_TYPE); + } else { + self.typecheck_expr(*arg, TOP_TYPE); + } + } + } - self.compiler.decls[decl_id.0].name().split(' ').count() + // TODO base this on pipeline input type + let out_types = self.decl_types[decl_id.0] + .clone() + .iter() + .map(|io| self.subst(io.out_type, &type_substs)) + .collect(); + self.create_oneof(out_types) } else { // external call - self.node_types[node_id.0] = BYTE_STREAM_TYPE; - 1 - }; - - for part in &parts[num_name_parts..] { - if matches!(self.compiler.ast_nodes[part.0], AstNode::Name) { - self.set_node_type_id(*part, STRING_TYPE); - } else { - self.typecheck_node(*part); + for part in &parts[1..] { + if matches!(self.compiler.ast_nodes[part.0], AstNode::Name) { + self.set_node_type_id(*part, STRING_TYPE); + } else { + self.typecheck_expr(*part, TOP_TYPE); + } } + + BYTE_STREAM_TYPE } } @@ -811,15 +940,13 @@ impl<'a> Typechecker<'a> { initializer: NodeId, node_id: NodeId, ) { - self.typecheck_node(initializer); - - if let Some(ty) = ty { - self.typecheck_node(ty); - - if !self.is_type_compatible(self.type_of(ty), self.type_of(initializer)) { - self.error("initializer does not match declared type", initializer) - } - } + let type_id = if let Some(ty) = ty { + let ty_id = self.typecheck_type(ty); + self.typecheck_expr(initializer, ty_id); + ty_id + } else { + self.typecheck_expr(initializer, TOP_TYPE) + }; let var_id = self .compiler @@ -827,18 +954,63 @@ impl<'a> Typechecker<'a> { .get(&variable_name) .expect("missing declared variable"); - let type_id = if let Some(ty) = ty { - self.type_id_of(ty) - } else { - self.type_id_of(initializer) - }; - self.variable_types[var_id.0] = type_id; self.set_node_type_id(variable_name, type_id); self.set_node_type_id(node_id, NONE_TYPE); } - fn typecheck_type( + fn typecheck_type(&mut self, node_id: NodeId) -> TypeId { + let ty_id = match self.compiler.ast_nodes[node_id.0] { + AstNode::Type { + name, + args, + optional, + } => self.typecheck_type_ref(name, args, optional), + AstNode::RecordType { + fields, + optional: _, // TODO handle optional record types + } => { + let AstNode::Params(field_nodes) = self.compiler.get_node(fields) else { + panic!("internal error: record fields aren't Params"); + }; + let mut fields = field_nodes + .iter() + .map(|field| { + let AstNode::Param { name, ty } = self.compiler.get_node(*field) else { + panic!("internal error: record field isn't Param"); + }; + let ty_id = match ty { + Some(ty) => { + self.typecheck_type(*ty); + self.type_id_of(*ty) + } + None => ANY_TYPE, + }; + (*name, ty_id) + }) + .collect::>(); + // Store fields sorted by name + fields.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); + + self.record_types.push(fields); + self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) + } + _ => { + self.error( + format!( + "Internal error: expected type, got '{:?}'", + self.compiler.ast_nodes[node_id.0] + ), + node_id, + ); + ERROR_TYPE + } + }; + self.set_node_type_id(node_id, ty_id); + ty_id + } + + fn typecheck_type_ref( &mut self, name_id: NodeId, args_id: Option, @@ -902,7 +1074,11 @@ impl<'a> Typechecker<'a> { // if bytes.contains(&b'@') { // // type with completion // } else { - UNKNOWN_TYPE + if let Some(type_decl) = self.compiler.type_resolution.get(&name_id) { + self.push_type(Type::Ref(*type_decl)) + } else { + UNKNOWN_TYPE + } // } } } @@ -914,7 +1090,10 @@ impl<'a> Typechecker<'a> { match ty { Type::Unknown => UNKNOWN_TYPE, Type::Forbidden => FORBIDDEN_TYPE, + Type::Error => ERROR_TYPE, Type::None => NONE_TYPE, + Type::Top => TOP_TYPE, + Type::Bottom => BOTTOM_TYPE, Type::Any => ANY_TYPE, Type::Number => NUMBER_TYPE, Type::Nothing => NOTHING_TYPE, @@ -931,64 +1110,359 @@ impl<'a> Typechecker<'a> { } } - fn set_node_type(&mut self, node_id: NodeId, ty: Type) { - let type_id = self.push_type(ty); - self.node_types[node_id.0] = type_id; + /// Replace type parameters (universal type variables) with existential type variables that can be solved + fn subst(&mut self, ty_id: TypeId, substs: &HashMap) -> TypeId { + if substs.is_empty() { + return ty_id; + } + match self.types[ty_id.0] { + Type::Unknown + | Type::Forbidden + | Type::Error + | Type::None + | Type::Top + | Type::Bottom + | Type::Any + | Type::Number + | Type::Nothing + | Type::Int + | Type::Float + | Type::Bool + | Type::String + | Type::Binary + | Type::Var(_) => ty_id, + Type::Closure => todo!(), + Type::List(elem_ty) => { + let new_elem = self.subst(elem_ty, substs); + if elem_ty == new_elem { + ty_id + } else { + self.push_type(Type::List(new_elem)) + } + } + Type::Stream(elem_ty) => { + let new_elem = self.subst(elem_ty, substs); + if elem_ty == new_elem { + ty_id + } else { + self.push_type(Type::Stream(new_elem)) + } + } + Type::Record(record_type_id) => { + let mut fields = self.record_types[record_type_id.0].clone(); + for (_, ty) in fields.iter_mut() { + *ty = self.subst(*ty, substs); + } + self.record_types.push(fields); + self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) + } + Type::OneOf(id) => { + let orig_types = self.oneof_types[id.0].clone(); + let mut new_types = HashSet::new(); + for ty in orig_types.iter() { + new_types.insert(self.subst(*ty, substs)); + } + self.oneof_types.push(new_types); + self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) + } + Type::AllOf(id) => { + let orig_types = self.allof_types[id.0].clone(); + let mut new_types = HashSet::new(); + for ty in orig_types.iter() { + new_types.insert(self.subst(*ty, substs)); + } + self.allof_types.push(new_types); + self.push_type(Type::AllOf(AllOfId(self.allof_types.len() - 1))) + } + Type::Ref(type_decl_id) => { + if let Some(var) = substs.get(&type_decl_id) { + self.push_type(Type::Var(*var)) + } else { + ty_id + } + } + } + } + + /// Given the type for a list, extract the type of its elements + fn extract_elem_type(&mut self, list_ty: TypeId) -> Option { + match self.types[list_ty.0] { + Type::List(elem) => Some(elem), + Type::Top => Some(TOP_TYPE), + Type::Bottom => Some(BOTTOM_TYPE), + Type::Any => Some(ANY_TYPE), + Type::Unknown => Some(UNKNOWN_TYPE), + _ => None, + } } fn set_node_type_id(&mut self, node_id: NodeId, type_id: TypeId) { self.node_types[node_id.0] = type_id; } - /// Finds a "supertype" of two types (e.g., number for float and int) - fn least_common_type(&mut self, lhs: Type, rhs: Type) -> Type { - match (lhs, rhs) { - (Type::List(lhs_id), Type::List(rhs_id)) => { - let item_type = self.least_common_type(self.types[lhs_id.0], self.types[rhs_id.0]); - let item_type_id = self.push_type(item_type); - Type::List(item_type_id) + fn new_typevar(&mut self, lower_bound: TypeId, upper_bound: TypeId) -> TypeVarId { + self.type_vars.push(TypeVar { + lower_bound, + upper_bound, + }); + TypeVarId(self.type_vars.len() - 1) + } + + /// Check if `sub` is a subtype of `supe` + /// + /// Returns `false` if there is a type mismatch, `true` otherwise + /// TODO: return a Result with a message about constraints not being solvable or something + fn constrain_subtype(&mut self, sub_id: TypeId, supe_id: TypeId) -> bool { + if sub_id == supe_id { + return true; + } + match (self.types[sub_id.0], self.types[supe_id.0]) { + (_, Type::Top | Type::Any | Type::Unknown) => true, + (Type::Bottom | Type::Any | Type::Unknown, _) => true, + (Type::Int | Type::Float | Type::Number, Type::Number) => true, + (Type::List(inner_sub), Type::List(inner_supe)) => { + self.constrain_subtype(inner_sub, inner_supe) } - (Type::Record(lhs_id), Type::Record(rhs_id)) => { - let mut common_fields = Vec::new(); + (Type::Record(sub_rec_id), Type::Record(supe_rec_id)) => { + let sub_fields = self.record_types[sub_rec_id.0].clone(); + let supe_fields = self.record_types[supe_rec_id.0].clone(); + + let mut i = 0; + let mut j = 0; + while i < sub_fields.len() && j < supe_fields.len() { + let (sub_name, sub_ty) = sub_fields[i]; + let (supe_name, supe_ty) = supe_fields[j]; + let sub_text = self.compiler.get_span_contents(sub_name); + let supe_text = self.compiler.get_span_contents(supe_name); + match sub_text.cmp(supe_text) { + Ordering::Less => { + i += 1; + } + Ordering::Greater => { + // The field is in the supertype but not the subtype + return false; + } + Ordering::Equal => { + if !self.constrain_subtype(sub_ty, supe_ty) { + return false; + } + i += 1; + j += 1; + } + } + } - let mut l = 0; - let mut r = 0; - while l < self.record_types[lhs_id.0].len() && r < self.record_types[rhs_id.0].len() - { - let (lhs_name, lhs_ty) = self.record_types[lhs_id.0][l]; - let (rhs_name, rhs_ty) = self.record_types[rhs_id.0][r]; - let lhs_text = self.compiler.get_span_contents(lhs_name); - let rhs_text = self.compiler.get_span_contents(rhs_name); - match lhs_text.cmp(rhs_text) { + true + } + (Type::Var(var_id), _) => { + let lb = self.type_vars[var_id.0].lower_bound; + let ub = self.type_vars[var_id.0].upper_bound; + let mut types = HashSet::new(); + types.insert(ub); + types.insert(supe_id); + let new_ub = self.create_allof(types); + // Prevent forward references/cycles + let new_ub = self.eliminate_type_vars(new_ub, var_id, true); + + if self.constrain_subtype(lb, new_ub) { + let var = self + .type_vars + .get_mut(var_id.0) + .expect("type variable must exist"); + var.upper_bound = new_ub; + true + } else { + false + } + } + (_, Type::Var(var_id)) => { + let lb = self.type_vars[var_id.0].lower_bound; + let ub = self.type_vars[var_id.0].upper_bound; + let mut types = HashSet::new(); + types.insert(lb); + types.insert(sub_id); + let new_lb = self.create_oneof(types); + // Prevent forward references/cycles + let new_lb = self.eliminate_type_vars(new_lb, var_id, false); + + if self.constrain_subtype(new_lb, ub) { + let var = self + .type_vars + .get_mut(var_id.0) + .expect("type variable must exist"); + var.lower_bound = new_lb; + true + } else { + false + } + } + (Type::OneOf(id), _) => self.oneof_types[id.0] + .clone() + .iter() + .all(|ty| self.constrain_subtype(*ty, supe_id)), + (_, Type::OneOf(id)) => { + // TODO: actually add constraints? + self.oneof_types[id.0] + .clone() + .iter() + .any(|ty| self.is_subtype(sub_id, *ty)) + } + (sub, supe) if sub == supe => true, + _ => false, + } + } + + /// Check if `sub` is a subtype of `supe` + /// TODO: reduce duplication between this and constrain_subtype + fn is_subtype(&self, sub: TypeId, supe: TypeId) -> bool { + if sub == supe { + return true; + } + match (self.types[sub.0], self.types[supe.0]) { + (_, Type::Top | Type::Any | Type::Unknown) => true, + (Type::Bottom | Type::Any | Type::Unknown, _) => true, + (Type::Int | Type::Float | Type::Number, Type::Number) => true, + (Type::List(inner_sub), Type::List(inner_supe)) => { + self.is_subtype(inner_sub, inner_supe) + } + (Type::Record(sub_rec_id), Type::Record(supe_rec_id)) => { + let sub_fields = self.record_types[sub_rec_id.0].clone(); + let supe_fields = self.record_types[supe_rec_id.0].clone(); + + let mut i = 0; + let mut j = 0; + while i < sub_fields.len() && j < supe_fields.len() { + let (sub_name, sub_ty) = sub_fields[i]; + let (supe_name, supe_ty) = supe_fields[j]; + let sub_text = self.compiler.get_span_contents(sub_name); + let supe_text = self.compiler.get_span_contents(supe_name); + match sub_text.cmp(supe_text) { Ordering::Less => { - l += 1; + i += 1; } Ordering::Greater => { - r += 1; + // The field is in the supertype but not the subtype + return false; } Ordering::Equal => { - let field_ty = - self.least_common_type(self.types[lhs_ty.0], self.types[rhs_ty.0]); - let field_ty_id = self.push_type(field_ty); - common_fields.push((lhs_name, field_ty_id)); - l += 1; - r += 1; + if !self.is_subtype(sub_ty, supe_ty) { + return false; + } + i += 1; + j += 1; } } } - self.record_types.push(common_fields); - Type::Record(RecordTypeId(self.record_types.len() - 1)) + true } - (Type::Int, Type::Float) => Type::Number, - (Type::Int, Type::Number) => Type::Number, - (Type::Float, Type::Float) => Type::Number, - (Type::Float, Type::Number) => Type::Number, - _ => { - if lhs == rhs { - lhs + (Type::Var(var_id), _) => { + let var = &self.type_vars[var_id.0]; + self.is_subtype(var.upper_bound, supe) + } + (_, Type::Var(var_id)) => { + let var = &self.type_vars[var_id.0]; + self.is_subtype(sub, var.lower_bound) + } + (Type::OneOf(id), _) => self.oneof_types[id.0] + .clone() + .iter() + .all(|ty| self.is_subtype(*ty, supe)), + (_, Type::OneOf(id)) => self.oneof_types[id.0] + .iter() + .any(|ty| self.is_subtype(sub, *ty)), + (sub, supe) => sub == supe, + } + } + + /// Eliminate all type variables that are greater than or equal to `max_var` + /// * `use_lower`: If true, replace type variables with their lower bound. + /// Otherwise, replace with their upper bound + fn eliminate_type_vars( + &mut self, + ty_id: TypeId, + max_var: TypeVarId, + use_lower: bool, + ) -> TypeId { + match self.types[ty_id.0] { + Type::Unknown + | Type::Forbidden + | Type::Error + | Type::None + | Type::Top + | Type::Bottom + | Type::Any + | Type::Number + | Type::Nothing + | Type::Int + | Type::Float + | Type::Bool + | Type::String + | Type::Binary + | Type::Ref(_) => ty_id, + Type::Closure => ty_id, + Type::List(inner_ty) => { + let new_inner = self.eliminate_type_vars(inner_ty, max_var, use_lower); + if inner_ty == new_inner { + ty_id } else { - Type::Any + self.push_type(Type::List(new_inner)) + } + } + Type::Stream(inner_ty) => { + let new_inner = self.eliminate_type_vars(inner_ty, max_var, use_lower); + if inner_ty == new_inner { + ty_id + } else { + self.push_type(Type::Stream(new_inner)) + } + } + Type::Record(record_type_id) => { + let mut changed = false; + let mut fields = self.record_types[record_type_id.0].clone(); + for (_, ty) in fields.iter_mut() { + let res = self.eliminate_type_vars(*ty, max_var, use_lower); + if res != *ty { + *ty = res; + changed = true; + } + } + if changed { + self.record_types.push(fields); + self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) + } else { + ty_id + } + } + Type::OneOf(id) => { + let orig_types = self.oneof_types[id.0].clone(); + let mut new_types = HashSet::new(); + for ty in orig_types.iter() { + new_types.insert(self.eliminate_type_vars(*ty, max_var, use_lower)); + } + self.oneof_types.push(new_types); + self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) + } + Type::AllOf(id) => { + let orig_types = self.allof_types[id.0].clone(); + let mut new_types = HashSet::new(); + for ty in orig_types.iter() { + new_types.insert(self.eliminate_type_vars(*ty, max_var, use_lower)); + } + self.allof_types.push(new_types); + self.push_type(Type::AllOf(AllOfId(self.allof_types.len() - 1))) + } + Type::Var(var_id) => { + if var_id.0 < max_var.0 { + ty_id + } else { + let var = &self.type_vars[var_id.0]; + let repl = if use_lower { + var.lower_bound + } else { + var.upper_bound + }; + self.eliminate_type_vars(repl, max_var, use_lower) } } } @@ -1001,6 +1475,9 @@ impl<'a> Typechecker<'a> { Type::Unknown => "unknown".to_string(), Type::Forbidden => "forbidden".to_string(), Type::None => "()".to_string(), + Type::Error => "error".to_string(), + Type::Top => "top".to_string(), + Type::Bottom => "bottom".to_string(), Type::Any => "any".to_string(), Type::Number => "number".to_string(), Type::Nothing => "nothing".to_string(), @@ -1049,7 +1526,37 @@ impl<'a> Typechecker<'a> { fmt.push('>'); fmt } - Type::Error => "error".to_string(), + Type::AllOf(id) => { + let mut fmt = "allof<".to_string(); + let mut types: Vec<_> = self.allof_types[id.0] + .iter() + .map(|ty| self.type_to_string(*ty) + ", ") + .collect(); + types.sort(); + for ty in &types { + fmt += ty; + } + if !types.is_empty() { + fmt.pop(); + fmt.pop(); + } + fmt.push('>'); + fmt + } + Type::Ref(type_decl_id) => match self.compiler.type_decls[type_decl_id.0] { + TypeDecl::Param(name_node) => { + String::from_utf8_lossy(self.compiler.get_span_contents(name_node)).to_string() + } + }, + Type::Var(type_var_id) => { + let var = &self.type_vars[type_var_id.0]; + format!( + "{} <: '{} <: {}", + self.type_to_string(var.lower_bound), + type_var_id.0, + self.type_to_string(var.upper_bound) + ) + } } } @@ -1124,31 +1631,264 @@ impl<'a> Typechecker<'a> { types.insert(*ty); } } -} -/// Check whether two types can perform common numeric operations -fn check_numeric_op(lhs: Type, rhs: Type) -> Type { - match (rhs, lhs) { - (Type::Int, Type::Int) => Type::Int, - (Type::Int, Type::Float) => Type::Float, - (Type::Int, Type::Number) => Type::Number, - (Type::Float, Type::Int) => Type::Float, - (Type::Float, Type::Float) => Type::Float, - (Type::Float, Type::Number) => Type::Float, - (Type::Number, Type::Int) => Type::Number, - (Type::Number, Type::Float) => Type::Float, - (Type::Number, Type::Number) => Type::Number, - (Type::Any, _) => Type::Number, - (_, Type::Any) => Type::Number, - // TODO: Differentiate error based on whether LHS supports the op or not (see type_check.rs) - _ => Type::Unknown, + /// Use this to create any union types, to ensure that the created union type + /// is as simple as possible + fn create_oneof(&mut self, types: HashSet) -> TypeId { + let mut flattened = HashSet::new(); + for ty_id in types { + match self.types[ty_id.0] { + Type::Top | Type::Any | Type::Unknown => return ty_id, + Type::Bottom => {} + Type::OneOf(id) => { + flattened.extend(&self.oneof_types[id.0]); + } + _ => { + flattened.insert(ty_id); + } + } + } + + if flattened.is_empty() { + return BOTTOM_TYPE; + } + + let mut simple_types = HashSet::::new(); + let mut list_elems = HashSet::new(); + let mut record_fields = HashMap::<&[u8], (NodeId, HashSet)>::new(); + for ty_id in flattened { + if simple_types.contains(&ty_id) { + continue; + } + + let ty = self.types[ty_id.0]; + + if ty == Type::Int && simple_types.contains(&FLOAT_TYPE) { + simple_types.remove(&FLOAT_TYPE); + simple_types.insert(NUMBER_TYPE); + continue; + } + if ty == Type::Float && simple_types.contains(&INT_TYPE) { + simple_types.remove(&INT_TYPE); + simple_types.insert(NUMBER_TYPE); + continue; + } + + match ty { + Type::Int if simple_types.contains(&FLOAT_TYPE) => { + simple_types.remove(&FLOAT_TYPE); + simple_types.insert(NUMBER_TYPE); + } + Type::Float if simple_types.contains(&INT_TYPE) => { + simple_types.remove(&INT_TYPE); + simple_types.insert(NUMBER_TYPE); + } + Type::List(elem_ty) => { + list_elems.insert(elem_ty); + } + Type::Record(rec_ty_id) => { + let new_fields = &self.record_types[rec_ty_id.0]; + for (name_node, ty) in new_fields.iter() { + let name = self.compiler.get_span_contents(*name_node); + if let Some((_, types)) = record_fields.get_mut(&name) { + types.insert(*ty); + } else { + let mut types = HashSet::new(); + types.insert(*ty); + record_fields.insert(name, (*name_node, types)); + } + } + } + _ => { + let mut add = true; + let mut remove = HashSet::new(); + for other_id in simple_types.iter() { + if self.is_subtype(ty_id, *other_id) { + add = false; + break; + } + if self.is_subtype(*other_id, ty_id) { + remove.insert(*other_id); + } + } + + if add { + simple_types.insert(ty_id); + for other in remove { + simple_types.remove(&other); + } + } + } + } + } + + if !list_elems.is_empty() { + let elem_oneof = self.create_oneof(list_elems); + simple_types.insert(self.push_type(Type::List(elem_oneof))); + } + + if !record_fields.is_empty() { + let mut fields = Vec::new(); + for (_, (node, types)) in record_fields.into_iter() { + fields.push((node, self.create_oneof(types))); + } + fields.sort_by_cached_key(|(name_node, _)| self.compiler.get_span_contents(*name_node)); + + let rec_ty_id = RecordTypeId(self.record_types.len()); + self.record_types.push(fields); + simple_types.insert(self.push_type(Type::Record(rec_ty_id))); + } + + if simple_types.len() == 1 { + *simple_types + .iter() + .next() + .expect("should have exactly 1 element") + } else { + self.oneof_types.push(simple_types); + self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) + } } -} -/// Check whether two types can perform addition -fn check_plus_op(lhs: Type, rhs: Type) -> Type { - match (rhs, lhs) { - (Type::String, Type::String) => Type::String, - _ => check_numeric_op(lhs, rhs), + /// Use this to create any intersection types, to ensure that the created intersection type + /// is as simple as possible + fn create_allof(&mut self, types: HashSet) -> TypeId { + let mut flattened = HashSet::new(); + for ty_id in types { + match self.types[ty_id.0] { + Type::AllOf(id) => { + flattened.extend(&self.allof_types[id.0]); + } + _ => { + flattened.insert(ty_id); + } + } + } + + if flattened.is_empty() { + return TOP_TYPE; + } + + let mut vars = HashMap::::new(); + let mut refs = HashMap::::new(); + let mut simple_type: Option = None; + let mut list_elems = HashSet::new(); + let mut record_fields = HashMap::<&[u8], (NodeId, HashSet)>::new(); + let mut oneof_ids = Vec::new(); + for ty_id in flattened { + let ty = self.types[ty_id.0]; + + match ty { + Type::Any => return ANY_TYPE, + Type::Unknown => return UNKNOWN_TYPE, + Type::Top => {} + Type::Bottom => return BOTTOM_TYPE, + Type::Var(var_id) => { + vars.insert(var_id, ty_id); + } + Type::Ref(decl_id) => { + refs.insert(decl_id, ty_id); + } + Type::List(elem_ty) => { + if simple_type.is_some() || !record_fields.is_empty() { + return BOTTOM_TYPE; + } + list_elems.insert(elem_ty); + } + Type::Record(rec_ty_id) => { + if simple_type.is_some() || !list_elems.is_empty() { + return BOTTOM_TYPE; + } + let new_fields = &self.record_types[rec_ty_id.0]; + for (name_node, ty) in new_fields.iter() { + let name = self.compiler.get_span_contents(*name_node); + if let Some((_, types)) = record_fields.get_mut(&name) { + types.insert(*ty); + } else { + let mut types = HashSet::new(); + types.insert(*ty); + record_fields.insert(name, (*name_node, types)); + } + } + } + Type::OneOf(id) => { + oneof_ids.push(id); + } + _ => { + if !list_elems.is_empty() && !record_fields.is_empty() { + return BOTTOM_TYPE; + } + if let Some(other_id) = &simple_type { + if self.is_subtype(ty_id, *other_id) { + simple_type = Some(ty_id); + } else if self.is_subtype(*other_id, ty_id) { + } else { + return BOTTOM_TYPE; + } + } else { + simple_type = Some(ty_id); + } + } + } + } + + let mut res = HashSet::new(); + res.extend(vars.values()); + res.extend(refs.values()); + + if let Some(ty) = simple_type { + res.insert(ty); + } + if !list_elems.is_empty() { + let elem_allof = self.create_allof(list_elems); + res.insert(self.push_type(Type::List(elem_allof))); + } + if !record_fields.is_empty() { + let mut fields = Vec::new(); + for (_, (node, types)) in record_fields.into_iter() { + fields.push((node, self.create_oneof(types))); + } + fields.sort_by_cached_key(|(name_node, _)| self.compiler.get_span_contents(*name_node)); + + let rec_ty_id = RecordTypeId(self.record_types.len()); + self.record_types.push(fields); + res.insert(self.push_type(Type::Record(rec_ty_id))); + } + + let single_res = if res.is_empty() { + TOP_TYPE + } else if res.len() == 1 { + *res.iter().next().expect("should have exactly 1 element") + } else { + self.allof_types.push(res); + self.push_type(Type::AllOf(AllOfId(self.allof_types.len() - 1))) + }; + + if oneof_ids.is_empty() { + return single_res; + } + + let mut first_inter = HashSet::new(); + first_inter.insert(single_res); + let mut inters = vec![first_inter]; + + for oneof_id in oneof_ids { + let mut new_inters = vec![]; + let types = &self.oneof_types[oneof_id.0]; + for ty in types.iter() { + for mut inter in inters.clone() { + inter.insert(*ty); + new_inters.push(inter); + } + } + inters = new_inters; + } + + let inters = inters + .into_iter() + .map(|inter| self.create_allof(inter)) + .collect::>(); + + self.create_oneof(inters) } } diff --git a/tests/binary_ops_exact.nu b/tests/binary_ops_exact.nu index 1977041..56c1598 100644 --- a/tests/binary_ops_exact.nu +++ b/tests/binary_ops_exact.nu @@ -1,5 +1,5 @@ 1 == 1 -[true] ++ false +[true] ++ [false] 1 + 1 1.0 + 1.0 true and false diff --git a/tests/binary_ops_subtypes.nu b/tests/binary_ops_subtypes.nu index 618d648..11723a0 100644 --- a/tests/binary_ops_subtypes.nu +++ b/tests/binary_ops_subtypes.nu @@ -1,8 +1,8 @@ 1 == 1.0 "a" == 1.0 1 + 1.0 -[1] ++ 1.0 -[1.0 1] ++ "a" +[1] ++ [1.0] +[1.0 1] ++ ["a"] [[1] [2]] ++ [[3]] [[1] [2]] ++ [[3.0]] 1 in [1.0, 1] diff --git a/tests/calls_invalid.nu b/tests/calls_invalid.nu new file mode 100644 index 0000000..f435080 --- /dev/null +++ b/tests/calls_invalid.nu @@ -0,0 +1,3 @@ +def foo [ a: int ] {} +foo 1 2 +foo "string" diff --git a/tests/infer_complex.nu b/tests/infer_complex.nu new file mode 100644 index 0000000..555f5b9 --- /dev/null +++ b/tests/infer_complex.nu @@ -0,0 +1,7 @@ +def f [ x: record, y: record ] : nothing -> record { + $x +} +def mysterious [ x: int ] : nothing -> T {} + +let m = mysterious 0 +let a: record = f { a: 123, b: $m } { a: 12.3, b: "foo" } diff --git a/tests/infer_generics.nu b/tests/infer_generics.nu new file mode 100644 index 0000000..549de42 --- /dev/null +++ b/tests/infer_generics.nu @@ -0,0 +1,6 @@ +def f [ x: T ] : nothing -> list { + let z: T = $x + [$z] +} + +f 1 diff --git a/tests/infer_plus.nu b/tests/infer_plus.nu new file mode 100644 index 0000000..6a11f01 --- /dev/null +++ b/tests/infer_plus.nu @@ -0,0 +1,6 @@ +def mysterious [ x: int ] : nothing -> T {} + +let m = mysterious 0 + +$m + "foo" +$m + 123