From 5fce79a989504ca73d26443528a729f40aaf6c95 Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Tue, 29 Apr 2025 00:34:39 -0400 Subject: [PATCH 01/16] Parse type params on defs --- src/parser.rs | 34 +++++++++ src/resolver.rs | 1 + ...nu_parser__test__node_output@calls.nu.snap | 2 +- ...r__test__node_output@calls_invalid.nu.snap | 49 +++++++++++++ ...w_nu_parser__test__node_output@def.nu.snap | 4 +- ...er__test__node_output@def_generics.nu.snap | 71 +++++++++++++++++++ ..._test__node_output@def_return_type.nu.snap | 6 +- ...r__test__node_output@invalid_types.nu.snap | 6 +- src/typechecker.rs | 1 + tests/calls_invalid.nu | 3 + tests/def_generics.nu | 4 ++ 11 files changed, 172 insertions(+), 9 deletions(-) create mode 100644 src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap create mode 100644 src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap create mode 100644 tests/calls_invalid.nu create mode 100644 tests/def_generics.nu diff --git a/src/parser.rs b/src/parser.rs index c7b018a..b6f8f89 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -137,6 +137,7 @@ pub enum AstNode { // Definitions Def { name: NodeId, + type_params: Option, params: NodeId, in_out_types: Option, block: NodeId, @@ -935,6 +936,32 @@ impl Parser { self.create_node(AstNode::Params(param_list), span_start, span_end) } + pub fn type_params(&mut self) -> NodeId { + let _span = span!(); + let span_start = self.position(); + self.less_than(); + + let mut param_list = vec![]; + + while self.has_tokens() { + if self.is_greater_than() { + break; + } + + if self.is_comma() { + self.tokens.advance(); + continue; + } + + param_list.push(self.name()); + } + + let span_end = self.position() + 1; + self.greater_than(); + + self.create_node(AstNode::Params(param_list), span_start, span_end) + } + pub fn type_args(&mut self) -> NodeId { let _span = span!(); let span_start = self.position(); @@ -1076,6 +1103,12 @@ impl Parser { _ => return self.error("expected def name"), }; + let type_params = if self.is_less_than() { + Some(self.type_params()) + } else { + None + }; + let params = self.signature_params(ParamsContext::Squares); let in_out_types = if self.is_colon() { Some(self.in_out_types()) @@ -1089,6 +1122,7 @@ impl Parser { self.create_node( AstNode::Def { name, + type_params, params, in_out_types, block, diff --git a/src/resolver.rs b/src/resolver.rs index bbdb7ca..ebe2801 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -220,6 +220,7 @@ impl<'a> Resolver<'a> { } AstNode::Def { name, + type_params, params, in_out_types: _, block, diff --git a/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap b/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap index d450b35..3993779 100644 --- a/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap @@ -31,7 +31,7 @@ input_file: tests/calls.nu 24: Variable (80 to 82) "$c" 25: List([NodeId(22), NodeId(23), NodeId(24)]) (70 to 82) 26: Block(BlockId(0)) (68 to 85) -27: Def { name: NodeId(8), params: NodeId(21), in_out_types: None, block: NodeId(26) } (24 to 85) +27: Def { name: NodeId(8), type_params: None, params: NodeId(21), in_out_types: None, block: NodeId(26) } (24 to 85) 28: Name (86 to 94) "existing" 29: Name (95 to 98) "foo" 30: String (100 to 104) ""ba"" diff --git a/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap b/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap new file mode 100644 index 0000000..eb5ca46 --- /dev/null +++ b/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap @@ -0,0 +1,49 @@ +--- +source: src/test.rs +expression: evaluate_example(path) +input_file: tests/calls_invalid.nu +--- +==== COMPILER ==== +0: Name (4 to 7) "foo" +1: Name (10 to 11) "a" +2: Name (13 to 16) "int" +3: Type { name: NodeId(2), args: None, optional: false } (13 to 16) +4: Param { name: NodeId(1), ty: Some(NodeId(3)) } (10 to 16) +5: Params([NodeId(4)]) (8 to 18) +6: Block(BlockId(0)) (19 to 21) +7: Def { name: NodeId(0), type_params: None, params: NodeId(5), in_out_types: None, block: NodeId(6) } (0 to 21) +8: Name (22 to 25) "foo" +9: Int (26 to 27) "1" +10: Int (28 to 29) "2" +11: Call { parts: [NodeId(8), NodeId(9), NodeId(10)] } (26 to 29) +12: Name (30 to 33) "foo" +13: String (34 to 42) ""string"" +14: Call { parts: [NodeId(12), NodeId(13)] } (34 to 42) +15: Block(BlockId(1)) (0 to 43) +==== SCOPE ==== +0: Frame Scope, node_id: NodeId(15) + decls: [ foo: NodeId(0) ] +1: Frame Scope, node_id: NodeId(6) + variables: [ a: NodeId(1) ] +==== TYPES ==== +0: unknown +1: unknown +2: unknown +3: int +4: int +5: forbidden +6: () +7: () +8: unknown +9: int +10: int +11: any +12: unknown +13: string +14: any +15: any +==== IR ==== +register_count: 0 +file_count: 0 +==== IR ERRORS ==== +Error (NodeId 7): node Def { name: NodeId(0), type_params: None, params: NodeId(5), in_out_types: None, block: NodeId(6) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@def.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def.nu.snap index 12f3561..0baa77c 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def.nu.snap @@ -39,7 +39,7 @@ input_file: tests/def.nu 32: Variable (77 to 79) "$z" 33: List([NodeId(29), NodeId(30), NodeId(31), NodeId(32)]) (64 to 80) 34: Block(BlockId(0)) (62 to 83) -35: Def { name: NodeId(0), params: NodeId(28), in_out_types: None, block: NodeId(34) } (0 to 83) +35: Def { name: NodeId(0), type_params: None, params: NodeId(28), in_out_types: None, block: NodeId(34) } (0 to 83) 36: Block(BlockId(1)) (0 to 83) ==== SCOPE ==== 0: Frame Scope, node_id: NodeId(36) @@ -88,4 +88,4 @@ input_file: tests/def.nu register_count: 0 file_count: 0 ==== IR ERRORS ==== -Error (NodeId 35): node Def { name: NodeId(0), params: NodeId(28), in_out_types: None, block: NodeId(34) } not suported yet +Error (NodeId 35): node Def { name: NodeId(0), type_params: None, params: NodeId(28), in_out_types: None, block: NodeId(34) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap new file mode 100644 index 0000000..2887fb7 --- /dev/null +++ b/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap @@ -0,0 +1,71 @@ +--- +source: src/test.rs +expression: evaluate_example(path) +input_file: tests/def_generics.nu +--- +==== COMPILER ==== +0: Name (4 to 5) "f" +1: Name (6 to 7) "T" +2: Params([NodeId(1)]) (5 to 8) +3: Name (11 to 12) "x" +4: Name (14 to 15) "T" +5: Type { name: NodeId(4), args: None, optional: false } (14 to 15) +6: Param { name: NodeId(3), ty: Some(NodeId(5)) } (11 to 15) +7: Params([NodeId(6)]) (9 to 17) +8: Name (20 to 27) "nothing" +9: Type { name: NodeId(8), args: None, optional: false } (20 to 27) +10: Name (31 to 35) "list" +11: Name (36 to 37) "T" +12: Type { name: NodeId(11), args: None, optional: false } (36 to 37) +13: TypeArgs([NodeId(12)]) (35 to 38) +14: Type { name: NodeId(10), args: Some(NodeId(13)), optional: false } (31 to 35) +15: InOutType(NodeId(9), NodeId(14)) (20 to 39) +16: InOutTypes([NodeId(15)]) (20 to 39) +17: Variable (47 to 48) "z" +18: Name (50 to 51) "T" +19: Type { name: NodeId(18), args: None, optional: false } (50 to 51) +20: Variable (54 to 56) "$x" +21: Let { variable_name: NodeId(17), ty: Some(NodeId(19)), initializer: NodeId(20), is_mutable: false } (43 to 56) +22: Variable (60 to 62) "$z" +23: List([NodeId(22)]) (59 to 62) +24: Block(BlockId(0)) (39 to 65) +25: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(16)), block: NodeId(24) } (0 to 65) +26: Block(BlockId(1)) (0 to 66) +==== SCOPE ==== +0: Frame Scope, node_id: NodeId(26) + decls: [ f: NodeId(0) ] +1: Frame Scope, node_id: NodeId(24) + variables: [ x: NodeId(3), z: NodeId(17) ] +==== TYPES ==== +0: unknown +1: unknown +2: unknown +3: unknown +4: unknown +5: unknown +6: unknown +7: forbidden +8: unknown +9: unknown +10: unknown +11: unknown +12: unknown +13: forbidden +14: unknown +15: unknown +16: unknown +17: unknown +18: unknown +19: unknown +20: unknown +21: () +22: unknown +23: list +24: list +25: () +26: () +==== IR ==== +register_count: 0 +file_count: 0 +==== IR ERRORS ==== +Error (NodeId 25): node Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(16)), block: NodeId(24) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap index a73afae..afe68ad 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap @@ -17,7 +17,7 @@ input_file: tests/def_return_type.nu 10: InOutTypes([NodeId(9)]) (14 to 35) 11: List([]) (37 to 38) 12: Block(BlockId(0)) (35 to 41) -13: Def { name: NodeId(0), params: NodeId(1), in_out_types: Some(NodeId(10)), block: NodeId(12) } (0 to 41) +13: Def { name: NodeId(0), type_params: None, params: NodeId(1), in_out_types: Some(NodeId(10)), block: NodeId(12) } (0 to 41) 14: Name (46 to 49) "bar" 15: Params([]) (50 to 53) 16: Name (58 to 64) "string" @@ -39,7 +39,7 @@ input_file: tests/def_return_type.nu 32: InOutTypes([NodeId(23), NodeId(31)]) (56 to 101) 33: List([]) (103 to 104) 34: Block(BlockId(1)) (101 to 107) -35: Def { name: NodeId(14), params: NodeId(15), in_out_types: Some(NodeId(32)), block: NodeId(34) } (42 to 107) +35: Def { name: NodeId(14), type_params: None, params: NodeId(15), in_out_types: Some(NodeId(32)), block: NodeId(34) } (42 to 107) 36: Block(BlockId(2)) (0 to 108) ==== SCOPE ==== 0: Frame Scope, node_id: NodeId(36) @@ -88,4 +88,4 @@ input_file: tests/def_return_type.nu register_count: 0 file_count: 0 ==== IR ERRORS ==== -Error (NodeId 13): node Def { name: NodeId(0), params: NodeId(1), in_out_types: Some(NodeId(10)), block: NodeId(12) } not suported yet +Error (NodeId 13): node Def { name: NodeId(0), type_params: None, params: NodeId(1), in_out_types: Some(NodeId(10)), block: NodeId(12) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@invalid_types.nu.snap b/src/snapshots/new_nu_parser__test__node_output@invalid_types.nu.snap index 89de622..2791175 100644 --- a/src/snapshots/new_nu_parser__test__node_output@invalid_types.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@invalid_types.nu.snap @@ -17,7 +17,7 @@ input_file: tests/invalid_types.nu 10: Params([NodeId(9)]) (8 to 30) 11: Variable (33 to 35) "$x" 12: Block(BlockId(0)) (31 to 37) -13: Def { name: NodeId(0), params: NodeId(10), in_out_types: None, block: NodeId(12) } (0 to 37) +13: Def { name: NodeId(0), type_params: None, params: NodeId(10), in_out_types: None, block: NodeId(12) } (0 to 37) 14: Name (42 to 45) "bar" 15: Name (47 to 48) "y" 16: Name (50 to 54) "list" @@ -27,7 +27,7 @@ input_file: tests/invalid_types.nu 20: Params([NodeId(19)]) (46 to 57) 21: Variable (60 to 62) "$y" 22: Block(BlockId(1)) (58 to 64) -23: Def { name: NodeId(14), params: NodeId(20), in_out_types: None, block: NodeId(22) } (38 to 64) +23: Def { name: NodeId(14), type_params: None, params: NodeId(20), in_out_types: None, block: NodeId(22) } (38 to 64) 24: Block(BlockId(2)) (0 to 65) ==== SCOPE ==== 0: Frame Scope, node_id: NodeId(24) @@ -69,4 +69,4 @@ Error (NodeId 17): list must have one type argument register_count: 0 file_count: 0 ==== IR ERRORS ==== -Error (NodeId 13): node Def { name: NodeId(0), params: NodeId(10), in_out_types: None, block: NodeId(12) } not suported yet +Error (NodeId 13): node Def { name: NodeId(0), type_params: None, params: NodeId(10), in_out_types: None, block: NodeId(12) } not suported yet diff --git a/src/typechecker.rs b/src/typechecker.rs index 71077e8..4996d9e 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -403,6 +403,7 @@ impl<'a> Typechecker<'a> { } AstNode::Def { name, + type_params, params, in_out_types, block, diff --git a/tests/calls_invalid.nu b/tests/calls_invalid.nu new file mode 100644 index 0000000..f435080 --- /dev/null +++ b/tests/calls_invalid.nu @@ -0,0 +1,3 @@ +def foo [ a: int ] {} +foo 1 2 +foo "string" diff --git a/tests/def_generics.nu b/tests/def_generics.nu new file mode 100644 index 0000000..1e9b850 --- /dev/null +++ b/tests/def_generics.nu @@ -0,0 +1,4 @@ +def f [ x: T ] : nothing -> list { + let z: T = $x + [$z] +} From 8b8816b6d879ac150a0efc66f13458c6ba125267 Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Tue, 29 Apr 2025 10:04:19 -0400 Subject: [PATCH 02/16] Resolve references to type params --- src/resolver.rs | 149 +++++++++++++++++- ...er__test__node_output@def_generics.nu.snap | 1 + 2 files changed, 143 insertions(+), 7 deletions(-) diff --git a/src/resolver.rs b/src/resolver.rs index ebe2801..30f5f69 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -23,6 +23,7 @@ pub enum FrameType { pub struct Frame { pub frame_type: FrameType, pub variables: HashMap, NodeId>, + pub type_decls: HashMap, NodeId>, pub decls: HashMap, NodeId>, /// Node that defined the scope frame (e.g., a block or overlay) pub node_id: NodeId, @@ -33,6 +34,7 @@ impl Frame { Frame { frame_type: scope_type, variables: HashMap::new(), + type_decls: HashMap::new(), decls: HashMap::new(), node_id, } @@ -47,6 +49,16 @@ pub struct Variable { #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub struct VarId(pub usize); +#[derive(Debug, Clone)] +pub enum TypeDecl { + /// A type parameter. Holds the parameter name node + Param(NodeId), + // In the future, we may have type aliases, user-defined classes, etc. +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct TypeDeclId(pub usize); + #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub struct DeclId(pub usize); @@ -56,6 +68,8 @@ pub struct NameBindings { pub scope_stack: Vec, pub variables: Vec, pub var_resolution: HashMap, + pub type_decls: Vec, + pub type_resolution: HashMap, pub decls: Vec>, pub decl_resolution: HashMap, pub errors: Vec, @@ -68,6 +82,8 @@ impl NameBindings { scope_stack: vec![], variables: vec![], var_resolution: HashMap::new(), + type_decls: vec![], + type_resolution: HashMap::new(), decls: vec![], decl_resolution: HashMap::new(), errors: vec![], @@ -93,6 +109,10 @@ pub struct Resolver<'a> { pub variables: Vec, /// Mapping of variable's name node -> Variable pub var_resolution: HashMap, + /// Type declarations, indexed by TypeDeclId + pub type_decls: Vec, + /// Mapping of type decl's name node -> TypeDecl + pub type_resolution: HashMap, /// Declarations (commands, aliases, etc.), indexed by DeclId pub decls: Vec>, /// Mapping of decl's name node -> Command @@ -109,6 +129,8 @@ impl<'a> Resolver<'a> { scope_stack: vec![], variables: vec![], var_resolution: HashMap::new(), + type_decls: vec![], + type_resolution: HashMap::new(), decls: vec![], decl_resolution: HashMap::new(), errors: vec![], @@ -121,6 +143,8 @@ impl<'a> Resolver<'a> { scope_stack: self.scope_stack, variables: self.variables, var_resolution: self.var_resolution, + type_decls: self.type_decls, + type_resolution: self.type_resolution, decls: self.decls, decl_resolution: self.decl_resolution, errors: self.errors, @@ -149,13 +173,19 @@ impl<'a> Resolver<'a> { .map(|(name, id)| format!("{0}: {id:?}", String::from_utf8_lossy(name))) .collect(); + let mut types: Vec = scope + .type_decls + .iter() + .map(|(name, id)| format!("{0}: {id:?}", String::from_utf8_lossy(name))) + .collect(); + let mut decls: Vec = scope .decls .iter() .map(|(name, id)| format!("{0}: {id:?}", String::from_utf8_lossy(name))) .collect(); - if vars.is_empty() && decls.is_empty() { + if vars.is_empty() && types.is_empty() && decls.is_empty() { result.push_str(" (empty)\n"); continue; } @@ -168,6 +198,12 @@ impl<'a> Resolver<'a> { result.push_str(&line_var); } + if !types.is_empty() { + types.sort(); + let line_type = format!(" type decls: [ {0} ]\n", types.join(", ")); + result.push_str(&line_type); + } + if !decls.is_empty() { decls.sort(); let line_decl = format!(" decls: [ {0} ]\n", decls.join(", ")); @@ -222,7 +258,7 @@ impl<'a> Resolver<'a> { name, type_params, params, - in_out_types: _, + in_out_types, block, } => { // define the command before the block to enable recursive calls @@ -230,7 +266,18 @@ impl<'a> Resolver<'a> { // making sure the def parameters and body end up in the same scope frame self.enter_scope(block); + if let Some(type_params) = type_params { + let AstNode::Params(type_params) = self.compiler.get_node(type_params) else { + panic!("Internal error: expected type params") + }; + for type_param_id in type_params { + self.define_type_decl(*type_param_id, TypeDecl::Param(*type_param_id)); + } + } self.resolve_node(params); + if let Some(in_out_types) = in_out_types { + self.resolve_node(in_out_types); + } let def_scope = self.exit_scope(); let AstNode::Block(block_id) = self.compiler.ast_nodes[block.0] else { @@ -247,19 +294,24 @@ impl<'a> Resolver<'a> { } AstNode::Params(ref params) => { for param in params { - if let AstNode::Param { name, .. } = self.compiler.ast_nodes[param.0] { - self.define_variable(name, false); - } else { + let AstNode::Param { name, ty } = self.compiler.ast_nodes[param.0] else { panic!("param is not a param"); + }; + self.define_variable(name, false); + if let Some(ty) = ty { + self.resolve_node(ty); } } } AstNode::Let { variable_name, - ty: _, + ty, initializer, is_mutable, } => { + if let Some(ty) = ty { + self.resolve_node(ty); + } self.resolve_node(initializer); self.define_variable(variable_name, is_mutable) } @@ -339,8 +391,37 @@ impl<'a> Resolver<'a> { } } AstNode::Statement(node) => self.resolve_node(node), + AstNode::Type { name, args, .. } => { + self.resolve_type(name); + if let Some(args) = args { + self.resolve_node(args); + } + } + AstNode::RecordType { fields, .. } => { + let AstNode::Params(fields) = self.compiler.get_node(fields) else { + panic!("Internal error: expected params for record field types"); + }; + for field in fields { + if let AstNode::Param { ty: Some(ty), .. } = self.compiler.get_node(*field) { + self.resolve_node(*ty); + } + } + } + AstNode::TypeArgs(ref args) => { + for arg in args { + self.resolve_node(*arg); + } + } + AstNode::InOutTypes(ref in_out_types) => { + for in_out_ty in in_out_types { + self.resolve_node(*in_out_ty); + } + } + AstNode::InOutType(in_ty, out_ty) => { + self.resolve_node(in_ty); + self.resolve_node(out_ty); + } AstNode::Param { .. } => (/* seems unused for now */), - AstNode::Type { .. } => ( /* probably doesn't make sense to resolve? */ ), AstNode::NamedValue { .. } => (/* seems unused for now */), // All remaining matches do not contain NodeId => there is nothing to resolve _ => (), @@ -366,6 +447,31 @@ impl<'a> Resolver<'a> { } } + pub fn resolve_type(&mut self, unbound_node_id: NodeId) { + let type_name = self.compiler.get_span_contents(unbound_node_id); + + match type_name { + b"any" | b"list" | b"bool" | b"closure" | b"float" | b"int" | b"nothing" + | b"number" | b"string" => return, + _ => {} + } + + if let Some(node_id) = self.find_type(type_name) { + let type_id = self + .type_resolution + .get(&node_id) + .expect("internal error: missing resolved type"); + + self.type_resolution.insert(unbound_node_id, *type_id); + } else { + self.errors.push(SourceError { + message: format!("type `{}` not found", String::from_utf8_lossy(type_name)), + node_id: unbound_node_id, + severity: Severity::Error, + }) + } + } + pub fn resolve_call(&mut self, unbound_node_id: NodeId, parts: &[NodeId]) { // Find out the potentially longest command name let max_name_parts = parts @@ -477,6 +583,25 @@ impl<'a> Resolver<'a> { self.var_resolution.insert(var_name_id, var_id); } + pub fn define_type_decl(&mut self, type_name_id: NodeId, type_decl: TypeDecl) { + let type_name = self.compiler.get_span_contents(type_name_id).to_vec(); + + let current_scope_id = self + .scope_stack + .last() + .expect("internal error: missing scope frame id"); + + self.scope[current_scope_id.0] + .type_decls + .insert(type_name, type_name_id); + + self.type_decls.push(type_decl); + let type_id = TypeDeclId(self.type_decls.len() - 1); + + // let the definition of a type also count as its use + self.type_resolution.insert(type_name_id, type_id); + } + pub fn define_decl(&mut self, decl_name_id: NodeId) { // TODO: Deduplicate code with define_variable() let decl_name = self.compiler.get_span_contents(decl_name_id); @@ -509,6 +634,16 @@ impl<'a> Resolver<'a> { None } + pub fn find_type(&self, type_name: &[u8]) -> Option { + for scope_id in self.scope_stack.iter().rev() { + if let Some(id) = self.scope[scope_id.0].type_decls.get(type_name) { + return Some(*id); + } + } + + None + } + pub fn find_decl(&self, var_name: &[u8]) -> Option { // TODO: Deduplicate code with find_variable() for scope_id in self.scope_stack.iter().rev() { diff --git a/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap index 2887fb7..ea6e0a4 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap @@ -36,6 +36,7 @@ input_file: tests/def_generics.nu decls: [ f: NodeId(0) ] 1: Frame Scope, node_id: NodeId(24) variables: [ x: NodeId(3), z: NodeId(17) ] + type decls: [ T: NodeId(1) ] ==== TYPES ==== 0: unknown 1: unknown From 5c628eea3c46dbc56054644ec36c007142a11757 Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Tue, 29 Apr 2025 22:03:39 -0400 Subject: [PATCH 03/16] Basic typechecking (no generics) --- src/compiler.rs | 19 +- src/resolver.rs | 15 +- ...nu_parser__test__node_output@calls.nu.snap | 2 +- ...r__test__node_output@calls_invalid.nu.snap | 9 +- ...st__node_output@for_break_continue.nu.snap | 16 +- ...rser__test__node_output@invalid_if.nu.snap | 7 +- ...er__test__node_output@let_mismatch.nu.snap | 16 +- ..._nu_parser__test__node_output@loop.nu.snap | 31 +- ...nu_parser__test__node_output@table.nu.snap | 7 +- ...u_parser__test__node_output@table2.nu.snap | 7 +- src/typechecker.rs | 688 ++++++++++++------ 11 files changed, 519 insertions(+), 298 deletions(-) diff --git a/src/compiler.rs b/src/compiler.rs index 98b6c3e..97e20c5 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,7 +1,7 @@ use crate::errors::SourceError; use crate::parser::{AstNode, Block, NodeId}; use crate::protocol::Command; -use crate::resolver::{DeclId, Frame, NameBindings, ScopeId, VarId, Variable}; +use crate::resolver::{DeclId, Frame, NameBindings, ScopeId, TypeDecl, TypeDeclId, VarId, Variable}; use crate::typechecker::{TypeId, Types}; use std::collections::HashMap; @@ -57,8 +57,14 @@ pub struct Compiler { pub variables: Vec, /// Mapping of variable's name node -> Variable pub var_resolution: HashMap, - /// Declarations (commands, aliases, externs), indexed by VarId + /// Type declarations, indexed by TypeDeclId + pub type_decls: Vec, + /// Mapping of type decl's name node -> TypeDecl + pub type_resolution: HashMap, + /// Declarations (commands, aliases, externs), indexed by DeclId pub decls: Vec>, + /// Declaration NodeIds, indexed by DeclId + pub decl_nodes: Vec, /// Mapping of decl's name node -> Command pub decl_resolution: HashMap, @@ -70,7 +76,6 @@ pub struct Compiler { // Use/def // pub call_resolution: HashMap, - // pub type_resolution: HashMap, pub errors: Vec, } @@ -94,7 +99,10 @@ impl Compiler { scope_stack: vec![], variables: vec![], var_resolution: HashMap::new(), + type_decls: vec![], + type_resolution: HashMap::new(), decls: vec![], + decl_nodes: vec![], decl_resolution: HashMap::new(), // variables: vec![], @@ -102,8 +110,6 @@ impl Compiler { // types: vec![], // call_resolution: HashMap::new(), - // var_resolution: HashMap::new(), - // type_resolution: HashMap::new(), errors: vec![], } } @@ -155,7 +161,10 @@ impl Compiler { self.scope_stack.extend(name_bindings.scope_stack); self.variables.extend(name_bindings.variables); self.var_resolution.extend(name_bindings.var_resolution); + self.type_decls.extend(name_bindings.type_decls); + self.type_resolution.extend(name_bindings.type_resolution); self.decls.extend(name_bindings.decls); + self.decl_nodes.extend(name_bindings.decl_nodes); self.decl_resolution.extend(name_bindings.decl_resolution); self.errors.extend(name_bindings.errors); } diff --git a/src/resolver.rs b/src/resolver.rs index 30f5f69..958366f 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -71,6 +71,7 @@ pub struct NameBindings { pub type_decls: Vec, pub type_resolution: HashMap, pub decls: Vec>, + pub decl_nodes: Vec, pub decl_resolution: HashMap, pub errors: Vec, } @@ -85,6 +86,7 @@ impl NameBindings { type_decls: vec![], type_resolution: HashMap::new(), decls: vec![], + decl_nodes: vec![], decl_resolution: HashMap::new(), errors: vec![], } @@ -115,6 +117,8 @@ pub struct Resolver<'a> { pub type_resolution: HashMap, /// Declarations (commands, aliases, etc.), indexed by DeclId pub decls: Vec>, + /// Declaration nodes, indexed by DeclId + pub decl_nodes: Vec, /// Mapping of decl's name node -> Command pub decl_resolution: HashMap, /// Errors encountered during name binding @@ -132,6 +136,7 @@ impl<'a> Resolver<'a> { type_decls: vec![], type_resolution: HashMap::new(), decls: vec![], + decl_nodes: vec![], decl_resolution: HashMap::new(), errors: vec![], } @@ -146,6 +151,7 @@ impl<'a> Resolver<'a> { type_decls: self.type_decls, type_resolution: self.type_resolution, decls: self.decls, + decl_nodes: self.decl_nodes, decl_resolution: self.decl_resolution, errors: self.errors, } @@ -262,7 +268,7 @@ impl<'a> Resolver<'a> { block, } => { // define the command before the block to enable recursive calls - self.define_decl(name); + self.define_decl(name, node_id); // making sure the def parameters and body end up in the same scope frame self.enter_scope(block); @@ -290,7 +296,7 @@ impl<'a> Resolver<'a> { new_name, old_name: _, } => { - self.define_decl(new_name); + self.define_decl(new_name, node_id); } AstNode::Params(ref params) => { for param in params { @@ -602,7 +608,7 @@ impl<'a> Resolver<'a> { self.type_resolution.insert(type_name_id, type_id); } - pub fn define_decl(&mut self, decl_name_id: NodeId) { + pub fn define_decl(&mut self, decl_name_id: NodeId, decl_node_id: NodeId) { // TODO: Deduplicate code with define_variable() let decl_name = self.compiler.get_span_contents(decl_name_id); let decl_name = trim_decl_name(decl_name).to_vec(); @@ -618,8 +624,9 @@ impl<'a> Resolver<'a> { .insert(decl_name, decl_name_id); self.decls.push(Box::new(decl)); - let decl_id = DeclId(self.decls.len() - 1); + self.decl_nodes.push(decl_node_id); + let decl_id = DeclId(self.decls.len() - 1); // let the definition of a decl also count as its use self.decl_resolution.insert(decl_name_id, decl_id); } diff --git a/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap b/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap index 3993779..60d535f 100644 --- a/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@calls.nu.snap @@ -85,7 +85,7 @@ input_file: tests/calls.nu 32: string 33: string 34: int -35: any +35: list 36: unknown 37: stream 38: stream diff --git a/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap b/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap index eb5ca46..8bca912 100644 --- a/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@calls_invalid.nu.snap @@ -37,11 +37,14 @@ input_file: tests/calls_invalid.nu 8: unknown 9: int 10: int -11: any +11: () 12: unknown 13: string -14: any -15: any +14: () +15: () +==== TYPE ERRORS ==== +Error (NodeId 11): Expected 1 argument(s), got 2 +Error (NodeId 13): Expected int, got string ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@for_break_continue.nu.snap b/src/snapshots/new_nu_parser__test__node_output@for_break_continue.nu.snap index 86ad7e5..81769c8 100644 --- a/src/snapshots/new_nu_parser__test__node_output@for_break_continue.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@for_break_continue.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/for_break_continue.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Variable (4 to 5) "x" @@ -57,16 +56,16 @@ snapshot_kind: text 9: forbidden 10: int 11: bool -12: unknown -13: unknown -14: oneof<(), unknown> +12: () +13: () +14: () 15: int 16: forbidden 17: int 18: bool -19: unknown -20: unknown -21: oneof<(), unknown> +19: () +20: () +21: () 22: int 23: forbidden 24: int @@ -77,9 +76,6 @@ snapshot_kind: text 29: () 30: () 31: () -==== TYPE ERRORS ==== -Error (NodeId 12): unsupported ast node 'Break' in typechecker -Error (NodeId 19): unsupported ast node 'Continue' in typechecker ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@invalid_if.nu.snap b/src/snapshots/new_nu_parser__test__node_output@invalid_if.nu.snap index e96d1b5..0adf8c9 100644 --- a/src/snapshots/new_nu_parser__test__node_output@invalid_if.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@invalid_if.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/invalid_if.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Int (3 to 4) "1" @@ -22,10 +21,10 @@ snapshot_kind: text 2: int 3: int 4: int -5: error -6: error +5: int +6: int ==== TYPE ERRORS ==== -Error (NodeId 0): The condition for if branch is not a boolean +Error (NodeId 0): Expected bool, got int ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap b/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap index 429b553..b2dab3d 100644 --- a/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap @@ -49,22 +49,22 @@ input_file: tests/let_mismatch.nu 0: Frame Scope, node_id: NodeId(40) variables: [ v: NodeId(28), w: NodeId(15), x: NodeId(0), y: NodeId(5), z: NodeId(10) ] ==== TYPES ==== -0: number +0: int 1: unknown 2: number 3: int 4: () -5: any +5: string 6: unknown 7: any 8: string 9: () -10: string +10: int 11: unknown 12: string 13: int 14: () -15: list> +15: list> 16: unknown 17: unknown 18: unknown @@ -77,7 +77,7 @@ input_file: tests/let_mismatch.nu 25: list 26: list> 27: () -28: record +28: record 29: unknown 30: unknown 31: unknown @@ -91,9 +91,9 @@ input_file: tests/let_mismatch.nu 39: () 40: () ==== TYPE ERRORS ==== -Error (NodeId 13): initializer does not match declared type -Error (NodeId 26): initializer does not match declared type -Error (NodeId 38): initializer does not match declared type +Error (NodeId 13): Expected string, got int +Error (NodeId 26): Expected list>, got list> +Error (NodeId 38): Expected record, got record ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@loop.nu.snap b/src/snapshots/new_nu_parser__test__node_output@loop.nu.snap index 8b4fb48..0e8bd92 100644 --- a/src/snapshots/new_nu_parser__test__node_output@loop.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@loop.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/loop.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Variable (4 to 5) "x" @@ -31,22 +30,20 @@ snapshot_kind: text 0: int 1: int 2: () -3: unknown -4: unknown -5: unknown -6: unknown -7: unknown -8: unknown -9: unknown -10: unknown -11: unknown -12: unknown -13: unknown -14: unknown -15: unknown -16: unknown -==== TYPE ERRORS ==== -Error (NodeId 15): unsupported ast node 'Loop { block: NodeId(14) }' in typechecker +3: int +4: forbidden +5: int +6: bool +7: () +8: () +9: () +10: int +11: forbidden +12: int +13: () +14: () +15: () +16: () ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@table.nu.snap b/src/snapshots/new_nu_parser__test__node_output@table.nu.snap index 9cc85a8..02ca62a 100644 --- a/src/snapshots/new_nu_parser__test__node_output@table.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@table.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/table.nu -snapshot_kind: text --- ==== COMPILER ==== 0: String (7 to 10) ""a"" @@ -28,10 +27,10 @@ snapshot_kind: text 6: unknown 7: unknown 8: unknown -9: unknown -10: unknown +9: error +10: error ==== TYPE ERRORS ==== -Error (NodeId 9): unsupported ast node 'Table { header: NodeId(2), rows: [NodeId(5), NodeId(8)] }' in typechecker +Error (NodeId 9): Expected an expression to typecheck, got 'Table { header: NodeId(2), rows: [NodeId(5), NodeId(8)] }' ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@table2.nu.snap b/src/snapshots/new_nu_parser__test__node_output@table2.nu.snap index 46ed72c..ad9d253 100644 --- a/src/snapshots/new_nu_parser__test__node_output@table2.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@table2.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/table2.nu -snapshot_kind: text --- ==== COMPILER ==== 0: String (7 to 8) "a" @@ -28,10 +27,10 @@ snapshot_kind: text 6: unknown 7: unknown 8: unknown -9: unknown -10: unknown +9: error +10: error ==== TYPE ERRORS ==== -Error (NodeId 9): unsupported ast node 'Table { header: NodeId(2), rows: [NodeId(5), NodeId(8)] }' in typechecker +Error (NodeId 9): Expected an expression to typecheck, got 'Table { header: NodeId(2), rows: [NodeId(5), NodeId(8)] }' ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/typechecker.rs b/src/typechecker.rs index 4996d9e..6d5235b 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -1,6 +1,7 @@ use crate::compiler::Compiler; use crate::errors::{Severity, SourceError}; use crate::parser::{AstNode, NodeId}; +use crate::resolver::{TypeDecl, TypeDeclId}; use std::cmp::Ordering; use std::collections::HashSet; @@ -14,6 +15,15 @@ pub struct InOutType { pub out_type: TypeId, } +/// A type variable used for type inference +pub struct TypeVar { + lower_bound: TypeId, + upper_bound: TypeId, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TypeVarId(pub usize); + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RecordTypeId(pub usize); @@ -27,9 +37,12 @@ pub enum Type { /// Some nodes shouldn't be directly evaluated (like operators). These will have a "forbidden" /// to differentiate them from the "unknown" type. Forbidden, + Error, /// None type means that a node has no type. For example, statements like let x = ... do not /// output anything and thus don't have any type. None, + Top, + Bottom, Any, Number, Nothing, @@ -43,7 +56,8 @@ pub enum Type { Stream(TypeId), Record(RecordTypeId), OneOf(OneOfId), - Error, + Ref(TypeDeclId), + Var(TypeVarId), } pub struct Types { @@ -73,6 +87,8 @@ pub const CLOSURE_TYPE: TypeId = TypeId(11); pub const LIST_ANY_TYPE: TypeId = TypeId(12); pub const BYTE_STREAM_TYPE: TypeId = TypeId(13); pub const ERROR_TYPE: TypeId = TypeId(14); +pub const TOP_TYPE: TypeId = TypeId(15); +pub const BOTTOM_TYPE: TypeId = TypeId(16); pub struct Typechecker<'a> { /// Immutable reference to a compiler after the name binding pass @@ -88,6 +104,8 @@ pub struct Typechecker<'a> { pub record_types: Vec>, /// Types used for `OneOf`. Each value in this vector matches with the index in OneOfId pub oneof_types: Vec>, + /// Type variables, indexed by TypeVarId + pub type_vars: Vec, /// Type of each Variable in compiler.variables, indexed by VarId pub variable_types: Vec, /// Input/output type pairs of each declaration in compiler.decls, indexed by DeclId @@ -117,10 +135,13 @@ impl<'a> Typechecker<'a> { Type::List(ANY_TYPE), Type::Stream(BINARY_TYPE), Type::Error, + Type::Top, + Type::Bottom, ], node_types: vec![UNKNOWN_TYPE; compiler.ast_nodes.len()], record_types: Vec::new(), oneof_types: Vec::new(), + type_vars: Vec::new(), variable_types: vec![UNKNOWN_TYPE; compiler.variables.len()], decl_types: vec![ vec![InOutType { @@ -194,21 +215,6 @@ impl<'a> Typechecker<'a> { fn typecheck_node(&mut self, node_id: NodeId) { match self.compiler.ast_nodes[node_id.0] { - AstNode::Null => { - self.set_node_type_id(node_id, NOTHING_TYPE); - } - AstNode::Int => { - self.set_node_type_id(node_id, INT_TYPE); - } - AstNode::Float => { - self.set_node_type_id(node_id, FLOAT_TYPE); - } - AstNode::True | AstNode::False => { - self.set_node_type_id(node_id, BOOL_TYPE); - } - AstNode::String => { - self.set_node_type_id(node_id, STRING_TYPE); - } AstNode::Params(ref params) => { for param in params { self.typecheck_node(*param); @@ -218,74 +224,159 @@ impl<'a> Typechecker<'a> { } AstNode::Param { name, ty } => { if let Some(ty) = ty { - self.typecheck_node(ty); + let ty_id = self.typecheck_type(ty); let var_id = self .compiler .var_resolution .get(&name) .expect("missing resolved variable"); - self.variable_types[var_id.0] = self.type_id_of(ty); - self.set_node_type_id(node_id, self.type_id_of(ty)); + self.variable_types[var_id.0] = ty_id; + self.set_node_type_id(node_id, ty_id); } else { self.set_node_type_id(node_id, ANY_TYPE); } } - AstNode::Type { - name, - args, - optional, - } => { - let ty_id = self.typecheck_type(name, args, optional); - self.set_node_type_id(node_id, ty_id); - } - AstNode::RecordType { - fields, - optional: _, // TODO handle optional record types - } => { - let AstNode::Params(field_nodes) = self.compiler.get_node(fields) else { - panic!("internal error: record fields aren't Params"); - }; - let mut fields = field_nodes - .iter() - .map(|field| { - let AstNode::Param { name, ty } = self.compiler.get_node(*field) else { - panic!("internal error: record field isn't Param"); - }; - let ty_id = match ty { - Some(ty) => { - self.typecheck_node(*ty); - self.type_id_of(*ty) - } - None => ANY_TYPE, - }; - (*name, ty_id) - }) - .collect::>(); - // Store fields sorted by name - fields.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); - - self.record_types.push(fields); - let ty_id = self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))); - self.set_node_type_id(node_id, ty_id); - } AstNode::TypeArgs(ref args) => { for arg in args { - self.typecheck_node(*arg); + self.typecheck_type(*arg); } // Type argument lists are not supposed to be evaluated self.set_node_type_id(node_id, FORBIDDEN_TYPE); } + AstNode::Block(_) => { + self.typecheck_block(node_id, TOP_TYPE); + } + _ => self.error( + format!( + "unsupported/unexpected ast node '{:?}' in typechecker", + self.compiler.ast_nodes[node_id.0] + ), + node_id, + ), + } + } + + fn typecheck_block(&mut self, node_id: NodeId, expected: TypeId) -> TypeId { + let AstNode::Block(block_id) = self.compiler.ast_nodes[node_id.0] else { + panic!( + "Expected block to typecheck, got '{:?}'", + self.compiler.ast_nodes[node_id.0] + ); + }; + let block = &self.compiler.blocks[block_id.0]; + + for (i, inner_node_id) in block.nodes.iter().enumerate() { + if i == block.nodes.len() - 1 && self.is_expr(*inner_node_id) { + self.typecheck_expr(*inner_node_id, expected); + } else { + self.typecheck_stmt(*inner_node_id); + } + } + + // Block type is the type of the last statement, since blocks + // by themselves aren't supposed to be typed + let block_type = block + .nodes + .last() + .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)); + self.set_node_type_id(node_id, block_type); + block_type + } + + fn typecheck_stmt(&mut self, node_id: NodeId) { + match self.compiler.ast_nodes[node_id.0] { + AstNode::Let { + variable_name, + ty, + initializer, + is_mutable: _, + } => self.typecheck_let(variable_name, ty, initializer, node_id), + AstNode::Def { + name, + type_params, + params, + in_out_types, + block, + } => self.typecheck_def(name, type_params, params, in_out_types, block, node_id), + AstNode::Alias { new_name, old_name } => { + self.typecheck_alias(new_name, old_name, node_id) + } + AstNode::For { + variable, + range, + block, + } => { + // We don't need to typecheck variable after this + self.typecheck_expr(range, TOP_TYPE); + + let var_id = self + .compiler + .var_resolution + .get(&variable) + .expect("missing resolved variable"); + if let Type::List(type_id) = self.type_of(range) { + self.variable_types[var_id.0] = type_id; + self.set_node_type_id(variable, type_id); + } else { + self.variable_types[var_id.0] = ANY_TYPE; + self.set_node_type_id(variable, ERROR_TYPE); + self.error("For loop range is not a list", range); + } + + self.typecheck_node(block); + if self.type_id_of(block) != NONE_TYPE { + self.error("Blocks in looping constructs cannot return values", block); + } + + if self.type_id_of(node_id) != ERROR_TYPE { + self.set_node_type_id(node_id, NONE_TYPE); + } + } + AstNode::While { condition, block } => { + self.typecheck_expr(condition, BOOL_TYPE); + self.typecheck_node(block); + self.set_node_type_id(node_id, NONE_TYPE); + } + AstNode::Loop { block } => { + self.typecheck_node(block); + self.set_node_type_id(node_id, NONE_TYPE); + } + AstNode::Break | AstNode::Continue => { + // TODO make sure we're in a loop + self.set_node_type_id(node_id, NONE_TYPE); + } + _ if self.is_expr(node_id) => { + self.typecheck_expr(node_id, TOP_TYPE); + } + _ => self.error( + format!( + "Expected statement to typecheck, got '{:?}'", + self.compiler.ast_nodes[node_id.0] + ), + node_id, + ), + } + } + + fn typecheck_expr(&mut self, node_id: NodeId, expected: TypeId) -> TypeId { + let ty_id = match self.compiler.ast_nodes[node_id.0] { + AstNode::Null => NOTHING_TYPE, + AstNode::Int => INT_TYPE, + AstNode::Float => FLOAT_TYPE, + AstNode::True | AstNode::False => BOOL_TYPE, + AstNode::String => STRING_TYPE, AstNode::List(ref items) => { + // TODO inspect the expected type and infer a union type instead if let Some(first_id) = items.first() { - self.typecheck_node(*first_id); + self.typecheck_expr(*first_id, TOP_TYPE); let first_type = self.type_of(*first_id); let mut all_numbers = self.is_type_compatible(first_type, Type::Number); let mut all_same = true; for item_id in items.iter().skip(1) { - self.typecheck_node(*item_id); + self.typecheck_expr(*item_id, TOP_TYPE); let item_type = self.type_of(*item_id); if all_numbers && !self.is_type_compatible(item_type, Type::Number) { @@ -298,45 +389,26 @@ impl<'a> Typechecker<'a> { } if all_same { - self.set_node_type(node_id, Type::List(self.type_id_of(*first_id))); + self.push_type(Type::List(self.type_id_of(*first_id))) } else if all_numbers { - self.set_node_type(node_id, Type::List(NUMBER_TYPE)); + self.push_type(Type::List(NUMBER_TYPE)) } else { - self.set_node_type_id(node_id, LIST_ANY_TYPE); + LIST_ANY_TYPE } } else { - self.set_node_type_id(node_id, LIST_ANY_TYPE); + LIST_ANY_TYPE } } AstNode::Record { ref pairs } => { + // TODO take expected type into account let mut field_types = pairs .iter() - .map(|(name, value)| { - self.typecheck_node(*value); - (*name, self.type_id_of(*value)) - }) + .map(|(name, value)| (*name, self.typecheck_expr(*value, TOP_TYPE))) .collect::>(); field_types.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); self.record_types.push(field_types); - let ty_id = self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))); - self.set_node_type_id(node_id, ty_id); - } - AstNode::Block(block_id) => { - let block = &self.compiler.blocks[block_id.0]; - - for inner_node_id in &block.nodes { - self.typecheck_node(*inner_node_id); - } - - // Block type is the type of the last statement, since blocks - // by themselves aren't supposed to be typed - let block_type = block - .nodes - .last() - .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)); - - self.set_node_type_id(node_id, block_type); + self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) } AstNode::Closure { params, block } => { // TODO: input/output types @@ -345,15 +417,9 @@ impl<'a> Typechecker<'a> { } self.typecheck_node(block); - self.set_node_type_id(node_id, CLOSURE_TYPE); + CLOSURE_TYPE } - AstNode::BinaryOp { lhs, op, rhs } => self.typecheck_binary_op(lhs, op, rhs, node_id), - AstNode::Let { - variable_name, - ty, - initializer, - is_mutable: _, - } => self.typecheck_let(variable_name, ty, initializer, node_id), + AstNode::BinaryOp { lhs, op, rhs } => self.typecheck_binary_op(lhs, op, rhs), AstNode::Variable => { let var_id = self .compiler @@ -361,155 +427,121 @@ impl<'a> Typechecker<'a> { .get(&node_id) .expect("missing resolved variable"); - self.set_node_type_id(node_id, self.variable_types[var_id.0]); + self.variable_types[var_id.0] } AstNode::If { condition, then_block, else_block, } => { - self.typecheck_node(condition); - self.typecheck_node(then_block); + self.typecheck_expr(condition, BOOL_TYPE); - let then_type_id = self.type_id_of(then_block); - let mut else_type = None; + let then_type_id = self.typecheck_block(then_block, expected); if let Some(else_blk) = else_block { - self.typecheck_node(else_blk); - else_type = Some(self.type_of(else_blk)); - } - - let mut types = HashSet::new(); - self.add_resolved_types(&mut types, &then_type_id); - - if let Some(Type::OneOf(id)) = else_type { - types.extend(self.oneof_types[id.0].iter()); - } else if else_type.is_none() { - types.insert(NONE_TYPE); - } else { - types.insert(self.type_id_of(else_block.expect("Already checked"))); - } - - // the condition should always evaluate to a boolean - if self.type_of(condition) != Type::Bool { - self.error("The condition for if branch is not a boolean", condition); - self.set_node_type_id(node_id, ERROR_TYPE); - } else if types.len() > 1 { - self.oneof_types.push(types); - self.set_node_type(node_id, Type::OneOf(OneOfId(self.oneof_types.len() - 1))); + let else_type_id = + if let AstNode::Block(_) = self.compiler.ast_nodes[else_blk.0] { + self.typecheck_block(else_blk, expected) + } else { + self.typecheck_expr(else_blk, expected) + }; + let mut types = HashSet::new(); + self.add_resolved_types(&mut types, &then_type_id); + self.add_resolved_types(&mut types, &else_type_id); + if types.len() > 1 { + self.oneof_types.push(types); + self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) + } else { + *types.iter().next().expect("Can't be empty") + } } else { - self.set_node_type_id(node_id, *types.iter().next().expect("Can't be empty")); + // If there's no else block, the if expression is a statement + NONE_TYPE } } - AstNode::Def { - name, - type_params, - params, - in_out_types, - block, - } => self.typecheck_def(name, params, in_out_types, block, node_id), - AstNode::Alias { new_name, old_name } => { - self.typecheck_alias(new_name, old_name, node_id) - } AstNode::Call { ref parts } => self.typecheck_call(parts, node_id), - AstNode::For { - variable, - range, - block, - } => { - // We don't need to typecheck variable after this - self.typecheck_node(range); - - let var_id = self - .compiler - .var_resolution - .get(&variable) - .expect("missing resolved variable"); - if let Type::List(type_id) = self.type_of(range) { - self.variable_types[var_id.0] = type_id; - self.set_node_type_id(variable, type_id); - } else { - self.variable_types[var_id.0] = ANY_TYPE; - self.set_node_type_id(variable, ERROR_TYPE); - self.error("For loop range is not a list", range); - } - - self.typecheck_node(block); - if self.type_id_of(block) != NONE_TYPE { - self.error("Blocks in looping constructs cannot return values", block); - } - - if self.type_id_of(node_id) != ERROR_TYPE { - self.set_node_type_id(node_id, NONE_TYPE); - } - } - AstNode::While { condition, block } => { - self.typecheck_node(block); - if self.type_id_of(block) != NONE_TYPE { - self.error("Blocks in looping constructs cannot return values", block); - } - - self.typecheck_node(condition); - - // the condition should always evaluate to a boolean - if self.type_of(condition) != Type::Bool { - self.error("The condition for while loop is not a boolean", condition); - self.set_node_type_id(node_id, ERROR_TYPE); - } else { - self.set_node_type_id(node_id, self.type_id_of(block)); - } - } AstNode::Match { ref target, ref match_arms, } => { // Check all the output types of match - let output_types = self.typecheck_match(target, match_arms); + let output_types = self.typecheck_match(target, match_arms, expected); match output_types.len().cmp(&1) { Ordering::Greater => { self.oneof_types.push(output_types); - self.set_node_type( - node_id, - Type::OneOf(OneOfId(self.oneof_types.len() - 1)), - ); - } - Ordering::Equal => { - self.set_node_type_id( - node_id, - *output_types - .iter() - .next() - .expect("Will contain one element"), - ); - } - Ordering::Less => { - self.set_node_type_id(node_id, NOTHING_TYPE); + self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) } + Ordering::Equal => *output_types + .iter() + .next() + .expect("Will contain one element"), + Ordering::Less => NOTHING_TYPE, } } - _ => self.error( + _ => { + self.error( + format!( + "Expected an expression to typecheck, got '{:?}'", + self.compiler.ast_nodes[node_id.0] + ), + node_id, + ); + ERROR_TYPE + } + }; + self.set_node_type_id(node_id, ty_id); + + let got = self.types[ty_id.0]; + let exp = self.types[expected.0]; + if !self.is_subtype(got, exp) { + self.error( format!( - "unsupported ast node '{:?}' in typechecker", - self.compiler.ast_nodes[node_id.0] + "Expected {}, got {}", + self.type_to_string(expected), + self.type_to_string(ty_id) ), node_id, - ), + ); } + + ty_id + } + + fn is_expr(&mut self, node_id: NodeId) -> bool { + matches!( + self.compiler.ast_nodes[node_id.0], + AstNode::Null + | AstNode::Int + | AstNode::Float + | AstNode::True + | AstNode::False + | AstNode::String + | AstNode::Variable + | AstNode::List(_) + | AstNode::Record { .. } + | AstNode::Table { .. } + | AstNode::Closure { .. } + | AstNode::BinaryOp { .. } + | AstNode::If { .. } + | AstNode::Call { .. } + | AstNode::Match { .. } + ) } fn typecheck_match( &mut self, target: &NodeId, match_arms: &Vec<(NodeId, NodeId)>, + expected: TypeId, ) -> HashSet { - self.typecheck_node(*target); + self.typecheck_expr(*target, TOP_TYPE); let mut output_types = HashSet::new(); // typecheck each node let target_id = self.type_id_of(*target); for (match_node, result_node) in match_arms { self.typecheck_node(*match_node); - self.typecheck_node(*result_node); + self.typecheck_expr(*result_node, expected); let match_id = self.type_id_of(*match_node); match (self.type_of(*target), self.type_of(*match_node)) { @@ -554,9 +586,9 @@ impl<'a> Typechecker<'a> { output_types } - fn typecheck_binary_op(&mut self, lhs: NodeId, op: NodeId, rhs: NodeId, node_id: NodeId) { - self.typecheck_node(lhs); - self.typecheck_node(rhs); + fn typecheck_binary_op(&mut self, lhs: NodeId, op: NodeId, rhs: NodeId) -> TypeId { + self.typecheck_expr(lhs, TOP_TYPE); + self.typecheck_expr(rhs, TOP_TYPE); self.set_node_type_id(op, FORBIDDEN_TYPE); let lhs_type = self.type_of(lhs); @@ -676,15 +708,16 @@ impl<'a> Typechecker<'a> { }; if let Some(ty) = out_type { - self.set_node_type(node_id, ty); + self.push_type(ty) } else { - self.set_node_type_id(node_id, ERROR_TYPE); + ERROR_TYPE } } fn typecheck_def( &mut self, name: NodeId, + type_params: Option, params: NodeId, in_out_types: Option, block: NodeId, @@ -718,8 +751,8 @@ impl<'a> Typechecker<'a> { panic!("internal error: type is not a type"); }; InOutType { - in_type: self.typecheck_type(in_name, in_args, in_optional), - out_type: self.typecheck_type(out_name, out_args, out_optional), + in_type: self.typecheck_type_ref(in_name, in_args, in_optional), + out_type: self.typecheck_type_ref(out_name, out_args, out_optional), } }) .collect::>() @@ -769,24 +802,69 @@ impl<'a> Typechecker<'a> { ); } - fn typecheck_call(&mut self, parts: &[NodeId], node_id: NodeId) { - let num_name_parts = if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { + fn typecheck_call(&mut self, parts: &[NodeId], node_id: NodeId) -> TypeId { + if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { // TODO: The type should be `oneof` self.set_node_type_id(node_id, ANY_TYPE); - self.compiler.decls[decl_id.0].name().split(' ').count() + let num_name_parts = self.compiler.decls[decl_id.0].name().split(' ').count(); + let decl_node_id = self.compiler.decl_nodes[decl_id.0]; + let AstNode::Def { params, .. } = self.compiler.get_node(decl_node_id) else { + panic!("Internal error: Expected def") + }; + let AstNode::Params(params) = self.compiler.get_node(*params) else { + panic!("Internal error: Expected def") + }; + let num_args = parts.len() - num_name_parts; + if params.len() != num_args { + self.error( + format!("Expected {} argument(s), got {}", params.len(), num_args), + node_id, + ); + } + for (param, arg) in params.iter().zip(&parts[num_name_parts..]) { + let expected = self.type_id_of(*param); + if matches!(self.compiler.ast_nodes[arg.0], AstNode::Name) { + self.set_node_type_id(*arg, STRING_TYPE); + if !self.is_subtype(Type::String, self.types[expected.0]) { + self.error( + format!("Expected {}, got string", self.type_to_string(expected)), + *arg, + ); + } + } else { + self.typecheck_expr(*arg, expected); + } + } + if num_args > params.len() { + // Typecheck extra arguments too + for arg in &parts[num_name_parts + params.len()..] { + if matches!(self.compiler.ast_nodes[arg.0], AstNode::Name) { + self.set_node_type_id(*arg, STRING_TYPE); + } else { + self.typecheck_expr(*arg, TOP_TYPE); + } + } + } + + // TODO base this on pipeline input type + self.create_oneof( + self.decl_types[decl_id.0] + .iter() + .map(|io| io.out_type) + .collect(), + ) } else { // external call - self.node_types[node_id.0] = BYTE_STREAM_TYPE; - 1 - }; - - for part in &parts[num_name_parts..] { - if matches!(self.compiler.ast_nodes[part.0], AstNode::Name) { - self.set_node_type_id(*part, STRING_TYPE); - } else { - self.typecheck_node(*part); + for part in &parts[1..] { + if matches!(self.compiler.ast_nodes[part.0], AstNode::Name) { + self.set_node_type_id(*part, STRING_TYPE); + } else { + self.typecheck_expr(*part, TOP_TYPE); + } } + + BYTE_STREAM_TYPE } } @@ -797,15 +875,12 @@ impl<'a> Typechecker<'a> { initializer: NodeId, node_id: NodeId, ) { - self.typecheck_node(initializer); - - if let Some(ty) = ty { - self.typecheck_node(ty); - - if !self.is_type_compatible(self.type_of(ty), self.type_of(initializer)) { - self.error("initializer does not match declared type", initializer) - } - } + let type_id = if let Some(ty) = ty { + let ty_id = self.typecheck_type(ty); + self.typecheck_expr(initializer, ty_id) + } else { + self.typecheck_expr(initializer, TOP_TYPE) + }; let var_id = self .compiler @@ -813,18 +888,63 @@ impl<'a> Typechecker<'a> { .get(&variable_name) .expect("missing declared variable"); - let type_id = if let Some(ty) = ty { - self.type_id_of(ty) - } else { - self.type_id_of(initializer) - }; - self.variable_types[var_id.0] = type_id; self.set_node_type_id(variable_name, type_id); self.set_node_type_id(node_id, NONE_TYPE); } - fn typecheck_type( + fn typecheck_type(&mut self, node_id: NodeId) -> TypeId { + let ty_id = match self.compiler.ast_nodes[node_id.0] { + AstNode::Type { + name, + args, + optional, + } => self.typecheck_type_ref(name, args, optional), + AstNode::RecordType { + fields, + optional: _, // TODO handle optional record types + } => { + let AstNode::Params(field_nodes) = self.compiler.get_node(fields) else { + panic!("internal error: record fields aren't Params"); + }; + let mut fields = field_nodes + .iter() + .map(|field| { + let AstNode::Param { name, ty } = self.compiler.get_node(*field) else { + panic!("internal error: record field isn't Param"); + }; + let ty_id = match ty { + Some(ty) => { + self.typecheck_type(*ty); + self.type_id_of(*ty) + } + None => ANY_TYPE, + }; + (*name, ty_id) + }) + .collect::>(); + // Store fields sorted by name + fields.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); + + self.record_types.push(fields); + self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) + } + _ => { + self.error( + format!( + "unsupported/unexpected ast node '{:?}' in typechecker", + self.compiler.ast_nodes[node_id.0] + ), + node_id, + ); + ERROR_TYPE + } + }; + self.set_node_type_id(node_id, ty_id); + ty_id + } + + fn typecheck_type_ref( &mut self, name_id: NodeId, args_id: Option, @@ -900,7 +1020,10 @@ impl<'a> Typechecker<'a> { match ty { Type::Unknown => UNKNOWN_TYPE, Type::Forbidden => FORBIDDEN_TYPE, + Type::Error => ERROR_TYPE, Type::None => NONE_TYPE, + Type::Top => TOP_TYPE, + Type::Bottom => BOTTOM_TYPE, Type::Any => ANY_TYPE, Type::Number => NUMBER_TYPE, Type::Nothing => NOTHING_TYPE, @@ -926,6 +1049,14 @@ impl<'a> Typechecker<'a> { self.node_types[node_id.0] = type_id; } + fn new_typevar(&mut self, lower_bound: TypeId, upper_bound: TypeId) -> TypeVarId { + self.type_vars.push(TypeVar { + lower_bound, + upper_bound, + }); + TypeVarId(self.type_vars.len() - 1) + } + /// Finds a "supertype" of two types (e.g., number for float and int) fn least_common_type(&mut self, lhs: Type, rhs: Type) -> Type { match (lhs, rhs) { @@ -980,6 +1111,22 @@ impl<'a> Typechecker<'a> { } } + /// Check if `sub` is a subtype of `supe` + fn is_subtype(&self, sub: Type, supe: Type) -> bool { + match (sub, supe) { + (_, Type::Top | Type::Any | Type::Unknown) => true, + (Type::Any | Type::Unknown, _) => true, + (Type::Int | Type::Float | Type::Number, Type::Number) => true, + (Type::List(inner_sub), Type::List(inner_supe)) => { + self.is_subtype(self.types[inner_sub.0], self.types[inner_supe.0]) + } + (_, Type::OneOf(oneof_id)) => self.oneof_types[oneof_id.0] + .iter() + .any(|ty| self.is_subtype(sub, self.types[ty.0])), + (_, _) => sub == supe, + } + } + fn type_to_string(&self, type_id: TypeId) -> String { let ty = &self.types[type_id.0]; @@ -987,6 +1134,9 @@ impl<'a> Typechecker<'a> { Type::Unknown => "unknown".to_string(), Type::Forbidden => "forbidden".to_string(), Type::None => "()".to_string(), + Type::Error => "error".to_string(), + Type::Top => "top".to_string(), + Type::Bottom => "bottom".to_string(), Type::Any => "any".to_string(), Type::Number => "number".to_string(), Type::Nothing => "nothing".to_string(), @@ -1035,7 +1185,12 @@ impl<'a> Typechecker<'a> { fmt.push('>'); fmt } - Type::Error => "error".to_string(), + Type::Ref(type_decl_id) => match self.compiler.type_decls[type_decl_id.0] { + TypeDecl::Param(name_node) => { + String::from_utf8_lossy(self.compiler.get_span_contents(name_node)).to_string() + } + }, + Type::Var(type_var_id) => format!("'{}", type_var_id.0), } } @@ -1110,6 +1265,63 @@ impl<'a> Typechecker<'a> { types.insert(*ty); } } + + /// Use this to create any union types, to ensure that the created union type + /// is as simple as possible + fn create_oneof(&mut self, types: HashSet) -> TypeId { + if types.is_empty() { + // TODO return bottom type instead? + return ANY_TYPE; + } + + let mut res = HashSet::::new(); + + let mut flattened = HashSet::new(); + for ty_id in types { + match self.types[ty_id.0] { + Type::OneOf(oneof_id) => { + flattened.extend(&self.oneof_types[oneof_id.0]); + } + _ => { + flattened.insert(ty_id); + } + } + } + + for ty_id in flattened { + if res.contains(&ty_id) { + continue; + } + + let ty = self.types[ty_id.0]; + let mut add = true; + let mut remove = HashSet::new(); + for other_id in res.iter() { + let other = self.types[other_id.0]; + if self.is_subtype(ty, other) { + add = false; + break; + } + if self.is_subtype(other, ty) { + remove.insert(*other_id); + } + } + + if add { + res.insert(ty_id); + for other in remove { + res.remove(&other); + } + } + } + + if res.len() == 1 { + *res.iter().next().unwrap() + } else { + self.oneof_types.push(res); + self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) + } + } } /// Check whether two types can perform common numeric operations From adcb7b0ab01f199b91aaedf3968169d3d331f955 Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Wed, 30 Apr 2025 02:38:52 -0400 Subject: [PATCH 04/16] Type inference with generics --- ...er__test__node_output@def_generics.nu.snap | 61 ++++- src/typechecker.rs | 211 ++++++++++++++++-- tests/def_generics.nu | 5 + 3 files changed, 244 insertions(+), 33 deletions(-) diff --git a/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap index ea6e0a4..7b6aa1f 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap @@ -30,9 +30,28 @@ input_file: tests/def_generics.nu 23: List([NodeId(22)]) (59 to 62) 24: Block(BlockId(0)) (39 to 65) 25: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(16)), block: NodeId(24) } (0 to 65) -26: Block(BlockId(1)) (0 to 66) +26: Variable (71 to 73) "l1" +27: Name (76 to 77) "f" +28: Int (78 to 79) "1" +29: Call { parts: [NodeId(27), NodeId(28)] } (78 to 79) +30: Let { variable_name: NodeId(26), ty: None, initializer: NodeId(29), is_mutable: false } (67 to 79) +31: Variable (85 to 87) "l2" +32: Name (90 to 91) "f" +33: Int (92 to 93) "2" +34: Call { parts: [NodeId(32), NodeId(33)] } (92 to 93) +35: Let { variable_name: NodeId(31), ty: None, initializer: NodeId(34), is_mutable: false } (81 to 93) +36: Variable (98 to 100) "l3" +37: Name (102 to 106) "list" +38: Name (107 to 113) "number" +39: Type { name: NodeId(38), args: None, optional: false } (107 to 113) +40: TypeArgs([NodeId(39)]) (106 to 114) +41: Type { name: NodeId(37), args: Some(NodeId(40)), optional: false } (102 to 106) +42: Variable (117 to 120) "$l2" +43: Let { variable_name: NodeId(36), ty: Some(NodeId(41)), initializer: NodeId(42), is_mutable: false } (94 to 120) +44: Block(BlockId(1)) (0 to 121) ==== SCOPE ==== -0: Frame Scope, node_id: NodeId(26) +0: Frame Scope, node_id: NodeId(44) + variables: [ l1: NodeId(26), l2: NodeId(31), l3: NodeId(36) ] decls: [ f: NodeId(0) ] 1: Frame Scope, node_id: NodeId(24) variables: [ x: NodeId(3), z: NodeId(17) ] @@ -43,28 +62,46 @@ input_file: tests/def_generics.nu 2: unknown 3: unknown 4: unknown -5: unknown -6: unknown +5: T +6: T 7: forbidden 8: unknown 9: unknown 10: unknown 11: unknown -12: unknown +12: T 13: forbidden 14: unknown 15: unknown 16: unknown -17: unknown +17: T 18: unknown -19: unknown -20: unknown +19: T +20: T 21: () -22: unknown -23: list -24: list +22: T +23: list +24: list 25: () -26: () +26: list +27: unknown +28: int +29: list +30: () +31: list +32: unknown +33: int +34: list +35: () +36: list +37: unknown +38: unknown +39: number +40: forbidden +41: list +42: list +43: () +44: () ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/typechecker.rs b/src/typechecker.rs index 6d5235b..bfb2a0a 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -3,7 +3,7 @@ use crate::errors::{Severity, SourceError}; use crate::parser::{AstNode, NodeId}; use crate::resolver::{TypeDecl, TypeDeclId}; use std::cmp::Ordering; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TypeId(pub usize); @@ -198,7 +198,15 @@ impl<'a> Typechecker<'a> { if !self.compiler.ast_nodes.is_empty() { let last = self.compiler.ast_nodes.len() - 1; let last_node_id = NodeId(last); - self.typecheck_node(last_node_id) + self.typecheck_node(last_node_id); + + for i in 0..self.types.len() { + if let Type::Var(var_id) = self.types[i] { + let ub = self.type_vars[var_id.0].upper_bound; + let ub = self.types[ub.0]; + self.types[i] = ub; + } + } } } @@ -294,11 +302,11 @@ impl<'a> Typechecker<'a> { } => self.typecheck_let(variable_name, ty, initializer, node_id), AstNode::Def { name, - type_params, params, in_out_types, block, - } => self.typecheck_def(name, type_params, params, in_out_types, block, node_id), + .. + } => self.typecheck_def(name, params, in_out_types, block, node_id), AstNode::Alias { new_name, old_name } => { self.typecheck_alias(new_name, old_name, node_id) } @@ -491,9 +499,7 @@ impl<'a> Typechecker<'a> { }; self.set_node_type_id(node_id, ty_id); - let got = self.types[ty_id.0]; - let exp = self.types[expected.0]; - if !self.is_subtype(got, exp) { + if !self.constrain_subtype(ty_id, expected) { self.error( format!( "Expected {}, got {}", @@ -717,7 +723,6 @@ impl<'a> Typechecker<'a> { fn typecheck_def( &mut self, name: NodeId, - type_params: Option, params: NodeId, in_out_types: Option, block: NodeId, @@ -804,17 +809,35 @@ impl<'a> Typechecker<'a> { fn typecheck_call(&mut self, parts: &[NodeId], node_id: NodeId) -> TypeId { if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { - // TODO: The type should be `oneof` - self.set_node_type_id(node_id, ANY_TYPE); - let num_name_parts = self.compiler.decls[decl_id.0].name().split(' ').count(); let decl_node_id = self.compiler.decl_nodes[decl_id.0]; - let AstNode::Def { params, .. } = self.compiler.get_node(decl_node_id) else { + let AstNode::Def { + type_params, + params, + .. + } = self.compiler.get_node(decl_node_id) + else { panic!("Internal error: Expected def") }; let AstNode::Params(params) = self.compiler.get_node(*params) else { - panic!("Internal error: Expected def") + panic!("Internal error: Expected params") }; + + let type_substs = if let Some(type_params) = type_params { + let AstNode::Params(type_params) = self.compiler.get_node(*type_params) else { + panic!("Internal error: expected type params"); + }; + let mut type_substs = HashMap::new(); + for type_param in type_params.iter() { + let type_decl_id = self.compiler.type_resolution[type_param]; + let var = self.new_typevar(BOTTOM_TYPE, TOP_TYPE); + type_substs.insert(type_decl_id, var); + } + type_substs + } else { + HashMap::new() + }; + let num_args = parts.len() - num_name_parts; if params.len() != num_args { self.error( @@ -824,9 +847,10 @@ impl<'a> Typechecker<'a> { } for (param, arg) in params.iter().zip(&parts[num_name_parts..]) { let expected = self.type_id_of(*param); + let expected = self.subst(expected, &type_substs); if matches!(self.compiler.ast_nodes[arg.0], AstNode::Name) { self.set_node_type_id(*arg, STRING_TYPE); - if !self.is_subtype(Type::String, self.types[expected.0]) { + if !self.constrain_subtype(STRING_TYPE, expected) { self.error( format!("Expected {}, got string", self.type_to_string(expected)), *arg, @@ -848,12 +872,12 @@ impl<'a> Typechecker<'a> { } // TODO base this on pipeline input type - self.create_oneof( - self.decl_types[decl_id.0] - .iter() - .map(|io| io.out_type) - .collect(), - ) + let out_types = self.decl_types[decl_id.0] + .clone() + .iter() + .map(|io| self.subst(io.out_type, &type_substs)) + .collect(); + self.create_oneof(out_types) } else { // external call for part in &parts[1..] { @@ -1008,7 +1032,11 @@ impl<'a> Typechecker<'a> { // if bytes.contains(&b'@') { // // type with completion // } else { - UNKNOWN_TYPE + if let Some(type_decl) = self.compiler.type_resolution.get(&name_id) { + self.push_type(Type::Ref(*type_decl)) + } else { + UNKNOWN_TYPE + } // } } } @@ -1040,6 +1068,56 @@ impl<'a> Typechecker<'a> { } } + /// Replace type parameters (universal type variables) with existential type variables that can be solved + fn subst(&mut self, ty_id: TypeId, substs: &HashMap) -> TypeId { + if substs.is_empty() { + return ty_id; + } + match self.types[ty_id.0] { + Type::Unknown + | Type::Forbidden + | Type::Error + | Type::None + | Type::Top + | Type::Bottom + | Type::Any + | Type::Number + | Type::Nothing + | Type::Int + | Type::Float + | Type::Bool + | Type::String + | Type::Binary + | Type::Var(_) => ty_id, + Type::Closure => todo!(), + Type::List(elem_ty) => { + let new_elem = self.subst(elem_ty, substs); + if elem_ty == new_elem { + ty_id + } else { + self.push_type(Type::List(new_elem)) + } + } + Type::Stream(elem_ty) => { + let new_elem = self.subst(elem_ty, substs); + if elem_ty == new_elem { + ty_id + } else { + self.push_type(Type::Stream(new_elem)) + } + } + Type::Record(record_type_id) => todo!(), + Type::OneOf(one_of_id) => todo!(), + Type::Ref(type_decl_id) => { + if let Some(var) = substs.get(&type_decl_id) { + self.push_type(Type::Var(*var)) + } else { + ty_id + } + } + } + } + fn set_node_type(&mut self, node_id: NodeId, ty: Type) { let type_id = self.push_type(ty); self.node_types[node_id.0] = type_id; @@ -1111,6 +1189,75 @@ impl<'a> Typechecker<'a> { } } + /// Check if `sub` is a subtype of `supe` + /// + /// Returns `false` if there is a type mismatch, `true` otherwise + /// todo return a Result with a message about constraints not being solvable or something + fn constrain_subtype(&mut self, sub_id: TypeId, supe_id: TypeId) -> bool { + if sub_id == supe_id { + return true; + } + println!( + "Constraining {} to be a subtype of {}", + self.type_to_string(sub_id), + self.type_to_string(supe_id) + ); + match (self.types[sub_id.0], self.types[supe_id.0]) { + (_, Type::Top | Type::Any | Type::Unknown) => true, + (Type::Bottom | Type::Any | Type::Unknown, _) => true, + (Type::Int | Type::Float | Type::Number, Type::Number) => true, + (Type::List(inner_sub), Type::List(inner_supe)) => { + self.constrain_subtype(inner_sub, inner_supe) + } + (_, Type::OneOf(oneof_id)) => self.oneof_types[oneof_id.0] + .clone() + .iter() + .all(|ty| self.constrain_subtype(sub_id, *ty)), + (Type::Var(var_id), _) => { + let lb = self.type_vars[var_id.0].lower_bound; + let ub = self.type_vars[var_id.0].upper_bound; + let new_ub = self.create_intersection(ub, supe_id); + println!( + " New upper bound: {} (lb: {})", + self.type_to_string(new_ub), + self.type_to_string(lb) + ); + + // todo prevent infinite recursion here + if self.constrain_subtype(lb, new_ub) { + let var = self + .type_vars + .get_mut(var_id.0) + .expect("type variable must exist"); + var.upper_bound = new_ub; + true + } else { + println!(" New lower bound isn't a subtype of upper bound!"); + false + } + } + (_, Type::Var(var_id)) => { + let lb = self.type_vars[var_id.0].lower_bound; + let ub = self.type_vars[var_id.0].upper_bound; + let new_lb = self.create_intersection(lb, sub_id); + + // todo prevent infinite recursion here + if self.constrain_subtype(new_lb, ub) { + let var = self + .type_vars + .get_mut(var_id.0) + .expect("type variable must exist"); + var.lower_bound = new_lb; + true + } else { + false + } + } + (sub, supe) if sub == supe => true, + _ => false, + } + } + /// Check if `sub` is a subtype of `supe` fn is_subtype(&self, sub: Type, supe: Type) -> bool { match (sub, supe) { @@ -1322,6 +1469,28 @@ impl<'a> Typechecker<'a> { self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) } } + + fn create_intersection(&mut self, lhs_id: TypeId, rhs_id: TypeId) -> TypeId { + if lhs_id == rhs_id { + return lhs_id; + } + match (self.types[lhs_id.0], self.types[rhs_id.0]) { + (Type::Any | Type::Unknown, _) => lhs_id, + (_, Type::Any | Type::Unknown) => rhs_id, + (Type::Top, _) => rhs_id, + (_, Type::Top) => lhs_id, + (Type::Bottom, _) | (_, Type::Bottom) => BOTTOM_TYPE, + (Type::Number, Type::Int) | (Type::Int, Type::Number) => INT_TYPE, + (Type::Number, Type::Float) | (Type::Float, Type::Number) => FLOAT_TYPE, + (Type::List(lhs_inner), Type::List(rhs_inner)) => { + let new_inner = self.create_intersection(lhs_inner, rhs_inner); + self.push_type(Type::List(new_inner)) + } + (Type::Var(_) | Type::Ref(_), _) | (_, Type::Var(_) | Type::Ref(_)) => todo!(), + (lhs, rhs) if lhs == rhs => lhs_id, + _ => BOTTOM_TYPE, + } + } } /// Check whether two types can perform common numeric operations diff --git a/tests/def_generics.nu b/tests/def_generics.nu index 1e9b850..0784ce8 100644 --- a/tests/def_generics.nu +++ b/tests/def_generics.nu @@ -2,3 +2,8 @@ def f [ x: T ] : nothing -> list { let z: T = $x [$z] } + +let l1 = f 1 + +let l2 = f 2 +let l3: list = $l2 From 4b4acf5e6dc983d65089f60fbbcbb017ad851cb6 Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Tue, 6 May 2025 06:11:51 -0400 Subject: [PATCH 05/16] Prevent cycles in constraints --- ...ser__test__node_output@def_complex.nu.snap | 206 ++++++++++++++++++ ...er__test__node_output@def_generics.nu.snap | 4 +- ..._test__node_output@def_return_type.nu.snap | 12 +- src/typechecker.rs | 195 ++++++++++++----- tests/def_complex.nu | 7 + 5 files changed, 365 insertions(+), 59 deletions(-) create mode 100644 src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap create mode 100644 tests/def_complex.nu diff --git a/src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap new file mode 100644 index 0000000..103727c --- /dev/null +++ b/src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap @@ -0,0 +1,206 @@ +--- +source: src/test.rs +expression: evaluate_example(path) +input_file: tests/def_complex.nu +--- +==== COMPILER ==== +0: Name (4 to 5) "f" +1: Name (6 to 7) "A" +2: Name (9 to 10) "B" +3: Params([NodeId(1), NodeId(2)]) (5 to 11) +4: Name (14 to 15) "x" +5: Name (17 to 23) "record" +6: Name (24 to 25) "a" +7: Name (27 to 28) "A" +8: Type { name: NodeId(7), args: None, optional: false } (27 to 28) +9: Param { name: NodeId(6), ty: Some(NodeId(8)) } (24 to 28) +10: Name (30 to 31) "b" +11: Name (33 to 34) "B" +12: Type { name: NodeId(11), args: None, optional: false } (33 to 34) +13: Param { name: NodeId(10), ty: Some(NodeId(12)) } (30 to 34) +14: Params([NodeId(9), NodeId(13)]) (23 to 35) +15: RecordType { fields: NodeId(14), optional: false } (17 to 35) +16: Param { name: NodeId(4), ty: Some(NodeId(15)) } (14 to 35) +17: Name (37 to 38) "y" +18: Name (40 to 46) "record" +19: Name (47 to 48) "a" +20: Name (50 to 51) "A" +21: Type { name: NodeId(20), args: None, optional: false } (50 to 51) +22: Param { name: NodeId(19), ty: Some(NodeId(21)) } (47 to 51) +23: Name (53 to 54) "b" +24: Name (56 to 57) "B" +25: Type { name: NodeId(24), args: None, optional: false } (56 to 57) +26: Param { name: NodeId(23), ty: Some(NodeId(25)) } (53 to 57) +27: Params([NodeId(22), NodeId(26)]) (46 to 58) +28: RecordType { fields: NodeId(27), optional: false } (40 to 59) +29: Param { name: NodeId(17), ty: Some(NodeId(28)) } (37 to 59) +30: Params([NodeId(16), NodeId(29)]) (12 to 60) +31: Name (63 to 70) "nothing" +32: Type { name: NodeId(31), args: None, optional: false } (63 to 70) +33: Name (74 to 80) "record" +34: Name (81 to 82) "a" +35: Name (84 to 85) "A" +36: Type { name: NodeId(35), args: None, optional: false } (84 to 85) +37: Param { name: NodeId(34), ty: Some(NodeId(36)) } (81 to 85) +38: Name (87 to 88) "b" +39: Name (90 to 91) "B" +40: Type { name: NodeId(39), args: None, optional: false } (90 to 91) +41: Param { name: NodeId(38), ty: Some(NodeId(40)) } (87 to 91) +42: Params([NodeId(37), NodeId(41)]) (80 to 92) +43: RecordType { fields: NodeId(42), optional: false } (74 to 93) +44: InOutType(NodeId(32), NodeId(43)) (63 to 93) +45: InOutTypes([NodeId(44)]) (63 to 93) +46: Variable (97 to 99) "$x" +47: Block(BlockId(0)) (93 to 101) +48: Def { name: NodeId(0), type_params: Some(NodeId(3)), params: NodeId(30), in_out_types: Some(NodeId(45)), block: NodeId(47) } (0 to 101) +49: Name (106 to 116) "mysterious" +50: Name (117 to 118) "T" +51: Params([NodeId(50)]) (116 to 119) +52: Name (122 to 123) "x" +53: Name (125 to 128) "int" +54: Type { name: NodeId(53), args: None, optional: false } (125 to 128) +55: Param { name: NodeId(52), ty: Some(NodeId(54)) } (122 to 128) +56: Params([NodeId(55)]) (120 to 130) +57: Name (133 to 140) "nothing" +58: Type { name: NodeId(57), args: None, optional: false } (133 to 140) +59: Name (144 to 145) "T" +60: Type { name: NodeId(59), args: None, optional: false } (144 to 145) +61: InOutType(NodeId(58), NodeId(60)) (133 to 146) +62: InOutTypes([NodeId(61)]) (133 to 146) +63: Block(BlockId(1)) (146 to 148) +64: Def { name: NodeId(49), type_params: Some(NodeId(51)), params: NodeId(56), in_out_types: Some(NodeId(62)), block: NodeId(63) } (102 to 148) +65: Variable (154 to 155) "m" +66: Name (158 to 168) "mysterious" +67: Int (169 to 170) "0" +68: Call { parts: [NodeId(66), NodeId(67)] } (169 to 170) +69: Let { variable_name: NodeId(65), ty: None, initializer: NodeId(68), is_mutable: false } (150 to 170) +70: Variable (175 to 176) "a" +71: Name (178 to 184) "record" +72: Name (185 to 186) "a" +73: Name (188 to 194) "number" +74: Type { name: NodeId(73), args: None, optional: false } (188 to 194) +75: Param { name: NodeId(72), ty: Some(NodeId(74)) } (185 to 194) +76: Params([NodeId(75)]) (184 to 195) +77: RecordType { fields: NodeId(76), optional: false } (178 to 196) +78: Name (198 to 199) "f" +79: String (202 to 203) "a" +80: Int (205 to 208) "123" +81: String (210 to 211) "b" +82: Variable (213 to 215) "$m" +83: Record { pairs: [(NodeId(79), NodeId(80)), (NodeId(81), NodeId(82))] } (200 to 218) +84: String (220 to 221) "a" +85: Float (223 to 227) "12.3" +86: String (229 to 230) "b" +87: String (232 to 237) ""foo"" +88: Record { pairs: [(NodeId(84), NodeId(85)), (NodeId(86), NodeId(87))] } (218 to 239) +89: Call { parts: [NodeId(78), NodeId(83), NodeId(88)] } (200 to 239) +90: Let { variable_name: NodeId(70), ty: Some(NodeId(77)), initializer: NodeId(89), is_mutable: false } (171 to 239) +91: Block(BlockId(2)) (0 to 240) +==== SCOPE ==== +0: Frame Scope, node_id: NodeId(91) + variables: [ a: NodeId(70), m: NodeId(65) ] + decls: [ f: NodeId(0), mysterious: NodeId(49) ] +1: Frame Scope, node_id: NodeId(47) + variables: [ x: NodeId(4), y: NodeId(17) ] + type decls: [ A: NodeId(1), B: NodeId(2) ] +2: Frame Scope, node_id: NodeId(63) + variables: [ x: NodeId(52) ] + type decls: [ T: NodeId(50) ] +==== TYPES ==== +0: unknown +1: unknown +2: unknown +3: unknown +4: unknown +5: unknown +6: unknown +7: unknown +8: A +9: unknown +10: unknown +11: unknown +12: B +13: unknown +14: unknown +15: record +16: record +17: unknown +18: unknown +19: unknown +20: unknown +21: A +22: unknown +23: unknown +24: unknown +25: B +26: unknown +27: unknown +28: record +29: record +30: forbidden +31: unknown +32: nothing +33: unknown +34: unknown +35: unknown +36: A +37: unknown +38: unknown +39: unknown +40: B +41: unknown +42: unknown +43: record +44: unknown +45: unknown +46: record +47: record +48: () +49: unknown +50: unknown +51: unknown +52: unknown +53: unknown +54: int +55: int +56: forbidden +57: unknown +58: nothing +59: unknown +60: T +61: unknown +62: unknown +63: () +64: () +65: bottom +66: unknown +67: int +68: bottom +69: () +70: record +71: unknown +72: unknown +73: unknown +74: number +75: unknown +76: unknown +77: record +78: unknown +79: unknown +80: int +81: unknown +82: bottom +83: record +84: unknown +85: float +86: unknown +87: string +88: record +89: record +90: () +91: () +==== IR ==== +register_count: 0 +file_count: 0 +==== IR ERRORS ==== +Error (NodeId 48): node Def { name: NodeId(0), type_params: Some(NodeId(3)), params: NodeId(30), in_out_types: Some(NodeId(45)), block: NodeId(47) } not suported yet diff --git a/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap index 7b6aa1f..baf1eb4 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap @@ -66,12 +66,12 @@ input_file: tests/def_generics.nu 6: T 7: forbidden 8: unknown -9: unknown +9: nothing 10: unknown 11: unknown 12: T 13: forbidden -14: unknown +14: list 15: unknown 16: unknown 17: T diff --git a/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap index afe68ad..76247ec 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap @@ -50,12 +50,12 @@ input_file: tests/def_return_type.nu 0: unknown 1: forbidden 2: unknown -3: unknown +3: nothing 4: unknown 5: unknown 6: any 7: forbidden -8: unknown +8: list 9: unknown 10: unknown 11: list @@ -64,20 +64,20 @@ input_file: tests/def_return_type.nu 14: unknown 15: forbidden 16: unknown -17: unknown +17: string 18: unknown 19: unknown 20: string 21: forbidden -22: unknown +22: list 23: unknown 24: unknown -25: unknown +25: int 26: unknown 27: unknown 28: int 29: forbidden -30: unknown +30: list 31: unknown 32: unknown 33: list diff --git a/src/typechecker.rs b/src/typechecker.rs index bfb2a0a..0cf265c 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -201,10 +201,9 @@ impl<'a> Typechecker<'a> { self.typecheck_node(last_node_id); for i in 0..self.types.len() { - if let Type::Var(var_id) = self.types[i] { - let ub = self.type_vars[var_id.0].upper_bound; - let ub = self.types[ub.0]; - self.types[i] = ub; + let res = self.eliminate_type_vars(TypeId(i), TypeVarId(0), false); + if res.0 != i { + self.types[i] = self.types[res.0]; } } } @@ -739,25 +738,9 @@ impl<'a> Typechecker<'a> { let AstNode::InOutType(in_ty, out_ty) = self.compiler.get_node(*ty) else { panic!("internal error: return type is not a return type"); }; - let AstNode::Type { - name: in_name, - args: in_args, - optional: in_optional, - } = *self.compiler.get_node(*in_ty) - else { - panic!("internal error: type is not a type"); - }; - let AstNode::Type { - name: out_name, - args: out_args, - optional: out_optional, - } = *self.compiler.get_node(*out_ty) - else { - panic!("internal error: type is not a type"); - }; InOutType { - in_type: self.typecheck_type_ref(in_name, in_args, in_optional), - out_type: self.typecheck_type_ref(out_name, out_args, out_optional), + in_type: self.typecheck_type(*in_ty), + out_type: self.typecheck_type(*out_ty), } }) .collect::>() @@ -956,7 +939,7 @@ impl<'a> Typechecker<'a> { _ => { self.error( format!( - "unsupported/unexpected ast node '{:?}' in typechecker", + "Internal error: expected type, got '{:?}'", self.compiler.ast_nodes[node_id.0] ), node_id, @@ -1106,7 +1089,14 @@ impl<'a> Typechecker<'a> { self.push_type(Type::Stream(new_elem)) } } - Type::Record(record_type_id) => todo!(), + Type::Record(record_type_id) => { + let mut fields = self.record_types[record_type_id.0].clone(); + for (_, ty) in fields.iter_mut() { + *ty = self.subst(*ty, substs); + } + self.record_types.push(fields); + self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) + } Type::OneOf(one_of_id) => todo!(), Type::Ref(type_decl_id) => { if let Some(var) = substs.get(&type_decl_id) { @@ -1118,11 +1108,6 @@ impl<'a> Typechecker<'a> { } } - fn set_node_type(&mut self, node_id: NodeId, ty: Type) { - let type_id = self.push_type(ty); - self.node_types[node_id.0] = type_id; - } - fn set_node_type_id(&mut self, node_id: NodeId, type_id: TypeId) { self.node_types[node_id.0] = type_id; } @@ -1197,11 +1182,6 @@ impl<'a> Typechecker<'a> { if sub_id == supe_id { return true; } - println!( - "Constraining {} to be a subtype of {}", - self.type_to_string(sub_id), - self.type_to_string(supe_id) - ); match (self.types[sub_id.0], self.types[supe_id.0]) { (_, Type::Top | Type::Any | Type::Unknown) => true, (Type::Bottom | Type::Any | Type::Unknown, _) => true, @@ -1209,6 +1189,37 @@ impl<'a> Typechecker<'a> { (Type::List(inner_sub), Type::List(inner_supe)) => { self.constrain_subtype(inner_sub, inner_supe) } + (Type::Record(sub_rec_id), Type::Record(supe_rec_id)) => { + let sub_fields = self.record_types[sub_rec_id.0].clone(); + let supe_fields = self.record_types[supe_rec_id.0].clone(); + + let mut i = 0; + let mut j = 0; + while i < sub_fields.len() && j < supe_fields.len() { + let (sub_name, sub_ty) = sub_fields[i]; + let (supe_name, supe_ty) = supe_fields[j]; + let sub_text = self.compiler.get_span_contents(sub_name); + let supe_text = self.compiler.get_span_contents(supe_name); + match sub_text.cmp(supe_text) { + Ordering::Less => { + i += 1; + } + Ordering::Greater => { + // The field is in the supertype but not the subtype + return false; + } + Ordering::Equal => { + if !self.constrain_subtype(sub_ty, supe_ty) { + return false; + } + i += 1; + j += 1; + } + } + } + + true + } (_, Type::OneOf(oneof_id)) => self.oneof_types[oneof_id.0] .clone() .iter() @@ -1217,11 +1228,8 @@ impl<'a> Typechecker<'a> { let lb = self.type_vars[var_id.0].lower_bound; let ub = self.type_vars[var_id.0].upper_bound; let new_ub = self.create_intersection(ub, supe_id); - println!( - " New upper bound: {} (lb: {})", - self.type_to_string(new_ub), - self.type_to_string(lb) - ); + // Prevent forward references/cycles + let new_ub = self.eliminate_type_vars(new_ub, var_id, true); // todo prevent infinite recursion here if self.constrain_subtype(lb, new_ub) { @@ -1232,7 +1240,6 @@ impl<'a> Typechecker<'a> { var.upper_bound = new_ub; true } else { - println!(" New lower bound isn't a subtype of upper bound!"); false } } @@ -1240,8 +1247,9 @@ impl<'a> Typechecker<'a> { let lb = self.type_vars[var_id.0].lower_bound; let ub = self.type_vars[var_id.0].upper_bound; let new_lb = self.create_intersection(lb, sub_id); + // Prevent forward references/cycles + let new_lb = self.eliminate_type_vars(new_lb, var_id, false); - // todo prevent infinite recursion here if self.constrain_subtype(new_lb, ub) { let var = self .type_vars @@ -1259,18 +1267,105 @@ impl<'a> Typechecker<'a> { } /// Check if `sub` is a subtype of `supe` - fn is_subtype(&self, sub: Type, supe: Type) -> bool { - match (sub, supe) { + fn is_subtype(&self, sub: TypeId, supe: TypeId) -> bool { + if sub == supe { + return true; + } + match (self.types[sub.0], self.types[supe.0]) { (_, Type::Top | Type::Any | Type::Unknown) => true, - (Type::Any | Type::Unknown, _) => true, + (Type::Bottom | Type::Any | Type::Unknown, _) => true, (Type::Int | Type::Float | Type::Number, Type::Number) => true, (Type::List(inner_sub), Type::List(inner_supe)) => { - self.is_subtype(self.types[inner_sub.0], self.types[inner_supe.0]) + self.is_subtype(inner_sub, inner_supe) + } + (Type::Var(var_id), _) => { + let var = &self.type_vars[var_id.0]; + self.is_subtype(var.upper_bound, supe) + } + (_, Type::Var(var_id)) => { + let var = &self.type_vars[var_id.0]; + self.is_subtype(sub, var.lower_bound) } (_, Type::OneOf(oneof_id)) => self.oneof_types[oneof_id.0] .iter() - .any(|ty| self.is_subtype(sub, self.types[ty.0])), - (_, _) => sub == supe, + .any(|ty| self.is_subtype(sub, *ty)), + (sub, supe) => sub == supe, + } + } + + /// Eliminate all type variables that are greater than or equal to `max_var` + /// * `use_lower`: If true, replace type variables with their lower bound. + /// Otherwise, replace with their upper bound + fn eliminate_type_vars( + &mut self, + ty_id: TypeId, + max_var: TypeVarId, + use_lower: bool, + ) -> TypeId { + match self.types[ty_id.0] { + Type::Unknown + | Type::Forbidden + | Type::Error + | Type::None + | Type::Top + | Type::Bottom + | Type::Any + | Type::Number + | Type::Nothing + | Type::Int + | Type::Float + | Type::Bool + | Type::String + | Type::Binary + | Type::Ref(_) => ty_id, + Type::Closure => ty_id, + Type::List(inner_ty) => { + let new_inner = self.eliminate_type_vars(inner_ty, max_var, use_lower); + if inner_ty == new_inner { + ty_id + } else { + self.push_type(Type::List(new_inner)) + } + } + Type::Stream(inner_ty) => { + let new_inner = self.eliminate_type_vars(inner_ty, max_var, use_lower); + if inner_ty == new_inner { + ty_id + } else { + self.push_type(Type::Stream(new_inner)) + } + } + Type::Record(record_type_id) => { + let mut changed = false; + let mut fields = self.record_types[record_type_id.0].clone(); + for (_, ty) in fields.iter_mut() { + let res = self.eliminate_type_vars(*ty, max_var, use_lower); + if res != *ty { + *ty = res; + changed = true; + } + } + if changed { + self.record_types.push(fields); + self.push_type(Type::Record(RecordTypeId(self.record_types.len()))) + } else { + ty_id + } + } + Type::OneOf(one_of_id) => todo!(), + Type::Var(var_id) => { + if var_id.0 < max_var.0 { + ty_id + } else { + let var = &self.type_vars[var_id.0]; + let repl = if use_lower { + var.lower_bound + } else { + var.upper_bound + }; + self.eliminate_type_vars(repl, max_var, use_lower) + } + } } } @@ -1440,16 +1535,14 @@ impl<'a> Typechecker<'a> { continue; } - let ty = self.types[ty_id.0]; let mut add = true; let mut remove = HashSet::new(); for other_id in res.iter() { - let other = self.types[other_id.0]; - if self.is_subtype(ty, other) { + if self.is_subtype(ty_id, *other_id) { add = false; break; } - if self.is_subtype(other, ty) { + if self.is_subtype(*other_id, ty_id) { remove.insert(*other_id); } } diff --git a/tests/def_complex.nu b/tests/def_complex.nu new file mode 100644 index 0000000..555f5b9 --- /dev/null +++ b/tests/def_complex.nu @@ -0,0 +1,7 @@ +def f [ x: record, y: record ] : nothing -> record { + $x +} +def mysterious [ x: int ] : nothing -> T {} + +let m = mysterious 0 +let a: record = f { a: 123, b: $m } { a: 12.3, b: "foo" } From cb3fc781587598fdad4613fb4582a3df70671791 Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Tue, 6 May 2025 18:59:18 -0400 Subject: [PATCH 06/16] Improve create_oneof; rewrite bin ops --- ...test__node_output@binary_ops_exact.nu.snap | 91 ++-- ...t__node_output@binary_ops_mismatch.nu.snap | 22 +- ...t__node_output@binary_ops_subtypes.nu.snap | 208 +++++---- src/typechecker.rs | 441 +++++++++--------- tests/binary_ops_exact.nu | 2 +- tests/binary_ops_subtypes.nu | 4 +- 6 files changed, 393 insertions(+), 375 deletions(-) diff --git a/src/snapshots/new_nu_parser__test__node_output@binary_ops_exact.nu.snap b/src/snapshots/new_nu_parser__test__node_output@binary_ops_exact.nu.snap index 9233e08..da32b54 100644 --- a/src/snapshots/new_nu_parser__test__node_output@binary_ops_exact.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@binary_ops_exact.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/binary_ops_exact.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Int (0 to 1) "1" @@ -12,33 +11,34 @@ snapshot_kind: text 4: True (8 to 12) 5: List([NodeId(4)]) (7 to 12) 6: Append (14 to 16) -7: False (17 to 22) -8: BinaryOp { lhs: NodeId(5), op: NodeId(6), rhs: NodeId(7) } (7 to 22) -9: Int (23 to 24) "1" -10: Plus (25 to 26) -11: Int (27 to 28) "1" -12: BinaryOp { lhs: NodeId(9), op: NodeId(10), rhs: NodeId(11) } (23 to 28) -13: Float (29 to 32) "1.0" -14: Plus (33 to 34) -15: Float (35 to 38) "1.0" -16: BinaryOp { lhs: NodeId(13), op: NodeId(14), rhs: NodeId(15) } (29 to 38) -17: True (39 to 43) -18: And (44 to 47) -19: False (48 to 53) -20: BinaryOp { lhs: NodeId(17), op: NodeId(18), rhs: NodeId(19) } (39 to 53) -21: String (54 to 59) ""foo"" -22: RegexMatch (60 to 62) -23: String (63 to 68) "".*o"" -24: BinaryOp { lhs: NodeId(21), op: NodeId(22), rhs: NodeId(23) } (54 to 68) -25: Int (69 to 70) "1" -26: In (71 to 73) -27: Int (75 to 76) "1" -28: Int (78 to 79) "2" -29: List([NodeId(27), NodeId(28)]) (74 to 79) -30: BinaryOp { lhs: NodeId(25), op: NodeId(26), rhs: NodeId(29) } (69 to 79) -31: Block(BlockId(0)) (0 to 81) +7: False (18 to 23) +8: List([NodeId(7)]) (17 to 23) +9: BinaryOp { lhs: NodeId(5), op: NodeId(6), rhs: NodeId(8) } (7 to 23) +10: Int (25 to 26) "1" +11: Plus (27 to 28) +12: Int (29 to 30) "1" +13: BinaryOp { lhs: NodeId(10), op: NodeId(11), rhs: NodeId(12) } (25 to 30) +14: Float (31 to 34) "1.0" +15: Plus (35 to 36) +16: Float (37 to 40) "1.0" +17: BinaryOp { lhs: NodeId(14), op: NodeId(15), rhs: NodeId(16) } (31 to 40) +18: True (41 to 45) +19: And (46 to 49) +20: False (50 to 55) +21: BinaryOp { lhs: NodeId(18), op: NodeId(19), rhs: NodeId(20) } (41 to 55) +22: String (56 to 61) ""foo"" +23: RegexMatch (62 to 64) +24: String (65 to 70) "".*o"" +25: BinaryOp { lhs: NodeId(22), op: NodeId(23), rhs: NodeId(24) } (56 to 70) +26: Int (71 to 72) "1" +27: In (73 to 75) +28: Int (77 to 78) "1" +29: Int (80 to 81) "2" +30: List([NodeId(28), NodeId(29)]) (76 to 81) +31: BinaryOp { lhs: NodeId(26), op: NodeId(27), rhs: NodeId(30) } (71 to 81) +32: Block(BlockId(0)) (0 to 83) ==== SCOPE ==== -0: Frame Scope, node_id: NodeId(31) (empty) +0: Frame Scope, node_id: NodeId(32) (empty) ==== TYPES ==== 0: int 1: forbidden @@ -49,29 +49,30 @@ snapshot_kind: text 6: forbidden 7: bool 8: list -9: int -10: forbidden -11: int +9: list +10: int +11: forbidden 12: int -13: float -14: forbidden -15: float +13: int +14: float +15: forbidden 16: float -17: bool -18: forbidden -19: bool +17: float +18: bool +19: forbidden 20: bool -21: string -22: forbidden -23: string -24: bool -25: int -26: forbidden -27: int +21: bool +22: string +23: forbidden +24: string +25: bool +26: int +27: forbidden 28: int -29: list -30: bool +29: int +30: list 31: bool +32: bool ==== IR ==== register_count: 2 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap b/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap index fa39f5a..70bf21f 100644 --- a/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap @@ -25,27 +25,29 @@ input_file: tests/binary_ops_mismatch.nu 0: Frame Scope, node_id: NodeId(16) (empty) ==== TYPES ==== 0: string -1: error +1: forbidden 2: float -3: error +3: string 4: string 5: error 6: float 7: error 8: bool -9: error +9: forbidden 10: string -11: error +11: bool 12: bool -13: error +13: forbidden 14: string -15: error -16: error +15: bool +16: bool ==== TYPE ERRORS ==== -Error (NodeId 1): type mismatch: unsupported addition between string and float +Error (NodeId 2): Expected string, got float +Error (NodeId 4): Expected list, got string +Error (NodeId 6): Expected list, got float Error (NodeId 5): type mismatch: unsupported append between string and float -Error (NodeId 9): type mismatch: unsupported logical operation between bool and string -Error (NodeId 13): type mismatch: unsupported string operation between bool and string +Error (NodeId 10): Expected bool, got string +Error (NodeId 12): Expected string, got bool ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap b/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap index 9c76012..39dcc85 100644 --- a/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap @@ -2,7 +2,6 @@ source: src/test.rs expression: evaluate_example(path) input_file: tests/binary_ops_subtypes.nu -snapshot_kind: text --- ==== COMPILER ==== 0: Int (0 to 1) "1" @@ -20,63 +19,65 @@ snapshot_kind: text 12: Int (29 to 30) "1" 13: List([NodeId(12)]) (28 to 30) 14: Append (32 to 34) -15: Float (35 to 38) "1.0" -16: BinaryOp { lhs: NodeId(13), op: NodeId(14), rhs: NodeId(15) } (28 to 38) -17: Float (40 to 43) "1.0" -18: Int (44 to 45) "1" -19: List([NodeId(17), NodeId(18)]) (39 to 45) -20: Append (47 to 49) -21: String (50 to 53) ""a"" -22: BinaryOp { lhs: NodeId(19), op: NodeId(20), rhs: NodeId(21) } (39 to 53) -23: Int (56 to 57) "1" -24: List([NodeId(23)]) (55 to 57) -25: Int (60 to 61) "2" +15: Float (36 to 39) "1.0" +16: List([NodeId(15)]) (35 to 39) +17: BinaryOp { lhs: NodeId(13), op: NodeId(14), rhs: NodeId(16) } (28 to 39) +18: Float (42 to 45) "1.0" +19: Int (46 to 47) "1" +20: List([NodeId(18), NodeId(19)]) (41 to 47) +21: Append (49 to 51) +22: String (53 to 56) ""a"" +23: List([NodeId(22)]) (52 to 56) +24: BinaryOp { lhs: NodeId(20), op: NodeId(21), rhs: NodeId(23) } (41 to 56) +25: Int (60 to 61) "1" 26: List([NodeId(25)]) (59 to 61) -27: List([NodeId(24), NodeId(26)]) (54 to 62) -28: Append (64 to 66) -29: Int (69 to 70) "3" -30: List([NodeId(29)]) (68 to 70) -31: List([NodeId(30)]) (67 to 71) -32: BinaryOp { lhs: NodeId(27), op: NodeId(28), rhs: NodeId(31) } (54 to 71) -33: Int (75 to 76) "1" -34: List([NodeId(33)]) (74 to 76) -35: Int (79 to 80) "2" +27: Int (64 to 65) "2" +28: List([NodeId(27)]) (63 to 65) +29: List([NodeId(26), NodeId(28)]) (58 to 66) +30: Append (68 to 70) +31: Int (73 to 74) "3" +32: List([NodeId(31)]) (72 to 74) +33: List([NodeId(32)]) (71 to 75) +34: BinaryOp { lhs: NodeId(29), op: NodeId(30), rhs: NodeId(33) } (58 to 75) +35: Int (79 to 80) "1" 36: List([NodeId(35)]) (78 to 80) -37: List([NodeId(34), NodeId(36)]) (73 to 81) -38: Append (83 to 85) -39: Float (88 to 91) "3.0" -40: List([NodeId(39)]) (87 to 91) -41: List([NodeId(40)]) (86 to 92) -42: BinaryOp { lhs: NodeId(37), op: NodeId(38), rhs: NodeId(41) } (73 to 92) -43: Int (94 to 95) "1" -44: In (96 to 98) -45: Float (100 to 103) "1.0" -46: Int (105 to 106) "1" -47: List([NodeId(45), NodeId(46)]) (99 to 106) -48: BinaryOp { lhs: NodeId(43), op: NodeId(44), rhs: NodeId(47) } (94 to 106) -49: Float (108 to 111) "2.3" -50: Modulo (112 to 115) -51: Int (116 to 117) "1" -52: BinaryOp { lhs: NodeId(49), op: NodeId(50), rhs: NodeId(51) } (108 to 117) -53: String (120 to 121) "b" -54: Int (123 to 124) "2" -55: String (126 to 127) "c" -56: Int (129 to 130) "3" -57: Record { pairs: [(NodeId(53), NodeId(54)), (NodeId(55), NodeId(56))] } (119 to 131) -58: List([NodeId(57)]) (118 to 131) -59: Append (133 to 135) -60: String (138 to 139) "a" -61: Int (141 to 142) "3" -62: String (144 to 145) "b" -63: Float (147 to 150) "1.5" -64: String (152 to 153) "c" -65: String (155 to 160) ""foo"" -66: Record { pairs: [(NodeId(60), NodeId(61)), (NodeId(62), NodeId(63)), (NodeId(64), NodeId(65))] } (137 to 161) -67: List([NodeId(66)]) (136 to 161) -68: BinaryOp { lhs: NodeId(58), op: NodeId(59), rhs: NodeId(67) } (118 to 161) -69: Block(BlockId(0)) (0 to 163) +37: Int (83 to 84) "2" +38: List([NodeId(37)]) (82 to 84) +39: List([NodeId(36), NodeId(38)]) (77 to 85) +40: Append (87 to 89) +41: Float (92 to 95) "3.0" +42: List([NodeId(41)]) (91 to 95) +43: List([NodeId(42)]) (90 to 96) +44: BinaryOp { lhs: NodeId(39), op: NodeId(40), rhs: NodeId(43) } (77 to 96) +45: Int (98 to 99) "1" +46: In (100 to 102) +47: Float (104 to 107) "1.0" +48: Int (109 to 110) "1" +49: List([NodeId(47), NodeId(48)]) (103 to 110) +50: BinaryOp { lhs: NodeId(45), op: NodeId(46), rhs: NodeId(49) } (98 to 110) +51: Float (112 to 115) "2.3" +52: Modulo (116 to 119) +53: Int (120 to 121) "1" +54: BinaryOp { lhs: NodeId(51), op: NodeId(52), rhs: NodeId(53) } (112 to 121) +55: String (124 to 125) "b" +56: Int (127 to 128) "2" +57: String (130 to 131) "c" +58: Int (133 to 134) "3" +59: Record { pairs: [(NodeId(55), NodeId(56)), (NodeId(57), NodeId(58))] } (123 to 135) +60: List([NodeId(59)]) (122 to 135) +61: Append (137 to 139) +62: String (142 to 143) "a" +63: Int (145 to 146) "3" +64: String (148 to 149) "b" +65: Float (151 to 154) "1.5" +66: String (156 to 157) "c" +67: String (159 to 164) ""foo"" +68: Record { pairs: [(NodeId(62), NodeId(63)), (NodeId(64), NodeId(65)), (NodeId(66), NodeId(67))] } (141 to 165) +69: List([NodeId(68)]) (140 to 165) +70: BinaryOp { lhs: NodeId(60), op: NodeId(61), rhs: NodeId(69) } (122 to 165) +71: Block(BlockId(0)) (0 to 167) ==== SCOPE ==== -0: Frame Scope, node_id: NodeId(69) (empty) +0: Frame Scope, node_id: NodeId(71) (empty) ==== TYPES ==== 0: int 1: forbidden @@ -94,60 +95,65 @@ snapshot_kind: text 13: list 14: forbidden 15: float -16: list -17: float -18: int -19: list -20: forbidden -21: string -22: list -23: int -24: list +16: list +17: list +18: float +19: int +20: list +21: forbidden +22: string +23: list +24: list> 25: int 26: list -27: list> -28: forbidden -29: int -30: list -31: list> -32: list> -33: int -34: list +27: int +28: list +29: list> +30: forbidden +31: int +32: list +33: list> +34: list> 35: int 36: list -37: list> -38: forbidden -39: float -40: list -41: list> -42: list> -43: int -44: forbidden -45: float -46: int -47: list -48: bool -49: float -50: forbidden -51: int -52: float -53: unknown -54: int +37: int +38: list +39: list> +40: forbidden +41: float +42: list +43: list> +44: list> +45: int +46: forbidden +47: float +48: int +49: list +50: bool +51: float +52: forbidden +53: int +54: float 55: unknown 56: int -57: record -58: list> -59: forbidden -60: unknown -61: int +57: unknown +58: int +59: record +60: list> +61: forbidden 62: unknown -63: float +63: int 64: unknown -65: string -66: record -67: list> -68: list> -69: list> +65: float +66: unknown +67: string +68: record +69: list> +70: list>> +71: list>> +==== TYPE ERRORS ==== +Error (NodeId 2): Expected int, got float +Error (NodeId 6): Expected string, got float ==== IR ==== register_count: 1 file_count: 0 diff --git a/src/typechecker.rs b/src/typechecker.rs index 0cf265c..0def748 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -102,7 +102,8 @@ pub struct Typechecker<'a> { /// Record fields used for `RecordType`. Each value in this vector matches with the index in RecordTypeId. /// The individual field lists are stored sorted by field name. pub record_types: Vec>, - /// Types used for `OneOf`. Each value in this vector matches with the index in OneOfId + /// Types used for `OneOf`. Each value in this vector matches with the index in OneOfId. + /// oneof types should not be nested and should have at least two elements pub oneof_types: Vec>, /// Type variables, indexed by TypeVarId pub type_vars: Vec, @@ -453,14 +454,9 @@ impl<'a> Typechecker<'a> { self.typecheck_expr(else_blk, expected) }; let mut types = HashSet::new(); - self.add_resolved_types(&mut types, &then_type_id); - self.add_resolved_types(&mut types, &else_type_id); - if types.len() > 1 { - self.oneof_types.push(types); - self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) - } else { - *types.iter().next().expect("Can't be empty") - } + types.insert(then_type_id); + types.insert(else_type_id); + self.create_oneof(types) } else { // If there's no else block, the if expression is a statement NONE_TYPE @@ -473,16 +469,10 @@ impl<'a> Typechecker<'a> { } => { // Check all the output types of match let output_types = self.typecheck_match(target, match_arms, expected); - match output_types.len().cmp(&1) { - Ordering::Greater => { - self.oneof_types.push(output_types); - self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) - } - Ordering::Equal => *output_types - .iter() - .next() - .expect("Will contain one element"), - Ordering::Less => NOTHING_TYPE, + if output_types.is_empty() { + NOTHING_TYPE + } else { + self.create_oneof(output_types) } } _ => { @@ -592,25 +582,22 @@ impl<'a> Typechecker<'a> { } fn typecheck_binary_op(&mut self, lhs: NodeId, op: NodeId, rhs: NodeId) -> TypeId { - self.typecheck_expr(lhs, TOP_TYPE); - self.typecheck_expr(rhs, TOP_TYPE); self.set_node_type_id(op, FORBIDDEN_TYPE); - let lhs_type = self.type_of(lhs); - let rhs_type = self.type_of(rhs); - - let out_type = match self.compiler.ast_nodes[op.0] { - AstNode::Equal | AstNode::NotEqual => Some(Type::Bool), + // todo better error messages for type mismatches, the previous messages were better + match self.compiler.ast_nodes[op.0] { + AstNode::Equal | AstNode::NotEqual => { + let lhs_ty = self.typecheck_expr(lhs, TOP_TYPE); + self.typecheck_expr(rhs, lhs_ty); + BOOL_TYPE + } AstNode::LessThan | AstNode::GreaterThan | AstNode::LessThanOrEqual | AstNode::GreaterThanOrEqual => { - if check_numeric_op(lhs_type, rhs_type) == Type::Unknown { - self.binary_op_err("comparison", lhs, op, rhs); - None - } else { - Some(Type::Bool) - } + self.typecheck_expr(lhs, NUMBER_TYPE); + self.typecheck_expr(rhs, NUMBER_TYPE); + BOOL_TYPE } AstNode::Minus | AstNode::Multiply @@ -618,88 +605,86 @@ impl<'a> Typechecker<'a> { | AstNode::FloorDiv | AstNode::Modulo | AstNode::Pow => { - let type_id = check_numeric_op(lhs_type, rhs_type); - - if type_id == Type::Unknown { - self.binary_op_err("math operation", lhs, op, rhs); - None - } else { - Some(type_id) + let lhs_ty = self.typecheck_expr(lhs, NUMBER_TYPE); + let rhs_ty = self.typecheck_expr(rhs, NUMBER_TYPE); + + match (self.types[lhs_ty.0], self.types[rhs_ty.0]) { + (Type::Int, Type::Int) => INT_TYPE, + (Type::Int, Type::Float) => FLOAT_TYPE, + (Type::Float, Type::Int) => FLOAT_TYPE, + (Type::Float, Type::Float) => FLOAT_TYPE, + _ => NUMBER_TYPE, } } - AstNode::RegexMatch | AstNode::NotRegexMatch => match (lhs_type, rhs_type) { - (Type::String | Type::Any, Type::String | Type::Any) => Some(Type::Bool), - _ => { - self.binary_op_err("string operation", lhs, op, rhs); - None - } - }, - AstNode::In => match rhs_type { - Type::String => match lhs_type { - Type::String | Type::Any => Some(Type::Bool), - _ => { - self.binary_op_err("string operation", lhs, op, rhs); - None + AstNode::RegexMatch | AstNode::NotRegexMatch => { + self.typecheck_expr(lhs, STRING_TYPE); + self.typecheck_expr(rhs, STRING_TYPE); + BOOL_TYPE + } + AstNode::In => { + let rhs_type = self.typecheck_expr(rhs, TOP_TYPE); + match self.types[rhs_type.0] { + Type::String => { + self.typecheck_expr(lhs, STRING_TYPE); + BOOL_TYPE } - }, - Type::List(elem_ty) => { - if self.is_type_compatible(lhs_type, self.types[elem_ty.0]) { - Some(Type::Bool) - } else { - self.binary_op_err("list operation", lhs, op, rhs); - None + Type::List(elem_ty) => { + self.typecheck_expr(lhs, elem_ty); + BOOL_TYPE + } + Type::Any | Type::Bottom => { + self.typecheck_expr(lhs, TOP_TYPE); + BOOL_TYPE + } + _ => { + self.binary_op_err("list/string operation", lhs, op, rhs); + ERROR_TYPE } } - Type::Any => Some(Type::Bool), - _ => { - self.binary_op_err("list/string operation", lhs, op, rhs); - None - } - }, - AstNode::And | AstNode::Xor | AstNode::Or => match (lhs_type, rhs_type) { - (Type::Bool, Type::Bool) => Some(Type::Bool), - _ => { - self.binary_op_err("logical operation", lhs, op, rhs); - None - } - }, + } + AstNode::And | AstNode::Xor | AstNode::Or => { + self.typecheck_expr(lhs, BOOL_TYPE); + self.typecheck_expr(rhs, BOOL_TYPE); + BOOL_TYPE + } AstNode::Plus => { - let ty = check_plus_op(lhs_type, rhs_type); - - if ty == Type::Unknown { - self.binary_op_err("addition", lhs, op, rhs); - None + let lhs_ty = self.typecheck_expr(lhs, TOP_TYPE); + if self.is_subtype(lhs_ty, STRING_TYPE) { + self.typecheck_expr(rhs, STRING_TYPE); + STRING_TYPE + } else if self.is_subtype(lhs_ty, NUMBER_TYPE) { + let rhs_ty = self.typecheck_expr(rhs, NUMBER_TYPE); + match (self.types[lhs_ty.0], self.types[rhs_ty.0]) { + (Type::Int, Type::Int) => INT_TYPE, + (Type::Int, Type::Float) => FLOAT_TYPE, + (Type::Float, Type::Int) => FLOAT_TYPE, + (Type::Float, Type::Float) => FLOAT_TYPE, + _ => NUMBER_TYPE, + } } else { - Some(ty) + self.binary_op_err("string/number operation", lhs, op, rhs); + ERROR_TYPE } } AstNode::Append => { - let lhs_type = self.type_of(lhs); - let rhs_type = self.type_of(rhs); - - match (lhs_type, rhs_type) { - (Type::List(lhs_item_id), Type::List(rhs_item_id)) => { - let lhs_item_type = self.types[lhs_item_id.0]; - let rhs_item_type = self.types[rhs_item_id.0]; - let common_type = self.least_common_type(lhs_item_type, rhs_item_type); - let common_type_id = self.push_type(common_type); - Some(Type::List(common_type_id)) - } - (Type::List(item_id), rhs_type) => { - let item_type = self.types[item_id.0]; - let common_type = self.least_common_type(item_type, rhs_type); - let common_type_id = self.push_type(common_type); - Some(Type::List(common_type_id)) - } - (lhs_type, Type::List(item_id)) => { - let item_type = self.types[item_id.0]; - let common_type = self.least_common_type(lhs_type, item_type); - let common_type_id = self.push_type(common_type); - Some(Type::List(common_type_id)) + // TODO cache this type + let list_ty = self.push_type(Type::List(TOP_TYPE)); + let lhs_type = self.typecheck_expr(lhs, list_ty); + let rhs_type = self.typecheck_expr(rhs, list_ty); + + //todo account for any + match (self.types[lhs_type.0], self.types[rhs_type.0]) { + (Type::List(lhs_item), Type::List(rhs_item)) => { + let mut types = HashSet::new(); + types.insert(lhs_item); + types.insert(rhs_item); + let common_type = self.create_oneof(types); + self.push_type(Type::List(common_type)) } + (_, Type::Any | Type::Bottom) | (Type::Any | Type::Bottom, _) => ANY_TYPE, _ => { self.binary_op_err("append", lhs, op, rhs); - None + ERROR_TYPE } } } @@ -708,14 +693,13 @@ impl<'a> Typechecker<'a> { | AstNode::SubtractAssignment | AstNode::MultiplyAssignment | AstNode::DivideAssignment - | AstNode::AppendAssignment => Some(Type::None), + | AstNode::AppendAssignment => { + // todo actually check if operands are right for operator + self.typecheck_expr(lhs, TOP_TYPE); + self.typecheck_expr(rhs, TOP_TYPE); + NONE_TYPE + } _ => panic!("internal error: unsupported node passed as binary op: {op:?}"), - }; - - if let Some(ty) = out_type { - self.push_type(ty) - } else { - ERROR_TYPE } } @@ -1097,7 +1081,15 @@ impl<'a> Typechecker<'a> { self.record_types.push(fields); self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) } - Type::OneOf(one_of_id) => todo!(), + Type::OneOf(one_of_id) => { + let orig_types = self.oneof_types[one_of_id.0].clone(); + let mut new_types = HashSet::new(); + for ty in orig_types.iter() { + new_types.insert(self.subst(*ty, substs)); + } + self.oneof_types.push(orig_types); + self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) + } Type::Ref(type_decl_id) => { if let Some(var) = substs.get(&type_decl_id) { self.push_type(Type::Var(*var)) @@ -1120,60 +1112,6 @@ impl<'a> Typechecker<'a> { TypeVarId(self.type_vars.len() - 1) } - /// Finds a "supertype" of two types (e.g., number for float and int) - fn least_common_type(&mut self, lhs: Type, rhs: Type) -> Type { - match (lhs, rhs) { - (Type::List(lhs_id), Type::List(rhs_id)) => { - let item_type = self.least_common_type(self.types[lhs_id.0], self.types[rhs_id.0]); - let item_type_id = self.push_type(item_type); - Type::List(item_type_id) - } - (Type::Record(lhs_id), Type::Record(rhs_id)) => { - let mut common_fields = Vec::new(); - - let mut l = 0; - let mut r = 0; - while l < self.record_types[lhs_id.0].len() && r < self.record_types[rhs_id.0].len() - { - let (lhs_name, lhs_ty) = self.record_types[lhs_id.0][l]; - let (rhs_name, rhs_ty) = self.record_types[rhs_id.0][r]; - let lhs_text = self.compiler.get_span_contents(lhs_name); - let rhs_text = self.compiler.get_span_contents(rhs_name); - match lhs_text.cmp(rhs_text) { - Ordering::Less => { - l += 1; - } - Ordering::Greater => { - r += 1; - } - Ordering::Equal => { - let field_ty = - self.least_common_type(self.types[lhs_ty.0], self.types[rhs_ty.0]); - let field_ty_id = self.push_type(field_ty); - common_fields.push((lhs_name, field_ty_id)); - l += 1; - r += 1; - } - } - } - - self.record_types.push(common_fields); - Type::Record(RecordTypeId(self.record_types.len() - 1)) - } - (Type::Int, Type::Float) => Type::Number, - (Type::Int, Type::Number) => Type::Number, - (Type::Float, Type::Float) => Type::Number, - (Type::Float, Type::Number) => Type::Number, - _ => { - if lhs == rhs { - lhs - } else { - Type::Any - } - } - } - } - /// Check if `sub` is a subtype of `supe` /// /// Returns `false` if there is a type mismatch, `true` otherwise @@ -1267,6 +1205,7 @@ impl<'a> Typechecker<'a> { } /// Check if `sub` is a subtype of `supe` + /// todo reduce duplication between this and constrain_subtype fn is_subtype(&self, sub: TypeId, supe: TypeId) -> bool { if sub == supe { return true; @@ -1278,6 +1217,37 @@ impl<'a> Typechecker<'a> { (Type::List(inner_sub), Type::List(inner_supe)) => { self.is_subtype(inner_sub, inner_supe) } + (Type::Record(sub_rec_id), Type::Record(supe_rec_id)) => { + let sub_fields = self.record_types[sub_rec_id.0].clone(); + let supe_fields = self.record_types[supe_rec_id.0].clone(); + + let mut i = 0; + let mut j = 0; + while i < sub_fields.len() && j < supe_fields.len() { + let (sub_name, sub_ty) = sub_fields[i]; + let (supe_name, supe_ty) = supe_fields[j]; + let sub_text = self.compiler.get_span_contents(sub_name); + let supe_text = self.compiler.get_span_contents(supe_name); + match sub_text.cmp(supe_text) { + Ordering::Less => { + i += 1; + } + Ordering::Greater => { + // The field is in the supertype but not the subtype + return false; + } + Ordering::Equal => { + if !self.is_subtype(sub_ty, supe_ty) { + return false; + } + i += 1; + j += 1; + } + } + } + + true + } (Type::Var(var_id), _) => { let var = &self.type_vars[var_id.0]; self.is_subtype(var.upper_bound, supe) @@ -1347,12 +1317,20 @@ impl<'a> Typechecker<'a> { } if changed { self.record_types.push(fields); - self.push_type(Type::Record(RecordTypeId(self.record_types.len()))) + self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) } else { ty_id } } - Type::OneOf(one_of_id) => todo!(), + Type::OneOf(one_of_id) => { + let orig_types = self.oneof_types[one_of_id.0].clone(); + let mut new_types = HashSet::new(); + for ty in orig_types.iter() { + new_types.insert(self.eliminate_type_vars(*ty, max_var, use_lower)); + } + self.oneof_types.push(orig_types); + self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) + } Type::Var(var_id) => { if var_id.0 < max_var.0 { ty_id @@ -1511,13 +1489,6 @@ impl<'a> Typechecker<'a> { /// Use this to create any union types, to ensure that the created union type /// is as simple as possible fn create_oneof(&mut self, types: HashSet) -> TypeId { - if types.is_empty() { - // TODO return bottom type instead? - return ANY_TYPE; - } - - let mut res = HashSet::::new(); - let mut flattened = HashSet::new(); for ty_id in types { match self.types[ty_id.0] { @@ -1530,35 +1501,100 @@ impl<'a> Typechecker<'a> { } } + if flattened.is_empty() { + return BOTTOM_TYPE; + } + + let mut simple_types = HashSet::::new(); + let mut list_elems = HashSet::new(); + let mut record_fields = HashMap::<&[u8], (NodeId, HashSet)>::new(); for ty_id in flattened { - if res.contains(&ty_id) { + if simple_types.contains(&ty_id) { + continue; + } + + let ty = self.types[ty_id.0]; + + if ty == Type::Int && simple_types.contains(&FLOAT_TYPE) { + simple_types.remove(&FLOAT_TYPE); + simple_types.insert(NUMBER_TYPE); + continue; + } + if ty == Type::Float && simple_types.contains(&INT_TYPE) { + simple_types.remove(&INT_TYPE); + simple_types.insert(NUMBER_TYPE); continue; } - let mut add = true; - let mut remove = HashSet::new(); - for other_id in res.iter() { - if self.is_subtype(ty_id, *other_id) { - add = false; - break; + match ty { + Type::Int if simple_types.contains(&FLOAT_TYPE) => { + simple_types.remove(&FLOAT_TYPE); + simple_types.insert(NUMBER_TYPE); } - if self.is_subtype(*other_id, ty_id) { - remove.insert(*other_id); + Type::Float if simple_types.contains(&INT_TYPE) => { + simple_types.remove(&INT_TYPE); + simple_types.insert(NUMBER_TYPE); } - } + Type::List(elem_ty) => { + list_elems.insert(elem_ty); + } + Type::Record(rec_ty_id) => { + let new_fields = &self.record_types[rec_ty_id.0]; + for (name_node, ty) in new_fields.iter() { + let name = self.compiler.get_span_contents(*name_node); + if let Some((_, types)) = record_fields.get_mut(&name) { + types.insert(*ty); + } else { + let mut types = HashSet::new(); + types.insert(*ty); + record_fields.insert(name, (*name_node, types)); + } + } + } + _ => { + let mut add = true; + let mut remove = HashSet::new(); + for other_id in simple_types.iter() { + if self.is_subtype(ty_id, *other_id) { + add = false; + break; + } + if self.is_subtype(*other_id, ty_id) { + remove.insert(*other_id); + } + } - if add { - res.insert(ty_id); - for other in remove { - res.remove(&other); + if add { + simple_types.insert(ty_id); + for other in remove { + simple_types.remove(&other); + } + } } } } - if res.len() == 1 { - *res.iter().next().unwrap() + if !list_elems.is_empty() { + let elem_oneof = self.create_oneof(list_elems); + simple_types.insert(self.push_type(Type::List(elem_oneof))); + } + + if !record_fields.is_empty() { + let mut fields = Vec::new(); + for (_, (node, types)) in record_fields.into_iter() { + fields.push((node, self.create_oneof(types))); + } + fields.sort_by_cached_key(|(name_node, _)| self.compiler.get_span_contents(*name_node)); + + let rec_ty_id = RecordTypeId(self.record_types.len()); + self.record_types.push(fields); + simple_types.insert(self.push_type(Type::Record(rec_ty_id))); + } + + if simple_types.len() == 1 { + *simple_types.iter().next().unwrap() } else { - self.oneof_types.push(res); + self.oneof_types.push(simple_types); self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) } } @@ -1585,30 +1621,3 @@ impl<'a> Typechecker<'a> { } } } - -/// Check whether two types can perform common numeric operations -fn check_numeric_op(lhs: Type, rhs: Type) -> Type { - match (rhs, lhs) { - (Type::Int, Type::Int) => Type::Int, - (Type::Int, Type::Float) => Type::Float, - (Type::Int, Type::Number) => Type::Number, - (Type::Float, Type::Int) => Type::Float, - (Type::Float, Type::Float) => Type::Float, - (Type::Float, Type::Number) => Type::Float, - (Type::Number, Type::Int) => Type::Number, - (Type::Number, Type::Float) => Type::Float, - (Type::Number, Type::Number) => Type::Number, - (Type::Any, _) => Type::Number, - (_, Type::Any) => Type::Number, - // TODO: Differentiate error based on whether LHS supports the op or not (see type_check.rs) - _ => Type::Unknown, - } -} - -/// Check whether two types can perform addition -fn check_plus_op(lhs: Type, rhs: Type) -> Type { - match (rhs, lhs) { - (Type::String, Type::String) => Type::String, - _ => check_numeric_op(lhs, rhs), - } -} diff --git a/tests/binary_ops_exact.nu b/tests/binary_ops_exact.nu index 1977041..56c1598 100644 --- a/tests/binary_ops_exact.nu +++ b/tests/binary_ops_exact.nu @@ -1,5 +1,5 @@ 1 == 1 -[true] ++ false +[true] ++ [false] 1 + 1 1.0 + 1.0 true and false diff --git a/tests/binary_ops_subtypes.nu b/tests/binary_ops_subtypes.nu index 618d648..11723a0 100644 --- a/tests/binary_ops_subtypes.nu +++ b/tests/binary_ops_subtypes.nu @@ -1,8 +1,8 @@ 1 == 1.0 "a" == 1.0 1 + 1.0 -[1] ++ 1.0 -[1.0 1] ++ "a" +[1] ++ [1.0] +[1.0 1] ++ ["a"] [[1] [2]] ++ [[3]] [[1] [2]] ++ [[3.0]] 1 in [1.0, 1] From 170ab1cb319cc82a204d7a5f3f640ffc9006d3bd Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Tue, 6 May 2025 20:55:15 -0400 Subject: [PATCH 07/16] Constrain lhs of Plus based on rhs --- ...parser__test__node_output@def_plus.nu.snap | 71 ++++++++++++++++ src/typechecker.rs | 80 +++++++++++++++---- tests/def_plus.nu | 5 ++ 3 files changed, 139 insertions(+), 17 deletions(-) create mode 100644 src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap create mode 100644 tests/def_plus.nu diff --git a/src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap new file mode 100644 index 0000000..d579492 --- /dev/null +++ b/src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap @@ -0,0 +1,71 @@ +--- +source: src/test.rs +expression: evaluate_example(path) +input_file: tests/def_plus.nu +--- +==== COMPILER ==== +0: Name (4 to 14) "mysterious" +1: Name (15 to 16) "T" +2: Params([NodeId(1)]) (14 to 17) +3: Name (20 to 21) "x" +4: Name (23 to 26) "int" +5: Type { name: NodeId(4), args: None, optional: false } (23 to 26) +6: Param { name: NodeId(3), ty: Some(NodeId(5)) } (20 to 26) +7: Params([NodeId(6)]) (18 to 28) +8: Name (31 to 38) "nothing" +9: Type { name: NodeId(8), args: None, optional: false } (31 to 38) +10: Name (42 to 43) "T" +11: Type { name: NodeId(10), args: None, optional: false } (42 to 43) +12: InOutType(NodeId(9), NodeId(11)) (31 to 44) +13: InOutTypes([NodeId(12)]) (31 to 44) +14: Block(BlockId(0)) (44 to 46) +15: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(13)), block: NodeId(14) } (0 to 46) +16: Variable (52 to 53) "m" +17: Name (56 to 66) "mysterious" +18: Int (67 to 68) "0" +19: Call { parts: [NodeId(17), NodeId(18)] } (67 to 68) +20: Let { variable_name: NodeId(16), ty: None, initializer: NodeId(19), is_mutable: false } (48 to 68) +21: Variable (70 to 72) "$m" +22: Plus (73 to 74) +23: String (75 to 80) ""foo"" +24: BinaryOp { lhs: NodeId(21), op: NodeId(22), rhs: NodeId(23) } (70 to 80) +25: Block(BlockId(1)) (0 to 81) +==== SCOPE ==== +0: Frame Scope, node_id: NodeId(25) + variables: [ m: NodeId(16) ] + decls: [ mysterious: NodeId(0) ] +1: Frame Scope, node_id: NodeId(14) + variables: [ x: NodeId(3) ] + type decls: [ T: NodeId(1) ] +==== TYPES ==== +0: unknown +1: unknown +2: unknown +3: unknown +4: unknown +5: int +6: int +7: forbidden +8: unknown +9: nothing +10: unknown +11: T +12: unknown +13: unknown +14: () +15: () +16: bottom +17: unknown +18: int +19: bottom +20: () +21: bottom +22: forbidden +23: string +24: string +25: string +==== IR ==== +register_count: 0 +file_count: 0 +==== IR ERRORS ==== +Error (NodeId 15): node Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(13)), block: NodeId(14) } not suported yet diff --git a/src/typechecker.rs b/src/typechecker.rs index 0def748..f3f5397 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -648,22 +648,43 @@ impl<'a> Typechecker<'a> { BOOL_TYPE } AstNode::Plus => { - let lhs_ty = self.typecheck_expr(lhs, TOP_TYPE); - if self.is_subtype(lhs_ty, STRING_TYPE) { + let mut types = HashSet::new(); + types.insert(STRING_TYPE); + types.insert(NUMBER_TYPE); + let common_ty = self.create_oneof(types); + + let lhs_ty = self.typecheck_expr(lhs, common_ty); + let lhs_bottom = self.is_subtype(lhs_ty, BOTTOM_TYPE); + if !lhs_bottom && self.is_subtype(lhs_ty, STRING_TYPE) { self.typecheck_expr(rhs, STRING_TYPE); STRING_TYPE - } else if self.is_subtype(lhs_ty, NUMBER_TYPE) { + } else if !lhs_bottom && self.is_subtype(lhs_ty, NUMBER_TYPE) { let rhs_ty = self.typecheck_expr(rhs, NUMBER_TYPE); - match (self.types[lhs_ty.0], self.types[rhs_ty.0]) { - (Type::Int, Type::Int) => INT_TYPE, - (Type::Int, Type::Float) => FLOAT_TYPE, - (Type::Float, Type::Int) => FLOAT_TYPE, - (Type::Float, Type::Float) => FLOAT_TYPE, - _ => NUMBER_TYPE, - } + self.numeric_op_type(lhs_ty, rhs_ty) } else { - self.binary_op_err("string/number operation", lhs, op, rhs); - ERROR_TYPE + let rhs_ty = self.typecheck_expr(rhs, common_ty); + let rhs_bottom = self.is_subtype(rhs_ty, BOTTOM_TYPE); + if !rhs_bottom && self.is_subtype(rhs_ty, STRING_TYPE) { + if !self.constrain_subtype(lhs_ty, STRING_TYPE) { + self.error( + format!("Expected string, got {}", self.type_to_string(lhs_ty)), + lhs, + ); + } + STRING_TYPE + } else if !rhs_bottom && self.is_subtype(lhs_ty, NUMBER_TYPE) { + if !self.constrain_subtype(lhs_ty, NUMBER_TYPE) { + self.error( + format!("Expected number, got {}", self.type_to_string(lhs_ty)), + lhs, + ); + } + self.numeric_op_type(lhs_ty, rhs_ty) + } else if lhs_bottom && rhs_bottom { + common_ty + } else { + ERROR_TYPE + } } } AstNode::Append => { @@ -703,6 +724,16 @@ impl<'a> Typechecker<'a> { } } + fn numeric_op_type(&self, lhs: TypeId, rhs: TypeId) -> TypeId { + match (self.types[lhs.0], self.types[rhs.0]) { + (Type::Int, Type::Int) => INT_TYPE, + (Type::Int, Type::Float) => FLOAT_TYPE, + (Type::Float, Type::Int) => FLOAT_TYPE, + (Type::Float, Type::Float) => FLOAT_TYPE, + _ => NUMBER_TYPE, + } + } + fn typecheck_def( &mut self, name: NodeId, @@ -1158,10 +1189,6 @@ impl<'a> Typechecker<'a> { true } - (_, Type::OneOf(oneof_id)) => self.oneof_types[oneof_id.0] - .clone() - .iter() - .all(|ty| self.constrain_subtype(sub_id, *ty)), (Type::Var(var_id), _) => { let lb = self.type_vars[var_id.0].lower_bound; let ub = self.type_vars[var_id.0].upper_bound; @@ -1199,6 +1226,17 @@ impl<'a> Typechecker<'a> { false } } + (Type::OneOf(oneof_id), _) => self.oneof_types[oneof_id.0] + .clone() + .iter() + .all(|ty| self.constrain_subtype(*ty, supe_id)), + (_, Type::OneOf(oneof_id)) => { + // todo actually add constraints? + self.oneof_types[oneof_id.0] + .clone() + .iter() + .any(|ty| self.is_subtype(sub_id, *ty)) + } (sub, supe) if sub == supe => true, _ => false, } @@ -1410,7 +1448,15 @@ impl<'a> Typechecker<'a> { String::from_utf8_lossy(self.compiler.get_span_contents(name_node)).to_string() } }, - Type::Var(type_var_id) => format!("'{}", type_var_id.0), + Type::Var(type_var_id) => { + let var = &self.type_vars[type_var_id.0]; + format!( + "{} <: '{} <: {}", + self.type_to_string(var.lower_bound), + type_var_id.0, + self.type_to_string(var.upper_bound) + ) + } } } diff --git a/tests/def_plus.nu b/tests/def_plus.nu new file mode 100644 index 0000000..ed2b6bd --- /dev/null +++ b/tests/def_plus.nu @@ -0,0 +1,5 @@ +def mysterious [ x: int ] : nothing -> T {} + +let m = mysterious 0 + +$m + "foo" From 2b777a5da5e7ebc551d973b323dc44e64b0e882e Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Fri, 9 May 2025 19:29:55 -0400 Subject: [PATCH 08/16] Create allof type --- ...t__node_output@binary_ops_subtypes.nu.snap | 5 +- ...ser__test__node_output@def_complex.nu.snap | 2 +- ...parser__test__node_output@def_plus.nu.snap | 48 ++-- ...er__test__node_output@let_mismatch.nu.snap | 10 +- src/typechecker.rs | 230 +++++++++++++++--- tests/def_plus.nu | 3 +- 6 files changed, 235 insertions(+), 63 deletions(-) diff --git a/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap b/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap index 39dcc85..bb726ad 100644 --- a/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@binary_ops_subtypes.nu.snap @@ -84,7 +84,7 @@ input_file: tests/binary_ops_subtypes.nu 2: float 3: bool 4: string -5: forbidden +5: error 6: float 7: bool 8: int @@ -152,8 +152,7 @@ input_file: tests/binary_ops_subtypes.nu 70: list>> 71: list>> ==== TYPE ERRORS ==== -Error (NodeId 2): Expected int, got float -Error (NodeId 6): Expected string, got float +Error (NodeId 5): type mismatch: unsupported incompatible types for equal between string and float ==== IR ==== register_count: 1 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap index 103727c..07c77c1 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap @@ -177,7 +177,7 @@ input_file: tests/def_complex.nu 67: int 68: bottom 69: () -70: record +70: record 71: unknown 72: unknown 73: unknown diff --git a/src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap index d579492..8b23803 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap @@ -21,17 +21,23 @@ input_file: tests/def_plus.nu 14: Block(BlockId(0)) (44 to 46) 15: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(13)), block: NodeId(14) } (0 to 46) 16: Variable (52 to 53) "m" -17: Name (56 to 66) "mysterious" -18: Int (67 to 68) "0" -19: Call { parts: [NodeId(17), NodeId(18)] } (67 to 68) -20: Let { variable_name: NodeId(16), ty: None, initializer: NodeId(19), is_mutable: false } (48 to 68) -21: Variable (70 to 72) "$m" -22: Plus (73 to 74) -23: String (75 to 80) ""foo"" -24: BinaryOp { lhs: NodeId(21), op: NodeId(22), rhs: NodeId(23) } (70 to 80) -25: Block(BlockId(1)) (0 to 81) +17: Name (55 to 58) "any" +18: Type { name: NodeId(17), args: None, optional: false } (55 to 58) +19: Name (61 to 71) "mysterious" +20: Int (72 to 73) "0" +21: Call { parts: [NodeId(19), NodeId(20)] } (72 to 73) +22: Let { variable_name: NodeId(16), ty: Some(NodeId(18)), initializer: NodeId(21), is_mutable: false } (48 to 73) +23: Variable (75 to 77) "$m" +24: Plus (78 to 79) +25: String (80 to 85) ""foo"" +26: BinaryOp { lhs: NodeId(23), op: NodeId(24), rhs: NodeId(25) } (75 to 85) +27: Variable (86 to 88) "$m" +28: Plus (89 to 90) +29: Int (91 to 94) "123" +30: BinaryOp { lhs: NodeId(27), op: NodeId(28), rhs: NodeId(29) } (86 to 94) +31: Block(BlockId(1)) (0 to 95) ==== SCOPE ==== -0: Frame Scope, node_id: NodeId(25) +0: Frame Scope, node_id: NodeId(31) variables: [ m: NodeId(16) ] decls: [ mysterious: NodeId(0) ] 1: Frame Scope, node_id: NodeId(14) @@ -54,16 +60,22 @@ input_file: tests/def_plus.nu 13: unknown 14: () 15: () -16: bottom +16: any 17: unknown -18: int -19: bottom -20: () -21: bottom -22: forbidden -23: string -24: string +18: any +19: unknown +20: int +21: top +22: () +23: any +24: forbidden 25: string +26: string +27: any +28: forbidden +29: int +30: number +31: number ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap b/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap index b2dab3d..78663a4 100644 --- a/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap @@ -49,22 +49,22 @@ input_file: tests/let_mismatch.nu 0: Frame Scope, node_id: NodeId(40) variables: [ v: NodeId(28), w: NodeId(15), x: NodeId(0), y: NodeId(5), z: NodeId(10) ] ==== TYPES ==== -0: int +0: number 1: unknown 2: number 3: int 4: () -5: string +5: any 6: unknown 7: any 8: string 9: () -10: int +10: string 11: unknown 12: string 13: int 14: () -15: list> +15: list> 16: unknown 17: unknown 18: unknown @@ -77,7 +77,7 @@ input_file: tests/let_mismatch.nu 25: list 26: list> 27: () -28: record +28: record 29: unknown 30: unknown 31: unknown diff --git a/src/typechecker.rs b/src/typechecker.rs index f3f5397..db234c9 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -21,7 +21,7 @@ pub struct TypeVar { upper_bound: TypeId, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TypeVarId(pub usize); #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -30,6 +30,9 @@ pub struct RecordTypeId(pub usize); #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct OneOfId(pub usize); +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AllOfId(pub usize); + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Type { /// Any node that hasn't been touched by the typechecker will have this type @@ -56,6 +59,7 @@ pub enum Type { Stream(TypeId), Record(RecordTypeId), OneOf(OneOfId), + AllOf(AllOfId), Ref(TypeDeclId), Var(TypeVarId), } @@ -105,6 +109,9 @@ pub struct Typechecker<'a> { /// Types used for `OneOf`. Each value in this vector matches with the index in OneOfId. /// oneof types should not be nested and should have at least two elements pub oneof_types: Vec>, + /// Types used for `AllOf`. Each value in this vector matches with the index in AllOfId. + /// allof types should not be nested and should have at least two elements + pub allof_types: Vec>, /// Type variables, indexed by TypeVarId pub type_vars: Vec, /// Type of each Variable in compiler.variables, indexed by VarId @@ -142,6 +149,7 @@ impl<'a> Typechecker<'a> { node_types: vec![UNKNOWN_TYPE; compiler.ast_nodes.len()], record_types: Vec::new(), oneof_types: Vec::new(), + allof_types: Vec::new(), type_vars: Vec::new(), variable_types: vec![UNKNOWN_TYPE; compiler.variables.len()], decl_types: vec![ @@ -588,7 +596,14 @@ impl<'a> Typechecker<'a> { match self.compiler.ast_nodes[op.0] { AstNode::Equal | AstNode::NotEqual => { let lhs_ty = self.typecheck_expr(lhs, TOP_TYPE); - self.typecheck_expr(rhs, lhs_ty); + let rhs_ty = self.typecheck_expr(rhs, TOP_TYPE); + if !self.is_subtype(lhs_ty, rhs_ty) + && !self.is_subtype(rhs_ty, lhs_ty) + && !(self.is_subtype(lhs_ty, NUMBER_TYPE) + && self.is_subtype(rhs_ty, NUMBER_TYPE)) + { + self.binary_op_err("incompatible types for equal", lhs, op, rhs); + } BOOL_TYPE } AstNode::LessThan @@ -672,7 +687,7 @@ impl<'a> Typechecker<'a> { ); } STRING_TYPE - } else if !rhs_bottom && self.is_subtype(lhs_ty, NUMBER_TYPE) { + } else if !rhs_bottom && self.is_subtype(rhs_ty, NUMBER_TYPE) { if !self.constrain_subtype(lhs_ty, NUMBER_TYPE) { self.error( format!("Expected number, got {}", self.type_to_string(lhs_ty)), @@ -899,7 +914,8 @@ impl<'a> Typechecker<'a> { ) { let type_id = if let Some(ty) = ty { let ty_id = self.typecheck_type(ty); - self.typecheck_expr(initializer, ty_id) + self.typecheck_expr(initializer, ty_id); + ty_id } else { self.typecheck_expr(initializer, TOP_TYPE) }; @@ -1112,15 +1128,24 @@ impl<'a> Typechecker<'a> { self.record_types.push(fields); self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) } - Type::OneOf(one_of_id) => { - let orig_types = self.oneof_types[one_of_id.0].clone(); + Type::OneOf(id) => { + let orig_types = self.oneof_types[id.0].clone(); let mut new_types = HashSet::new(); for ty in orig_types.iter() { new_types.insert(self.subst(*ty, substs)); } - self.oneof_types.push(orig_types); + self.oneof_types.push(new_types); self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) } + Type::AllOf(id) => { + let orig_types = self.allof_types[id.0].clone(); + let mut new_types = HashSet::new(); + for ty in orig_types.iter() { + new_types.insert(self.subst(*ty, substs)); + } + self.allof_types.push(new_types); + self.push_type(Type::AllOf(AllOfId(self.allof_types.len() - 1))) + } Type::Ref(type_decl_id) => { if let Some(var) = substs.get(&type_decl_id) { self.push_type(Type::Var(*var)) @@ -1196,7 +1221,6 @@ impl<'a> Typechecker<'a> { // Prevent forward references/cycles let new_ub = self.eliminate_type_vars(new_ub, var_id, true); - // todo prevent infinite recursion here if self.constrain_subtype(lb, new_ub) { let var = self .type_vars @@ -1226,13 +1250,13 @@ impl<'a> Typechecker<'a> { false } } - (Type::OneOf(oneof_id), _) => self.oneof_types[oneof_id.0] + (Type::OneOf(id), _) => self.oneof_types[id.0] .clone() .iter() .all(|ty| self.constrain_subtype(*ty, supe_id)), - (_, Type::OneOf(oneof_id)) => { + (_, Type::OneOf(id)) => { // todo actually add constraints? - self.oneof_types[oneof_id.0] + self.oneof_types[id.0] .clone() .iter() .any(|ty| self.is_subtype(sub_id, *ty)) @@ -1294,7 +1318,11 @@ impl<'a> Typechecker<'a> { let var = &self.type_vars[var_id.0]; self.is_subtype(sub, var.lower_bound) } - (_, Type::OneOf(oneof_id)) => self.oneof_types[oneof_id.0] + (Type::OneOf(id), _) => self.oneof_types[id.0] + .clone() + .iter() + .all(|ty| self.is_subtype(*ty, supe)), + (_, Type::OneOf(id)) => self.oneof_types[id.0] .iter() .any(|ty| self.is_subtype(sub, *ty)), (sub, supe) => sub == supe, @@ -1360,15 +1388,24 @@ impl<'a> Typechecker<'a> { ty_id } } - Type::OneOf(one_of_id) => { - let orig_types = self.oneof_types[one_of_id.0].clone(); + Type::OneOf(id) => { + let orig_types = self.oneof_types[id.0].clone(); let mut new_types = HashSet::new(); for ty in orig_types.iter() { new_types.insert(self.eliminate_type_vars(*ty, max_var, use_lower)); } - self.oneof_types.push(orig_types); + self.oneof_types.push(new_types); self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) } + Type::AllOf(id) => { + let orig_types = self.allof_types[id.0].clone(); + let mut new_types = HashSet::new(); + for ty in orig_types.iter() { + new_types.insert(self.eliminate_type_vars(*ty, max_var, use_lower)); + } + self.allof_types.push(new_types); + self.push_type(Type::AllOf(AllOfId(self.allof_types.len() - 1))) + } Type::Var(var_id) => { if var_id.0 < max_var.0 { ty_id @@ -1443,6 +1480,23 @@ impl<'a> Typechecker<'a> { fmt.push('>'); fmt } + Type::AllOf(id) => { + let mut fmt = "allof<".to_string(); + let mut types: Vec<_> = self.oneof_types[id.0] + .iter() + .map(|ty| self.type_to_string(*ty) + ", ") + .collect(); + types.sort(); + for ty in &types { + fmt += ty; + } + if !types.is_empty() { + fmt.pop(); + fmt.pop(); + } + fmt.push('>'); + fmt + } Type::Ref(type_decl_id) => match self.compiler.type_decls[type_decl_id.0] { TypeDecl::Param(name_node) => { String::from_utf8_lossy(self.compiler.get_span_contents(name_node)).to_string() @@ -1538,8 +1592,8 @@ impl<'a> Typechecker<'a> { let mut flattened = HashSet::new(); for ty_id in types { match self.types[ty_id.0] { - Type::OneOf(oneof_id) => { - flattened.extend(&self.oneof_types[oneof_id.0]); + Type::OneOf(id) => { + flattened.extend(&self.oneof_types[id.0]); } _ => { flattened.insert(ty_id); @@ -1645,25 +1699,131 @@ impl<'a> Typechecker<'a> { } } - fn create_intersection(&mut self, lhs_id: TypeId, rhs_id: TypeId) -> TypeId { - if lhs_id == rhs_id { - return lhs_id; + /// Use this to create any intersection types, to ensure that the created intersection type + /// is as simple as possible + fn create_allof(&mut self, types: HashSet) -> TypeId { + let mut flattened = HashSet::new(); + for ty_id in types { + match self.types[ty_id.0] { + Type::AllOf(id) => { + flattened.extend(&self.allof_types[id.0]); + } + _ => { + flattened.insert(ty_id); + } + } + } + + if flattened.is_empty() { + return TOP_TYPE; + } + + let mut vars = HashMap::::new(); + let mut refs = HashMap::::new(); + let mut simple_type: Option = None; + let mut list_elems = HashSet::new(); + let mut record_fields = HashMap::<&[u8], (NodeId, HashSet)>::new(); + let mut oneof_ids = Vec::new(); + for ty_id in flattened { + let ty = self.types[ty_id.0]; + + match ty { + Type::Any => return ANY_TYPE, + Type::Unknown => return UNKNOWN_TYPE, + Type::Top => {} + Type::Bottom => return BOTTOM_TYPE, + Type::Var(var_id) => { + vars.insert(var_id, ty_id); + } + Type::Ref(decl_id) => { + refs.insert(decl_id, ty_id); + } + Type::List(elem_ty) => { + if simple_type.is_some() || !record_fields.is_empty() { + return BOTTOM_TYPE; + } + list_elems.insert(elem_ty); + } + Type::Record(rec_ty_id) => { + if simple_type.is_some() || !list_elems.is_empty() { + return BOTTOM_TYPE; + } + let new_fields = &self.record_types[rec_ty_id.0]; + for (name_node, ty) in new_fields.iter() { + let name = self.compiler.get_span_contents(*name_node); + if let Some((_, types)) = record_fields.get_mut(&name) { + types.insert(*ty); + } else { + let mut types = HashSet::new(); + types.insert(*ty); + record_fields.insert(name, (*name_node, types)); + } + } + } + Type::OneOf(id) => { + oneof_ids.push(id); + } + _ => { + if !list_elems.is_empty() && !record_fields.is_empty() { + return BOTTOM_TYPE; + } + if let Some(other_id) = &simple_type { + if self.is_subtype(ty_id, *other_id) { + simple_type = Some(ty_id); + } else if self.is_subtype(*other_id, ty_id) { + } else { + return BOTTOM_TYPE; + } + } else { + simple_type = Some(ty_id); + } + } + } + } + + let mut res = HashSet::new(); + res.extend(vars.values()); + res.extend(refs.values()); + + if let Some(ty) = simple_type { + res.insert(ty); + } + if !list_elems.is_empty() { + let elem_allof = self.create_allof(list_elems); + res.insert(self.push_type(Type::List(elem_allof))); + } + if !record_fields.is_empty() { + let mut fields = Vec::new(); + for (_, (node, types)) in record_fields.into_iter() { + fields.push((node, self.create_oneof(types))); + } + fields.sort_by_cached_key(|(name_node, _)| self.compiler.get_span_contents(*name_node)); + + let rec_ty_id = RecordTypeId(self.record_types.len()); + self.record_types.push(fields); + res.insert(self.push_type(Type::Record(rec_ty_id))); } - match (self.types[lhs_id.0], self.types[rhs_id.0]) { - (Type::Any | Type::Unknown, _) => lhs_id, - (_, Type::Any | Type::Unknown) => rhs_id, - (Type::Top, _) => rhs_id, - (_, Type::Top) => lhs_id, - (Type::Bottom, _) | (_, Type::Bottom) => BOTTOM_TYPE, - (Type::Number, Type::Int) | (Type::Int, Type::Number) => INT_TYPE, - (Type::Number, Type::Float) | (Type::Float, Type::Number) => FLOAT_TYPE, - (Type::List(lhs_inner), Type::List(rhs_inner)) => { - let new_inner = self.create_intersection(lhs_inner, rhs_inner); - self.push_type(Type::List(new_inner)) - } - (Type::Var(_) | Type::Ref(_), _) | (_, Type::Var(_) | Type::Ref(_)) => todo!(), - (lhs, rhs) if lhs == rhs => lhs_id, - _ => BOTTOM_TYPE, + + let oneofs = oneof_ids + .into_iter() + .map(|id| self.oneof_types[id.0].clone()) + .collect::>(); + // todo handle oneofs, need cartesian product + + if res.is_empty() { + TOP_TYPE + } else if res.len() == 1 { + *res.iter().next().unwrap() + } else { + self.allof_types.push(res); + self.push_type(Type::AllOf(AllOfId(self.allof_types.len() - 1))) } } + + fn create_intersection(&mut self, lhs_id: TypeId, rhs_id: TypeId) -> TypeId { + let mut types = HashSet::new(); + types.insert(lhs_id); + types.insert(rhs_id); + self.create_allof(types) + } } diff --git a/tests/def_plus.nu b/tests/def_plus.nu index ed2b6bd..4ec8a3e 100644 --- a/tests/def_plus.nu +++ b/tests/def_plus.nu @@ -1,5 +1,6 @@ def mysterious [ x: int ] : nothing -> T {} -let m = mysterious 0 +let m: any = mysterious 0 $m + "foo" +$m + 123 From 6f240789d373eb0cbf770a9bc757f463ff54442f Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Sat, 10 May 2025 14:22:11 -0400 Subject: [PATCH 09/16] Make allof work on oneof --- src/typechecker.rs | 53 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/src/typechecker.rs b/src/typechecker.rs index db234c9..6b842f0 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -44,7 +44,9 @@ pub enum Type { /// None type means that a node has no type. For example, statements like let x = ... do not /// output anything and thus don't have any type. None, + /// Supertype of all types Top, + /// Subtype of all types Bottom, Any, Number, @@ -58,9 +60,15 @@ pub enum Type { List(TypeId), Stream(TypeId), Record(RecordTypeId), + /// Union type. OneOf types should not be nested and should have at least two elements. + /// They can contain allof types. OneOf(OneOfId), + /// Intersection type. AllOf types should not be nested and should have at least two elements. + /// They also cannot contain oneof types. AllOf(AllOfId), + /// A reference to a type declaration such as a type parameter Ref(TypeDeclId), + /// A type variable that must be solved Var(TypeVarId), } @@ -107,10 +115,8 @@ pub struct Typechecker<'a> { /// The individual field lists are stored sorted by field name. pub record_types: Vec>, /// Types used for `OneOf`. Each value in this vector matches with the index in OneOfId. - /// oneof types should not be nested and should have at least two elements pub oneof_types: Vec>, /// Types used for `AllOf`. Each value in this vector matches with the index in AllOfId. - /// allof types should not be nested and should have at least two elements pub allof_types: Vec>, /// Type variables, indexed by TypeVarId pub type_vars: Vec, @@ -597,10 +603,10 @@ impl<'a> Typechecker<'a> { AstNode::Equal | AstNode::NotEqual => { let lhs_ty = self.typecheck_expr(lhs, TOP_TYPE); let rhs_ty = self.typecheck_expr(rhs, TOP_TYPE); - if !self.is_subtype(lhs_ty, rhs_ty) - && !self.is_subtype(rhs_ty, lhs_ty) - && !(self.is_subtype(lhs_ty, NUMBER_TYPE) - && self.is_subtype(rhs_ty, NUMBER_TYPE)) + if !(self.is_subtype(lhs_ty, rhs_ty) + || self.is_subtype(rhs_ty, lhs_ty) + || (self.is_subtype(lhs_ty, NUMBER_TYPE) + && self.is_subtype(rhs_ty, NUMBER_TYPE))) { self.binary_op_err("incompatible types for equal", lhs, op, rhs); } @@ -1804,20 +1810,41 @@ impl<'a> Typechecker<'a> { res.insert(self.push_type(Type::Record(rec_ty_id))); } - let oneofs = oneof_ids - .into_iter() - .map(|id| self.oneof_types[id.0].clone()) - .collect::>(); - // todo handle oneofs, need cartesian product - - if res.is_empty() { + let single_res = if res.is_empty() { TOP_TYPE } else if res.len() == 1 { *res.iter().next().unwrap() } else { self.allof_types.push(res); self.push_type(Type::AllOf(AllOfId(self.allof_types.len() - 1))) + }; + + if oneof_ids.is_empty() { + return single_res; + } + + let mut first_inter = HashSet::new(); + first_inter.insert(single_res); + let mut inters = vec![first_inter]; + + for oneof_id in oneof_ids { + let mut new_inters = vec![]; + let types = &self.oneof_types[oneof_id.0]; + for ty in types.iter() { + for mut inter in inters.clone() { + inter.insert(*ty); + new_inters.push(inter); + } + } + inters = new_inters; } + + let inters = inters + .into_iter() + .map(|inter| self.create_allof(inter)) + .collect::>(); + + self.create_oneof(inters) } fn create_intersection(&mut self, lhs_id: TypeId, rhs_id: TypeId) -> TypeId { From 08f75c3210313cafa0007ab8f883f740549d29be Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Sat, 10 May 2025 15:56:45 -0400 Subject: [PATCH 10/16] Correctly add subtype constraints to type vars --- src/snapshots/new_nu_parser__test__lexer.snap | 21 ----- ...t__node_output@binary_ops_mismatch.nu.snap | 4 +- ...__test__node_output@infer_complex.nu.snap} | 4 +- ..._test__node_output@infer_generics.nu.snap} | 51 +++-------- ...ser__test__node_output@infer_plus.nu.snap} | 62 +++++++------- ...er__test__node_output@let_mismatch.nu.snap | 2 + src/typechecker.rs | 84 +++++++++++-------- tests/{def_complex.nu => infer_complex.nu} | 0 tests/{def_generics.nu => infer_generics.nu} | 5 +- tests/{def_plus.nu => infer_plus.nu} | 2 +- 10 files changed, 99 insertions(+), 136 deletions(-) delete mode 100644 src/snapshots/new_nu_parser__test__lexer.snap rename src/snapshots/{new_nu_parser__test__node_output@def_complex.nu.snap => new_nu_parser__test__node_output@infer_complex.nu.snap} (98%) rename src/snapshots/{new_nu_parser__test__node_output@def_generics.nu.snap => new_nu_parser__test__node_output@infer_generics.nu.snap} (59%) rename src/snapshots/{new_nu_parser__test__node_output@def_plus.nu.snap => new_nu_parser__test__node_output@infer_plus.nu.snap} (62%) rename tests/{def_complex.nu => infer_complex.nu} (100%) rename tests/{def_generics.nu => infer_generics.nu} (55%) rename tests/{def_plus.nu => infer_plus.nu} (72%) diff --git a/src/snapshots/new_nu_parser__test__lexer.snap b/src/snapshots/new_nu_parser__test__lexer.snap deleted file mode 100644 index 87e09ba..0000000 --- a/src/snapshots/new_nu_parser__test__lexer.snap +++ /dev/null @@ -1,21 +0,0 @@ ---- -source: src/test.rs -expression: evaluate_lexer(path) -input_file: tests/lex/int.nu -snapshot_kind: text ---- -Token 0: Number span: 0 .. 1 '0' -Token 1: Newline span: 1 .. 2 '\n' -Token 2: Number span: 2 .. 4 '00' -Token 3: Newline span: 4 .. 5 '\n' -Token 4: Number span: 5 .. 10 '0x123' -Token 5: Newline span: 10 .. 11 '\n' -Token 6: Number span: 11 .. 16 '0b101' -Token 7: Newline span: 16 .. 17 '\n' -Token 8: Name span: 17 .. 19 '_0' -Token 9: Newline span: 19 .. 20 '\n' -Token 10: Number span: 20 .. 21 '0' -Token 11: Name span: 21 .. 22 '_' -Token 12: Newline span: 22 .. 23 '\n' -Token 13: Name span: 23 .. 24 '_' -Token 14: Newline span: 24 .. 25 '\n' diff --git a/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap b/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap index 70bf21f..42a713d 100644 --- a/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap @@ -43,8 +43,8 @@ input_file: tests/binary_ops_mismatch.nu 16: bool ==== TYPE ERRORS ==== Error (NodeId 2): Expected string, got float -Error (NodeId 4): Expected list, got string -Error (NodeId 6): Expected list, got float +Error (NodeId 4): Expected list <: '0 <: list, got string +Error (NodeId 6): Expected list <: '0 <: list, got float Error (NodeId 5): type mismatch: unsupported append between string and float Error (NodeId 10): Expected bool, got string Error (NodeId 12): Expected string, got bool diff --git a/src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap b/src/snapshots/new_nu_parser__test__node_output@infer_complex.nu.snap similarity index 98% rename from src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap rename to src/snapshots/new_nu_parser__test__node_output@infer_complex.nu.snap index 07c77c1..b91bb9c 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@infer_complex.nu.snap @@ -1,7 +1,7 @@ --- source: src/test.rs expression: evaluate_example(path) -input_file: tests/def_complex.nu +input_file: tests/infer_complex.nu --- ==== COMPILER ==== 0: Name (4 to 5) "f" @@ -196,7 +196,7 @@ input_file: tests/def_complex.nu 86: unknown 87: string 88: record -89: record +89: record 90: () 91: () ==== IR ==== diff --git a/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap b/src/snapshots/new_nu_parser__test__node_output@infer_generics.nu.snap similarity index 59% rename from src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap rename to src/snapshots/new_nu_parser__test__node_output@infer_generics.nu.snap index baf1eb4..3a202d4 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@infer_generics.nu.snap @@ -1,7 +1,7 @@ --- source: src/test.rs expression: evaluate_example(path) -input_file: tests/def_generics.nu +input_file: tests/infer_generics.nu --- ==== COMPILER ==== 0: Name (4 to 5) "f" @@ -30,28 +30,12 @@ input_file: tests/def_generics.nu 23: List([NodeId(22)]) (59 to 62) 24: Block(BlockId(0)) (39 to 65) 25: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(16)), block: NodeId(24) } (0 to 65) -26: Variable (71 to 73) "l1" -27: Name (76 to 77) "f" -28: Int (78 to 79) "1" -29: Call { parts: [NodeId(27), NodeId(28)] } (78 to 79) -30: Let { variable_name: NodeId(26), ty: None, initializer: NodeId(29), is_mutable: false } (67 to 79) -31: Variable (85 to 87) "l2" -32: Name (90 to 91) "f" -33: Int (92 to 93) "2" -34: Call { parts: [NodeId(32), NodeId(33)] } (92 to 93) -35: Let { variable_name: NodeId(31), ty: None, initializer: NodeId(34), is_mutable: false } (81 to 93) -36: Variable (98 to 100) "l3" -37: Name (102 to 106) "list" -38: Name (107 to 113) "number" -39: Type { name: NodeId(38), args: None, optional: false } (107 to 113) -40: TypeArgs([NodeId(39)]) (106 to 114) -41: Type { name: NodeId(37), args: Some(NodeId(40)), optional: false } (102 to 106) -42: Variable (117 to 120) "$l2" -43: Let { variable_name: NodeId(36), ty: Some(NodeId(41)), initializer: NodeId(42), is_mutable: false } (94 to 120) -44: Block(BlockId(1)) (0 to 121) +26: Name (67 to 68) "f" +27: Int (69 to 70) "1" +28: Call { parts: [NodeId(26), NodeId(27)] } (69 to 70) +29: Block(BlockId(1)) (0 to 71) ==== SCOPE ==== -0: Frame Scope, node_id: NodeId(44) - variables: [ l1: NodeId(26), l2: NodeId(31), l3: NodeId(36) ] +0: Frame Scope, node_id: NodeId(29) decls: [ f: NodeId(0) ] 1: Frame Scope, node_id: NodeId(24) variables: [ x: NodeId(3), z: NodeId(17) ] @@ -83,25 +67,10 @@ input_file: tests/def_generics.nu 23: list 24: list 25: () -26: list -27: unknown -28: int -29: list -30: () -31: list -32: unknown -33: int -34: list -35: () -36: list -37: unknown -38: unknown -39: number -40: forbidden -41: list -42: list -43: () -44: () +26: unknown +27: int +28: list +29: list ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap b/src/snapshots/new_nu_parser__test__node_output@infer_plus.nu.snap similarity index 62% rename from src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap rename to src/snapshots/new_nu_parser__test__node_output@infer_plus.nu.snap index 8b23803..ff021e2 100644 --- a/src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@infer_plus.nu.snap @@ -1,7 +1,7 @@ --- source: src/test.rs expression: evaluate_example(path) -input_file: tests/def_plus.nu +input_file: tests/infer_plus.nu --- ==== COMPILER ==== 0: Name (4 to 14) "mysterious" @@ -21,23 +21,21 @@ input_file: tests/def_plus.nu 14: Block(BlockId(0)) (44 to 46) 15: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(13)), block: NodeId(14) } (0 to 46) 16: Variable (52 to 53) "m" -17: Name (55 to 58) "any" -18: Type { name: NodeId(17), args: None, optional: false } (55 to 58) -19: Name (61 to 71) "mysterious" -20: Int (72 to 73) "0" -21: Call { parts: [NodeId(19), NodeId(20)] } (72 to 73) -22: Let { variable_name: NodeId(16), ty: Some(NodeId(18)), initializer: NodeId(21), is_mutable: false } (48 to 73) -23: Variable (75 to 77) "$m" -24: Plus (78 to 79) -25: String (80 to 85) ""foo"" -26: BinaryOp { lhs: NodeId(23), op: NodeId(24), rhs: NodeId(25) } (75 to 85) -27: Variable (86 to 88) "$m" -28: Plus (89 to 90) -29: Int (91 to 94) "123" -30: BinaryOp { lhs: NodeId(27), op: NodeId(28), rhs: NodeId(29) } (86 to 94) -31: Block(BlockId(1)) (0 to 95) +17: Name (56 to 66) "mysterious" +18: Int (67 to 68) "0" +19: Call { parts: [NodeId(17), NodeId(18)] } (67 to 68) +20: Let { variable_name: NodeId(16), ty: None, initializer: NodeId(19), is_mutable: false } (48 to 68) +21: Variable (70 to 72) "$m" +22: Plus (73 to 74) +23: String (75 to 80) ""foo"" +24: BinaryOp { lhs: NodeId(21), op: NodeId(22), rhs: NodeId(23) } (70 to 80) +25: Variable (81 to 83) "$m" +26: Plus (84 to 85) +27: Int (86 to 89) "123" +28: BinaryOp { lhs: NodeId(25), op: NodeId(26), rhs: NodeId(27) } (81 to 89) +29: Block(BlockId(1)) (0 to 90) ==== SCOPE ==== -0: Frame Scope, node_id: NodeId(31) +0: Frame Scope, node_id: NodeId(29) variables: [ m: NodeId(16) ] decls: [ mysterious: NodeId(0) ] 1: Frame Scope, node_id: NodeId(14) @@ -60,22 +58,22 @@ input_file: tests/def_plus.nu 13: unknown 14: () 15: () -16: any +16: bottom 17: unknown -18: any -19: unknown -20: int -21: top -22: () -23: any -24: forbidden -25: string -26: string -27: any -28: forbidden -29: int -30: number -31: number +18: int +19: bottom +20: () +21: bottom +22: forbidden +23: string +24: string +25: bottom +26: forbidden +27: int +28: string +29: string +==== TYPE ERRORS ==== +Error (NodeId 27): Expected string, got int ==== IR ==== register_count: 0 file_count: 0 diff --git a/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap b/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap index 78663a4..88f25c1 100644 --- a/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap +++ b/src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap @@ -92,6 +92,8 @@ input_file: tests/let_mismatch.nu 40: () ==== TYPE ERRORS ==== Error (NodeId 13): Expected string, got int +Error (NodeId 24): Expected int, got string +Error (NodeId 25): Expected list, got list Error (NodeId 26): Expected list>, got list> Error (NodeId 38): Expected record, got record ==== IR ==== diff --git a/src/typechecker.rs b/src/typechecker.rs index 6b842f0..e38c636 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -215,10 +215,17 @@ impl<'a> Typechecker<'a> { let last_node_id = NodeId(last); self.typecheck_node(last_node_id); + for i in 0..self.type_vars.len() { + let var = &self.type_vars[i]; + let bound = var.lower_bound; + let cleaned = self.eliminate_type_vars(bound, TypeVarId(0), true); + self.types[bound.0] = self.types[cleaned.0]; + } + for i in 0..self.types.len() { - let res = self.eliminate_type_vars(TypeId(i), TypeVarId(0), false); - if res.0 != i { - self.types[i] = self.types[res.0]; + if let Type::Var(var_id) = &self.types[i] { + let bound = self.type_vars[var_id.0].lower_bound; + self.types[i] = self.types[bound.0]; } } } @@ -389,9 +396,10 @@ impl<'a> Typechecker<'a> { AstNode::True | AstNode::False => BOOL_TYPE, AstNode::String => STRING_TYPE, AstNode::List(ref items) => { - // TODO inspect the expected type and infer a union type instead + // TODO infer a union type instead if let Some(first_id) = items.first() { - self.typecheck_expr(*first_id, TOP_TYPE); + let expected_elem = self.extract_elem_type(expected); + self.typecheck_expr(*first_id, expected_elem.unwrap_or(TOP_TYPE)); let first_type = self.type_of(*first_id); let mut all_numbers = self.is_type_compatible(first_type, Type::Number); @@ -709,25 +717,22 @@ impl<'a> Typechecker<'a> { } } AstNode::Append => { - // TODO cache this type - let list_ty = self.push_type(Type::List(TOP_TYPE)); - let lhs_type = self.typecheck_expr(lhs, list_ty); - let rhs_type = self.typecheck_expr(rhs, list_ty); - - //todo account for any - match (self.types[lhs_type.0], self.types[rhs_type.0]) { - (Type::List(lhs_item), Type::List(rhs_item)) => { - let mut types = HashSet::new(); - types.insert(lhs_item); - types.insert(rhs_item); - let common_type = self.create_oneof(types); - self.push_type(Type::List(common_type)) - } - (_, Type::Any | Type::Bottom) | (Type::Any | Type::Bottom, _) => ANY_TYPE, - _ => { - self.binary_op_err("append", lhs, op, rhs); - ERROR_TYPE - } + // TODO cache these two types + let top_list = self.push_type(Type::List(TOP_TYPE)); + let bottom_list = self.push_type(Type::List(BOTTOM_TYPE)); + + let res_var = self.new_typevar(bottom_list, top_list); + let res_type = self.push_type(Type::Var(res_var)); + let lhs_type = self.typecheck_expr(lhs, res_type); + let rhs_type = self.typecheck_expr(rhs, res_type); + + if self.is_subtype(lhs_type, LIST_ANY_TYPE) + && self.is_subtype(rhs_type, LIST_ANY_TYPE) + { + res_type + } else { + self.binary_op_err("append", lhs, op, rhs); + ERROR_TYPE } } AstNode::Assignment @@ -1162,6 +1167,18 @@ impl<'a> Typechecker<'a> { } } + /// Given the type for a list, extract the type of its elements + fn extract_elem_type(&mut self, list_ty: TypeId) -> Option { + match self.types[list_ty.0] { + Type::List(elem) => Some(elem), + Type::Top => Some(TOP_TYPE), + Type::Bottom => Some(BOTTOM_TYPE), + Type::Any => Some(ANY_TYPE), + Type::Unknown => Some(UNKNOWN_TYPE), + _ => None, + } + } + fn set_node_type_id(&mut self, node_id: NodeId, type_id: TypeId) { self.node_types[node_id.0] = type_id; } @@ -1223,7 +1240,10 @@ impl<'a> Typechecker<'a> { (Type::Var(var_id), _) => { let lb = self.type_vars[var_id.0].lower_bound; let ub = self.type_vars[var_id.0].upper_bound; - let new_ub = self.create_intersection(ub, supe_id); + let mut types = HashSet::new(); + types.insert(ub); + types.insert(supe_id); + let new_ub = self.create_allof(types); // Prevent forward references/cycles let new_ub = self.eliminate_type_vars(new_ub, var_id, true); @@ -1241,7 +1261,10 @@ impl<'a> Typechecker<'a> { (_, Type::Var(var_id)) => { let lb = self.type_vars[var_id.0].lower_bound; let ub = self.type_vars[var_id.0].upper_bound; - let new_lb = self.create_intersection(lb, sub_id); + let mut types = HashSet::new(); + types.insert(lb); + types.insert(sub_id); + let new_lb = self.create_oneof(types); // Prevent forward references/cycles let new_lb = self.eliminate_type_vars(new_lb, var_id, false); @@ -1598,6 +1621,8 @@ impl<'a> Typechecker<'a> { let mut flattened = HashSet::new(); for ty_id in types { match self.types[ty_id.0] { + Type::Top | Type::Any | Type::Unknown => return ty_id, + Type::Bottom => {} Type::OneOf(id) => { flattened.extend(&self.oneof_types[id.0]); } @@ -1846,11 +1871,4 @@ impl<'a> Typechecker<'a> { self.create_oneof(inters) } - - fn create_intersection(&mut self, lhs_id: TypeId, rhs_id: TypeId) -> TypeId { - let mut types = HashSet::new(); - types.insert(lhs_id); - types.insert(rhs_id); - self.create_allof(types) - } } diff --git a/tests/def_complex.nu b/tests/infer_complex.nu similarity index 100% rename from tests/def_complex.nu rename to tests/infer_complex.nu diff --git a/tests/def_generics.nu b/tests/infer_generics.nu similarity index 55% rename from tests/def_generics.nu rename to tests/infer_generics.nu index 0784ce8..549de42 100644 --- a/tests/def_generics.nu +++ b/tests/infer_generics.nu @@ -3,7 +3,4 @@ def f [ x: T ] : nothing -> list { [$z] } -let l1 = f 1 - -let l2 = f 2 -let l3: list = $l2 +f 1 diff --git a/tests/def_plus.nu b/tests/infer_plus.nu similarity index 72% rename from tests/def_plus.nu rename to tests/infer_plus.nu index 4ec8a3e..6a11f01 100644 --- a/tests/def_plus.nu +++ b/tests/infer_plus.nu @@ -1,6 +1,6 @@ def mysterious [ x: int ] : nothing -> T {} -let m: any = mysterious 0 +let m = mysterious 0 $m + "foo" $m + 123 From d36b699cf58beb53aba347cb231cbea01304d74a Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Sat, 10 May 2025 16:18:41 -0400 Subject: [PATCH 11/16] Fix after merge --- src/compiler.rs | 4 +++- src/typechecker.rs | 9 ++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/compiler.rs b/src/compiler.rs index 62b467c..11136e0 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,7 +1,9 @@ use crate::errors::SourceError; use crate::parser::{AstNode, Block, NodeId, Pipeline}; use crate::protocol::Command; -use crate::resolver::{DeclId, Frame, NameBindings, ScopeId, TypeDecl, TypeDeclId, VarId, Variable}; +use crate::resolver::{ + DeclId, Frame, NameBindings, ScopeId, TypeDecl, TypeDeclId, VarId, Variable, +}; use crate::typechecker::{TypeId, Types}; use std::collections::HashMap; diff --git a/src/typechecker.rs b/src/typechecker.rs index f747055..4d60620 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -444,16 +444,14 @@ impl<'a> Typechecker<'a> { let pipeline = &self.compiler.pipelines[pipeline_id.0]; let expressions = pipeline.get_expressions(); for inner in expressions { - self.typecheck_node(*inner) + self.typecheck_expr(*inner, TOP_TYPE); } // pipeline type is the type of the last expression, since blocks // by themselves aren't supposed to be typed - let pipeline_type = expressions + expressions .last() - .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)); - - self.set_node_type_id(node_id, pipeline_type); + .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)) } AstNode::Closure { params, block } => { // TODO: input/output types @@ -552,6 +550,7 @@ impl<'a> Typechecker<'a> { | AstNode::List(_) | AstNode::Record { .. } | AstNode::Table { .. } + | AstNode::Pipeline(_) | AstNode::Closure { .. } | AstNode::BinaryOp { .. } | AstNode::If { .. } From ad5aaf4d25a8b8f7f445724953cf1dc34f6aec55 Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Sat, 10 May 2025 16:21:57 -0400 Subject: [PATCH 12/16] .unwrap() -> .expect() --- src/typechecker.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/typechecker.rs b/src/typechecker.rs index 4d60620..7f3f665 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -1737,7 +1737,7 @@ impl<'a> Typechecker<'a> { } if simple_types.len() == 1 { - *simple_types.iter().next().unwrap() + *simple_types.iter().next().expect("should have exactly 1 element") } else { self.oneof_types.push(simple_types); self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) @@ -1852,7 +1852,7 @@ impl<'a> Typechecker<'a> { let single_res = if res.is_empty() { TOP_TYPE } else if res.len() == 1 { - *res.iter().next().unwrap() + *res.iter().next().expect("should have exactly 1 element") } else { self.allof_types.push(res); self.push_type(Type::AllOf(AllOfId(self.allof_types.len() - 1))) From 85d30f0ef2141cf2ed0ac0e0a80daaa4141a0021 Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Sat, 10 May 2025 16:27:29 -0400 Subject: [PATCH 13/16] Format (I'm an idiot) --- src/typechecker.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/typechecker.rs b/src/typechecker.rs index 7f3f665..954c332 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -1737,7 +1737,10 @@ impl<'a> Typechecker<'a> { } if simple_types.len() == 1 { - *simple_types.iter().next().expect("should have exactly 1 element") + *simple_types + .iter() + .next() + .expect("should have exactly 1 element") } else { self.oneof_types.push(simple_types); self.push_type(Type::OneOf(OneOfId(self.oneof_types.len() - 1))) From e5de73c78ac2d9da818035452e3f5da0a91aff3e Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Fri, 20 Jun 2025 21:51:35 -0400 Subject: [PATCH 14/16] Correct allof type to string --- src/typechecker.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/typechecker.rs b/src/typechecker.rs index 954c332..777f755 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -1525,7 +1525,7 @@ impl<'a> Typechecker<'a> { } Type::AllOf(id) => { let mut fmt = "allof<".to_string(); - let mut types: Vec<_> = self.oneof_types[id.0] + let mut types: Vec<_> = self.allof_types[id.0] .iter() .map(|ty| self.type_to_string(*ty) + ", ") .collect(); From 09e3c4ab25e2fd8caf8ef5532b00d0726de84943 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20=C5=BD=C3=A1dn=C3=ADk?= Date: Sat, 12 Jul 2025 14:25:31 +0300 Subject: [PATCH 15/16] Update case of TODOs Easier for case-sensitive searching. --- src/typechecker.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/typechecker.rs b/src/typechecker.rs index 777f755..379c59a 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -620,7 +620,7 @@ impl<'a> Typechecker<'a> { fn typecheck_binary_op(&mut self, lhs: NodeId, op: NodeId, rhs: NodeId) -> TypeId { self.set_node_type_id(op, FORBIDDEN_TYPE); - // todo better error messages for type mismatches, the previous messages were better + // TODO: better error messages for type mismatches, the previous messages were better match self.compiler.ast_nodes[op.0] { AstNode::Equal | AstNode::NotEqual => { let lhs_ty = self.typecheck_expr(lhs, TOP_TYPE); @@ -755,7 +755,7 @@ impl<'a> Typechecker<'a> { | AstNode::MultiplyAssignment | AstNode::DivideAssignment | AstNode::AppendAssignment => { - // todo actually check if operands are right for operator + // TODO: actually check if operands are right for operator self.typecheck_expr(lhs, TOP_TYPE); self.typecheck_expr(rhs, TOP_TYPE); NONE_TYPE @@ -1208,7 +1208,7 @@ impl<'a> Typechecker<'a> { /// Check if `sub` is a subtype of `supe` /// /// Returns `false` if there is a type mismatch, `true` otherwise - /// todo return a Result with a message about constraints not being solvable or something + /// TODO: return a Result with a message about constraints not being solvable or something fn constrain_subtype(&mut self, sub_id: TypeId, supe_id: TypeId) -> bool { if sub_id == supe_id { return true; @@ -1298,7 +1298,7 @@ impl<'a> Typechecker<'a> { .iter() .all(|ty| self.constrain_subtype(*ty, supe_id)), (_, Type::OneOf(id)) => { - // todo actually add constraints? + // TODO: actually add constraints? self.oneof_types[id.0] .clone() .iter() @@ -1310,7 +1310,7 @@ impl<'a> Typechecker<'a> { } /// Check if `sub` is a subtype of `supe` - /// todo reduce duplication between this and constrain_subtype + /// TODO: reduce duplication between this and constrain_subtype fn is_subtype(&self, sub: TypeId, supe: TypeId) -> bool { if sub == supe { return true; From 93c7dc5b2b269d25b335eb3aefecf40a3411b66f Mon Sep 17 00:00:00 2001 From: ysthakur <45539777+ysthakur@users.noreply.github.com> Date: Sun, 27 Jul 2025 03:17:56 -0400 Subject: [PATCH 16/16] A little documentation for the typechecker --- contributing/typechecking.md | 36 ++++++++++++++++++++++++++++++++++++ src/typechecker.rs | 3 +++ 2 files changed, 39 insertions(+) create mode 100644 contributing/typechecking.md diff --git a/contributing/typechecking.md b/contributing/typechecking.md new file mode 100644 index 0000000..d194ead --- /dev/null +++ b/contributing/typechecking.md @@ -0,0 +1,36 @@ +# Typechecking + +The current typechecker is based on the paper [Complete and Easy Bidirectional Typechecking for Higher-Rank Polymorphism](https://arxiv.org/abs/1306.6032), although we've had to make some modifications because their type system is pretty different from ours. Mainly, we've added gradual typing and subtyping while ignoring the higher-rank polymorphism stuff. + +The following are types used by the typechecker: + +- The top type: the supertype of all types (similar to Java's `Object`, which is the supertype of all reference types) +- The bottom type: the subtype of all types (same as TypeScript's `never`). You can never construct an instance of the bottom type. + - This is useful for indicating that a command will never return (due to throwing an error or looping infinitely). +- The `oneof` type: for creating union types. `oneof` can be used as the type for a value that's either an integer or a string. +- The `allof` type: for creating intersection types. `allof` represents a type that is simultaneously the subtype of `A`, `B`, and `C`. + - This isn't particularly useful in Nushell, so users won't be able to construct `allof` types themselves. + - However, the typechecker does use this internally. + +## How type checking/inference work + +The `typecheck_expr` method takes the expected type of the expression it is currently processing and provides an inferred type for every expression it visits. This way, we both check against the expected type and infer a type for the expression at the same time. Below is an example program using generics. + +```nu +def f [ x: T ] : nothing -> list { + let z = $x + [$z] +} +f 123 # Inferred to be of type list +``` + +For generics to work, the algorithm requires creating and solving type variables. These type variables have a lower bound and an upper bound. As we move through the program, these bounds are tightened further. At the end of the program, the lower bound of each type variable is chosen as its value. + +Every time we come across a call to a custom command with type parameters, we instantiate new type variables corresponding to those type parameters. For example, for the expression `f 123` above, we instantiate a new type variable `'0` with lower bound `bottom` and upper bound `top` (essentially unbounded). Because of the signature of `f`, we know that `123: '0` and `f 123: list<'0>`. + +So we first call `typecheck_expr` to check/infer the type of `123`, providing `'0` as its expected type. Since it's just an integer literal, we infer the type to be `int`. Then, to ensure that this type matches the expected type (`'0`), we call the `constrain_subtype` method to ensure that `int <: '0`. The existing lower bound for `'0` was `bottom`, so we set the new lower bound to `oneof = int`. + +Then we set the type of `f 123` to `list<'0>`. After all expressions have been processed, we replace all occurrences of `'0` with `int`. So this becomes `list`. + +Recursive bounds are not allowed. The bounds for a type variable (say, `'5`) can only refer to type variables created before it (here, `'0` through `'4`). +- If, during typechecking, we find ourselves setting the upper bound of `'5` to `'6`, then we instead set the upper bound of `'5` to whatever the lower bound of `'6` is at that point. This behavior can be improved upon somewhat in the future (in this particular example, we could instead update `'6`'s lower bound to include `'5`), but I think it's good enough for now. diff --git a/src/typechecker.rs b/src/typechecker.rs index 379c59a..10990a1 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -1,3 +1,6 @@ +//! See typechecking.md in the contributing/ folder for more information on +//! how the typechecker works + use crate::compiler::Compiler; use crate::errors::{Severity, SourceError}; use crate::parser::{AstNode, NodeId};