diff --git a/crates/pyrefly_types/src/types.rs b/crates/pyrefly_types/src/types.rs index 3fb01b3379..e139017ecc 100644 --- a/crates/pyrefly_types/src/types.rs +++ b/crates/pyrefly_types/src/types.rs @@ -1851,6 +1851,7 @@ impl Type { Type::Literal(lit) if let Lit::Int(x) = &lit.value => Some(x.as_bool()), Type::Literal(lit) if let Lit::Bytes(x) = &lit.value => Some(!x.is_empty()), Type::Literal(lit) if let Lit::Str(x) = &lit.value => Some(!x.is_empty()), + Type::Type(_) => Some(true), Type::None => Some(false), Type::Tuple(Tuple::Concrete(elements)) => Some(!elements.is_empty()), Type::Union(box Union { members, .. }) => { diff --git a/pyrefly/lib/alt/expr.rs b/pyrefly/lib/alt/expr.rs index 48cb021c05..7b3a12d86a 100644 --- a/pyrefly/lib/alt/expr.rs +++ b/pyrefly/lib/alt/expr.rs @@ -77,6 +77,7 @@ use crate::binding::binding::Key; use crate::binding::binding::KeyYield; use crate::binding::binding::KeyYieldFrom; use crate::binding::binding::LambdaParamId; +use crate::binding::narrow::AtomicNarrowOp; use crate::binding::narrow::int_from_slice; use crate::config::error_kind::ErrorKind; use crate::error::collector::ErrorCollector; @@ -89,7 +90,6 @@ use crate::types::callable::Params; use crate::types::callable::Required; use crate::types::class::Class; use crate::types::facet::FacetKind; -use crate::types::lit_int::LitInt; use crate::types::literal::Lit; use crate::types::param_spec::ParamSpec; use crate::types::quantified::Quantified; @@ -1231,6 +1231,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { BoolOp::And => false, BoolOp::Or => true, }; + let result_narrow = match op { + BoolOp::And => AtomicNarrowOp::IsFalsy, + BoolOp::Or => AtomicNarrowOp::IsTruthy, + }; let should_shortcircuit = |t: &Type, r: TextRange| self.as_bool(t, r, errors) == Some(target); let should_discard = |t: &Type, r: TextRange| self.as_bool(t, r, errors) == Some(!target); @@ -1264,21 +1268,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { None => t.clone(), Some(acc) => self.union(acc, t.clone()), }); - // Narrow the type for the result of the boolop - let t = if i != last_index - && t == self.heap.mk_class_type(self.stdlib.bool().clone()) - { - Lit::Bool(target).to_implicit_type() - } else if i != last_index - && t == self.heap.mk_class_type(self.stdlib.int().clone()) - && !target - { - LitInt::new(0).to_implicit_type() - } else if i != last_index - && t == self.heap.mk_class_type(self.stdlib.str().clone()) - && !target - { - Lit::Str(Default::default()).to_implicit_type() + let t = if i != last_index { + self.atomic_narrow(&t, &result_narrow, value.range(), errors) } else { t }; diff --git a/pyrefly/lib/alt/narrow.rs b/pyrefly/lib/alt/narrow.rs index bb1e8fc57a..ff57f4a458 100644 --- a/pyrefly/lib/alt/narrow.rs +++ b/pyrefly/lib/alt/narrow.rs @@ -885,7 +885,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } - fn atomic_narrow( + pub(crate) fn atomic_narrow( &self, ty: &Type, op: &AtomicNarrowOp, diff --git a/pyrefly/lib/test/literal.rs b/pyrefly/lib/test/literal.rs index 7436fedf3e..c196003935 100644 --- a/pyrefly/lib/test/literal.rs +++ b/pyrefly/lib/test/literal.rs @@ -378,6 +378,18 @@ def f(x1: list[str], x2: list[LiteralString]): "#, ); +testcase!( + test_str_join_boolop_narrowing, + r#" +from typing import assert_type + +def format_types(types: set[type | None]) -> str: + values = sorted((e and e.__name__) or "None" for e in types) + assert_type(values, list[str]) + return ", ".join(values) + "#, +); + testcase!( test_giant_literal_string, r#" diff --git a/pyrefly/lib/test/lsp/lsp_interaction/pytorch_benchmark.rs b/pyrefly/lib/test/lsp/lsp_interaction/pytorch_benchmark.rs index cf56f84412..62429ffd53 100644 --- a/pyrefly/lib/test/lsp/lsp_interaction/pytorch_benchmark.rs +++ b/pyrefly/lib/test/lsp/lsp_interaction/pytorch_benchmark.rs @@ -48,7 +48,7 @@ fn test_pytorch_error_propagation_latency() { }; // Use all available cores for realistic benchmarking let mut interaction = - LspInteraction::new_with_args(args, NoTelemetry, Some(ThreadCount::AllThreads)); + LspInteraction::new_with_args(args, NoTelemetry, Some(ThreadCount::AllThreads), None); interaction.set_root(pytorch_root.clone()); interaction