Skip to content

Commit a900831

Browse files
authored
Improve negative narrowing for membership checks on tuples (#21456)
Related to #21411 Fixes #16093
1 parent c0cced3 commit a900831

4 files changed

Lines changed: 111 additions & 31 deletions

File tree

mypy/checker.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6771,25 +6771,45 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
67716771
else_map = {}
67726772

67736773
if left_index in narrowable_operand_index_to_hash:
6774-
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
6775-
if collection_item_type is not None:
6776-
if_map, else_map = self.narrow_type_by_identity_equality(
6777-
"==",
6778-
operands=[operands[left_index], operands[right_index]],
6779-
operand_types=[item_type, collection_item_type],
6780-
expr_indices=[0, 1],
6781-
narrowable_indices={0},
6782-
)
6783-
if else_map and not (
6784-
isinstance(p_typ := get_proper_type(iterable_type), TupleType)
6785-
and all(
6786-
is_singleton_equality_type(get_proper_type(item))
6787-
for item in p_typ.items
6774+
p_iterable_type = get_proper_type(iterable_type)
6775+
if (
6776+
isinstance(p_iterable_type, TupleType)
6777+
and find_unpack_in_list(p_iterable_type.items) is None
6778+
):
6779+
# For some tuples, we can do negative narrowing, e.g. `x not in (None,)`
6780+
all_if_maps = []
6781+
all_else_maps = []
6782+
for known_item in p_iterable_type.items:
6783+
# Match the should_coerce_literals logic from narrow_type_by_identity_equality
6784+
p_known_item = get_proper_type(known_item)
6785+
if is_literal_type_like(p_known_item) or (
6786+
isinstance(p_known_item, Instance) and p_known_item.type.is_enum
6787+
):
6788+
known_item = coerce_to_literal(known_item)
6789+
if_map, else_map = self.narrow_type_by_identity_equality(
6790+
"==",
6791+
operands=[operands[left_index], operands[right_index]],
6792+
operand_types=[item_type, known_item],
6793+
expr_indices=[0, 1],
6794+
narrowable_indices={0},
67886795
)
6789-
):
6790-
# In general, we can't do negative narrowing, since e.g. the container
6791-
# could just be empty. However, we can do negative narrowing for some
6792-
# tuples e.g. `x not in (None,)`
6796+
all_if_maps.append(if_map)
6797+
if is_singleton_equality_type(get_proper_type(known_item)):
6798+
all_else_maps.append(else_map)
6799+
if_map = reduce_or_conditional_type_maps(all_if_maps)
6800+
else_map = reduce_and_conditional_type_maps(all_else_maps, use_meet=True)
6801+
else:
6802+
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
6803+
if collection_item_type is not None:
6804+
if_map, else_map = self.narrow_type_by_identity_equality(
6805+
"==",
6806+
operands=[operands[left_index], operands[right_index]],
6807+
operand_types=[item_type, collection_item_type],
6808+
expr_indices=[0, 1],
6809+
narrowable_indices={0},
6810+
)
6811+
# We can't do negative narrowing, since e.g. the container could
6812+
# just be empty.
67936813
else_map = {}
67946814

67956815
if right_index in narrowable_operand_index_to_hash:

test-data/unit/check-isinstance.test

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,18 +2294,16 @@ def f(x: Optional[int], lst: Optional[List[int]], nested_any: List[List[Any]]) -
22942294

22952295
[case testNarrowTypeAfterInTuple]
22962296
# flags: --warn-unreachable
2297-
from typing import Optional
22982297
class A: pass
22992298
class B(A): pass
23002299
class C(A): pass
23012300

2302-
y: Optional[B]
2303-
if y in (B(), C()):
2304-
reveal_type(y) # N: Revealed type is "__main__.B"
2305-
else:
2306-
reveal_type(y) # N: Revealed type is "__main__.B | None"
2301+
def f(y: B | None):
2302+
if y in (B(), C()):
2303+
reveal_type(y) # N: Revealed type is "__main__.B"
2304+
else:
2305+
reveal_type(y) # N: Revealed type is "__main__.B | None"
23072306
[builtins fixtures/tuple.pyi]
2308-
[out]
23092307

23102308
[case testNarrowTypeAfterInNamedTuple]
23112309
# flags: --warn-unreachable

test-data/unit/check-narrowing.test

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3201,7 +3201,7 @@ class X:
32013201
[builtins fixtures/dict.pyi]
32023202

32033203

3204-
[case testTypeNarrowingStringInLiteralContainer]
3204+
[case testNarrowStringInLiteralContainer]
32053205
# flags: --strict-equality --warn-unreachable
32063206
from typing import Literal
32073207

@@ -3235,6 +3235,69 @@ def narrow_set(x: str, t: set[Literal['a', 'b']]):
32353235
reveal_type(x) # N: Revealed type is "builtins.str"
32363236
[builtins fixtures/primitives.pyi]
32373237

3238+
[case testNarrowLiteralInLiteralContainer]
3239+
# flags: --strict-equality --warn-unreachable
3240+
from typing import Literal
3241+
3242+
def narrow_tuple_exact(x: Literal['a', 'b', 'c'], t: tuple[Literal['a'], Literal['b']]):
3243+
if x in t:
3244+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
3245+
else:
3246+
reveal_type(x) # N: Revealed type is "Literal['c']"
3247+
3248+
if x not in t:
3249+
reveal_type(x) # N: Revealed type is "Literal['c']"
3250+
else:
3251+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
3252+
3253+
def narrow_tuple_expression(x: Literal['a', 'b', 'c']):
3254+
# TODO: this should match narrow_tuple_exact
3255+
if x in ('a', 'b'):
3256+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
3257+
else:
3258+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
3259+
3260+
if x not in ('a', 'b'):
3261+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
3262+
else:
3263+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
3264+
3265+
def narrow_tuple_union(x: Literal['a', 'b', 'c'], t: tuple[Literal['a', 'b']]):
3266+
if x in t:
3267+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
3268+
else:
3269+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
3270+
3271+
if x not in t:
3272+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
3273+
else:
3274+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
3275+
3276+
def narrow_tuple_with_other_type(x: Literal['a', 'b', 'c'], t: tuple[Literal['a'], int]):
3277+
if x in t:
3278+
reveal_type(x) # N: Revealed type is "Literal['a']"
3279+
else:
3280+
reveal_type(x) # N: Revealed type is "Literal['b'] | Literal['c']"
3281+
3282+
def narrow_homo_tuple(x: Literal['a', 'b', 'c'], t: tuple[Literal['a', 'b'], ...]):
3283+
if x in t:
3284+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
3285+
else:
3286+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
3287+
3288+
def narrow_list(x: Literal['a', 'b', 'c'], t: list[Literal['a', 'b']]):
3289+
if x in t:
3290+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
3291+
else:
3292+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
3293+
3294+
def narrow_set(x: Literal['a', 'b', 'c'], t: set[Literal['a', 'b']]):
3295+
if x in t:
3296+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
3297+
else:
3298+
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
3299+
[builtins fixtures/primitives.pyi]
3300+
32383301

32393302
[case testNarrowingLiteralInLiteralContainer]
32403303
# flags: --strict-equality --warn-unreachable

test-data/unit/check-typevar-tuple.test

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,14 +2145,13 @@ match(b) # E: Argument 1 to "match" has incompatible type "Bad"; expected "PC[U
21452145
[builtins fixtures/tuple.pyi]
21462146

21472147
[case testVariadicTupleCollectionCheck]
2148-
from typing import Tuple, Optional
21492148
from typing_extensions import Unpack
21502149

2151-
allowed: Tuple[int, Unpack[Tuple[int, ...]]]
2150+
allowed: tuple[int, Unpack[tuple[int, ...]]]
21522151

2153-
x: Optional[int]
2154-
if x in allowed:
2155-
reveal_type(x) # N: Revealed type is "builtins.int"
2152+
def f(x: int | None):
2153+
if x in allowed:
2154+
reveal_type(x) # N: Revealed type is "builtins.int"
21562155
[builtins fixtures/tuple.pyi]
21572156

21582157
[case testJoinOfVariadicTupleCallablesNoCrash]

0 commit comments

Comments
 (0)