From e2327bcb5146e072a5329d170d4d85b7ee6b5909 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Fri, 2 Jan 2026 14:07:22 +0100 Subject: [PATCH 1/5] Update pyrefly --- .pyre_configuration | 17 ----------------- pyproject.toml | 15 +++++++++++---- uv.lock | 26 +++++++++++++------------- 3 files changed, 24 insertions(+), 34 deletions(-) delete mode 100644 .pyre_configuration diff --git a/.pyre_configuration b/.pyre_configuration deleted file mode 100644 index 6c670b47..00000000 --- a/.pyre_configuration +++ /dev/null @@ -1,17 +0,0 @@ -{ - "site_package_search_strategy": "pep561", - "source_directories": [ - { - "import_root": ".", - "source": "sc2" - }, - { - "import_root": ".", - "source": "examples" - }, - { - "import_root": ".", - "source": "test" - } - ] -} diff --git a/pyproject.toml b/pyproject.toml index d3b38928..3187f86d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ dev = [ "pyglet>=2.0.20", "pylint>=3.3.2", # Type checker - "pyrefly>=0.21.0", + "pyrefly>=0.40.0", "pytest>=8.3.4", "pytest-asyncio>=0.25.0", "pytest-benchmark>=5.1.0", @@ -93,7 +93,14 @@ dedent_closing_brackets = true allow_split_before_dict_value = false [tool.pyrefly] -project_includes = ["sc2", "examples", "test"] +project_includes = [ + "sc2", + # "examples", + # "test" +] + +[tool.pyrefly.errors] +bad-override = false [tool.ruff] target-version = 'py310' @@ -132,11 +139,11 @@ ignore = [ "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` ] -[tool.ruff.pyupgrade] +[tool.pyupgrade] # Preserve types, even if a file imports `from __future__ import annotations`. # Remove once support for py3.8 and 3.9 is dropped keep-runtime-typing = true -[tool.ruff.pep8-naming] +[tool.pep8-naming] # Allow Pydantic's `@validator` decorator to trigger class method treatment. classmethod-decorators = ["pydantic.validator", "classmethod"] diff --git a/uv.lock b/uv.lock index e7ee29bc..454fd21f 100644 --- a/uv.lock +++ b/uv.lock @@ -304,7 +304,7 @@ dev = [ { name = "pre-commit", specifier = ">=4.0.1" }, { name = "pyglet", specifier = ">=2.0.20" }, { name = "pylint", specifier = ">=3.3.2" }, - { name = "pyrefly", specifier = ">=0.21.0" }, + { name = "pyrefly", specifier = ">=0.40.0" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-asyncio", specifier = ">=0.25.0" }, { name = "pytest-benchmark", specifier = ">=5.1.0" }, @@ -2522,18 +2522,18 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.21.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/46/3f/8a30ed93cb027a18080d7e670b2bbf14135b031fe74443eaa850494d9aa8/pyrefly-0.21.0.tar.gz", hash = "sha256:e05a083047dcba25e730c7e0c70b3dc48ba420f17ef73265f169bc95f487a99d", size = 1056016, upload-time = "2025-06-23T17:45:22.033Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/dd/5b1a4a3a713be65e2af02563f8baa70ea69d594d681cb1c36b38319eee90/pyrefly-0.21.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b44e6172421de12fa14d380bcd416fa64d474aec9f1829e6985b18471b1fcee1", size = 5820971, upload-time = "2025-06-23T17:45:04.928Z" }, - { url = "https://files.pythonhosted.org/packages/5b/40/df55322e761b798903c951ad5699585046f9e926e6b3b6686cb0056024a4/pyrefly-0.21.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0d7ea3b86ac3c8680b290389ca3706cbbb046899247a963ac7384c57880e6f2f", size = 5404314, upload-time = "2025-06-23T17:45:07.007Z" }, - { url = "https://files.pythonhosted.org/packages/0b/a5/b3d526bf75ab8708cc85cd0db571af0179b566964fe2c9aae717f3a03090/pyrefly-0.21.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:101ecb84b0442e85a92bcb15a2dfd844f7a6abae8f980661849c71102447443e", size = 5611017, upload-time = "2025-06-23T17:45:08.989Z" }, - { url = "https://files.pythonhosted.org/packages/f6/02/4d4b0ddade7e2980a13e14062770535666a84dfab3293c33e154dfff6ae1/pyrefly-0.21.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b934004ddcdcbe55efeb732d0c0155a781a184e374fec9cd031e819ac1b92eff", size = 6285992, upload-time = "2025-06-23T17:45:11.991Z" }, - { url = "https://files.pythonhosted.org/packages/29/7c/2c3922ee3bdd82a827a893a80aeddf119e38ce8406035a6eda6ae480885e/pyrefly-0.21.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c66446bc5f7e912dab923adf93f636307c834ab53ddd9422e56d101910d960a", size = 6066455, upload-time = "2025-06-23T17:45:13.587Z" }, - { url = "https://files.pythonhosted.org/packages/e1/c8/c8d391f3535046cb79154d4dfcc01b0d022d5af2701627cbe9a518dd50bc/pyrefly-0.21.0-py3-none-win32.whl", hash = "sha256:a4cb8acf2dc831759cb43fa0326e07085cb21da7202212e70d9478bfc40b7a28", size = 5569924, upload-time = "2025-06-23T17:45:15.36Z" }, - { url = "https://files.pythonhosted.org/packages/85/79/58c94192acbc234ff9290c22063bf8f11d65a7b61eb61f9260fa6fb4b1f3/pyrefly-0.21.0-py3-none-win_amd64.whl", hash = "sha256:765d19f2b48d5dd3dae0752676e2d6e388025fa6337947031ff628160b8cf568", size = 5946646, upload-time = "2025-06-23T17:45:17.433Z" }, - { url = "https://files.pythonhosted.org/packages/75/6e/d476584e93e3c63609dde26a57bc23f732c22b0e9bd17462282268aec758/pyrefly-0.21.0-py3-none-win_arm64.whl", hash = "sha256:a3fc4fffb625a5610b68fc3bf07e4d33b5b1c279512e0251b3d3d4f7c3ed1541", size = 5595515, upload-time = "2025-06-23T17:45:19.4Z" }, +version = "0.46.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/67/c2/f92ab45699f4e2ce2291fd81ec183a3f96ec6c72a1d03056644fdf4aa702/pyrefly-0.46.3.tar.gz", hash = "sha256:6aeb90698b587bba38ec870a515cf3499756fc81d73852fd11eaa10abda0fea6", size = 4769402, upload-time = "2025-12-30T23:58:16.473Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/ae/8257e96b2e6396880280e96b1ccda0242f19bbedbcc933443d6f56f81843/pyrefly-0.46.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5626d345aa829dc17311c77c8019dfc7b9b6dd8da102b66d943d47862af2be59", size = 11662076, upload-time = "2025-12-30T23:57:55.765Z" }, + { url = "https://files.pythonhosted.org/packages/e5/7a/aba0dd3b0f9cb50fb3b39992960c1e04ae3498346c448ae041bff0d26337/pyrefly-0.46.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:13b553090ec9821a781198389564bcb3cb020679189272b1e88fddba9d613879", size = 11269927, upload-time = "2025-12-30T23:57:58.362Z" }, + { url = "https://files.pythonhosted.org/packages/17/82/2a3ad9107229893207979d25b5eb98d8b71d3f43b28f349286b3b6630514/pyrefly-0.46.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ad8612c2b8eaa36f3953ab3be8635bd91fd8179ddbdd9eb8bb4f1ad513ff3f0", size = 31501305, upload-time = "2025-12-30T23:58:00.998Z" }, + { url = "https://files.pythonhosted.org/packages/11/10/b5f2c1eea63bee42f45480a17a72c31ca60b6c15024f1085459f4cd5d638/pyrefly-0.46.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:873d0c6a71cdd0399e1c6de1a3edbab34c7a9ae48e7a9e9ac327b1010d8686a1", size = 33721653, upload-time = "2025-12-30T23:58:03.703Z" }, + { url = "https://files.pythonhosted.org/packages/af/7c/fe81b2c7e9e7edfe5ea9db00c79439e31c7110faad021ab451734722525e/pyrefly-0.46.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fd55f0d864ead658607dc3be6649019cfb5d1c72ceb0818b4b3501e0eb81472", size = 34766036, upload-time = "2025-12-30T23:58:06.613Z" }, + { url = "https://files.pythonhosted.org/packages/d1/c7/eff2558569533a11d1a21d73a366a85483bd38d80a12af1c7381218a6b62/pyrefly-0.46.3-py3-none-win32.whl", hash = "sha256:28e11eca9461fc892b19bdad799e280ab81cc00f5a712f89f7bc6b2e41a70194", size = 10738486, upload-time = "2025-12-30T23:58:09.316Z" }, + { url = "https://files.pythonhosted.org/packages/be/84/8f545321cc0bb555992d4e4af0905dc907ef3e9e864d68c5504a46560bc6/pyrefly-0.46.3-py3-none-win_amd64.whl", hash = "sha256:1eb510fa62960bc30137e39360ef68b7c691eb28a6c958560088bc99595d63be", size = 11426143, upload-time = "2025-12-30T23:58:11.546Z" }, + { url = "https://files.pythonhosted.org/packages/c6/1d/5f9b3f6eba2a90a7859dc21905248f78f23d2bdd9e13cec791acb544a4b4/pyrefly-0.46.3-py3-none-win_arm64.whl", hash = "sha256:bb5db31d7781edc3a590fec380d50832faecfc57171c054512009c8254d5ad0f", size = 10977207, upload-time = "2025-12-30T23:58:13.812Z" }, ] [[package]] From 724da69a5ed7776b67e6beec4eb6d0a243d38b9c Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Fri, 2 Jan 2026 14:09:23 +0100 Subject: [PATCH 2/5] Remove all pyre-ignore comments --- examples/arcade_bot.py | 2 -- examples/competitive/bot.py | 1 - examples/competitive/run.py | 1 - examples/distributed_workers.py | 1 - examples/protoss/find_adept_shades.py | 4 +--- examples/terran/onebase_battlecruiser.py | 1 - examples/worker_stack_bot.py | 4 ++-- examples/zerg/zerg_rush.py | 1 - sc2/bot_ai.py | 5 +---- sc2/bot_ai_internal.py | 3 +-- sc2/cache.py | 4 ++-- sc2/client.py | 3 +-- sc2/constants.py | 3 +-- sc2/data.py | 3 +-- sc2/expiring_dict.py | 2 -- sc2/game_data.py | 6 +----- sc2/game_info.py | 11 +++++------ sc2/game_state.py | 5 ++--- sc2/generate_ids.py | 2 +- sc2/ids/__init__.py | 1 - sc2/ids/ability_id.py | 1 - sc2/ids/buff_id.py | 1 - sc2/ids/effect_id.py | 1 - sc2/ids/unit_typeid.py | 1 - sc2/ids/upgrade_id.py | 1 - sc2/main.py | 1 - sc2/observer_ai.py | 1 - sc2/player.py | 1 - sc2/proxy.py | 1 - sc2/unit.py | 6 ++++-- sc2/units.py | 1 - sc2/wsl.py | 7 ++++--- test/autotest_bot.py | 1 - test/benchmark_distance_two_points.py | 1 - test/benchmark_distances_cdist.py | 1 - test/benchmark_distances_points_to_point.py | 1 - test/benchmark_distances_units.py | 1 - test/damagetest_bot.py | 1 - test/generate_pickle_files_bot.py | 1 - test/queries_test_bot.py | 1 - test/real_time_worker_production.py | 1 - test/run_example_bots_vs_computer.py | 2 -- test/run_example_bots_vs_each_other.py | 2 -- test/test_pickled_data.py | 1 - test/test_pickled_ramp.py | 6 +++--- test/travis_test_script.py | 6 ++---- test/upgradestest_bot.py | 1 - 47 files changed, 32 insertions(+), 81 deletions(-) diff --git a/examples/arcade_bot.py b/examples/arcade_bot.py index 811ee944..0c85286d 100644 --- a/examples/arcade_bot.py +++ b/examples/arcade_bot.py @@ -101,7 +101,6 @@ def position_around_unit( step_size: int = 1, exclude_out_of_bounds: bool = True, ): - # pyre-ignore[16] pos = pos.position.rounded positions = { pos.offset(Point2((x, y))) @@ -114,7 +113,6 @@ def position_around_unit( positions = { p for p in positions - # pyre-ignore[16] if 0 <= p[0] < self.game_info.pathing_grid.width and 0 <= p[1] < self.game_info.pathing_grid.height } return positions diff --git a/examples/competitive/bot.py b/examples/competitive/bot.py index 253337b6..f575ae79 100644 --- a/examples/competitive/bot.py +++ b/examples/competitive/bot.py @@ -11,7 +11,6 @@ async def on_step(self, iteration: int): # Populate this function with whatever your bot should do! pass - # pyre-ignore[11] async def on_end(self, game_result: Result): print("Game ended.") # Do things here after the game ends diff --git a/examples/competitive/run.py b/examples/competitive/run.py index 10984103..48d7fce3 100644 --- a/examples/competitive/run.py +++ b/examples/competitive/run.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[16, 21] import sys from __init__ import run_ladder_game diff --git a/examples/distributed_workers.py b/examples/distributed_workers.py index 9e7940e5..fa54588c 100644 --- a/examples/distributed_workers.py +++ b/examples/distributed_workers.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[16] from sc2 import maps from sc2.bot_ai import BotAI from sc2.data import Difficulty, Race diff --git a/examples/protoss/find_adept_shades.py b/examples/protoss/find_adept_shades.py index d10896d3..9a51af33 100644 --- a/examples/protoss/find_adept_shades.py +++ b/examples/protoss/find_adept_shades.py @@ -26,7 +26,6 @@ async def on_step(self, iteration: int): if adepts and not self.shaded: # Wait for adepts to spawn and then cast ability for adept in adepts: - # pyre-ignore[16] adept(AbilityId.ADEPTPHASESHIFT_ADEPTPHASESHIFT, self.game_info.map_center) self.shaded = True elif self.shades_mapping: @@ -38,7 +37,6 @@ async def on_step(self, iteration: int): # logger.info(f"Remaining shade time: {shade.buff_duration_remain} / {shade.buff_duration_max}") pass if adept and shade: - # pyre-ignore[16] self.client.debug_line_out(adept, shade, (0, 255, 0)) # logger.info(self.shades_mapping) elif self.shaded: @@ -53,7 +51,7 @@ async def on_step(self, iteration: int): previous_shade_location = shade.position.towards( forward_position, -(self.client.game_step / 16) * shade.movement_speed ) # See docstring of movement_speed attribute - # pyre-ignore[6] + closest_adept = remaining_adepts.closest_to(previous_shade_location) self.shades_mapping[closest_adept.tag] = shade.tag diff --git a/examples/terran/onebase_battlecruiser.py b/examples/terran/onebase_battlecruiser.py index 47cd5f62..41ce9e43 100644 --- a/examples/terran/onebase_battlecruiser.py +++ b/examples/terran/onebase_battlecruiser.py @@ -24,7 +24,6 @@ def select_target(self) -> tuple[Point2, bool]: return targets.random.position, True if self.units and min(u.position.distance_to(self.enemy_start_locations[0]) for u in self.units) < 5: - # pyre-ignore[7] return self.enemy_start_locations[0].position, False return self.mineral_field.random.position, False diff --git a/examples/worker_stack_bot.py b/examples/worker_stack_bot.py index 0a0cbbb2..805b24d6 100644 --- a/examples/worker_stack_bot.py +++ b/examples/worker_stack_bot.py @@ -89,7 +89,7 @@ async def on_step(self, iteration: int): # Move worker in front of the nexus to avoid deceleration until the last moment if worker.distance_to(th) > th.radius + worker.radius + self.townhall_distance_threshold: pos: Point2 = th.position - # pyre-ignore[6] + worker.move(pos.towards(worker, th.radius * self.townhall_distance_factor)) worker.return_resource(queue=True) else: @@ -97,7 +97,7 @@ async def on_step(self, iteration: int): worker.gather(mineral, queue=True) # Print info every 30 game-seconds - # pyre-ignore[16] + if self.state.game_loop % (22.4 * 30) == 0: logger.info(f"{self.time_formatted} Mined a total of {int(self.state.score.collected_minerals)} minerals") diff --git a/examples/zerg/zerg_rush.py b/examples/zerg/zerg_rush.py index 15d0df50..e8a6fedb 100644 --- a/examples/zerg/zerg_rush.py +++ b/examples/zerg/zerg_rush.py @@ -136,7 +136,6 @@ def draw_creep_pixelmap(self): color = Point3((0, 255, 0)) self.client.debug_box2_out(pos, half_vertex_length=0.25, color=color) - # pyre-ignore[11] async def on_end(self, game_result: Result): self.on_end_called = True logger.info(f"{self.time_formatted} On end was called") diff --git a/sc2/bot_ai.py b/sc2/bot_ai.py index d98e72a0..0a5d10e5 100644 --- a/sc2/bot_ai.py +++ b/sc2/bot_ai.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[6, 16] from __future__ import annotations import math @@ -71,7 +70,6 @@ def step_time(self) -> tuple[float, float, float, float]: self._last_step_step_time * 1000, ) - # pyre-ignore[11] def alert(self, alert_code: Alert) -> bool: """ Check if alert is triggered in the current step. @@ -1125,7 +1123,7 @@ def research(self, upgrade_type: UpgradeId) -> bool: return False research_structure_type: UnitTypeId = UPGRADE_RESEARCHED_FROM[upgrade_type] - # pyre-ignore[9] + required_tech_building: UnitTypeId | None = RESEARCH_INFO[research_structure_type][upgrade_type].get( "required_building", None ) @@ -1368,7 +1366,6 @@ async def on_step(self, iteration: int): """ raise NotImplementedError - # pyre-ignore[11] async def on_end(self, game_result: Result) -> None: """Override this in your bot class. This function is called at the end of a game. Unsure if this function will be called on the laddermanager client as the bot process may forcefully be terminated. diff --git a/sc2/bot_ai_internal.py b/sc2/bot_ai_internal.py index d2bde3f7..da0dcb7b 100644 --- a/sc2/bot_ai_internal.py +++ b/sc2/bot_ai_internal.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[6, 16, 29] from __future__ import annotations import itertools @@ -103,7 +102,7 @@ def _initialize_variables(self) -> None: self.warp_gate_count: int = 0 self.actions: list[UnitCommand] = [] self.blips: set[Blip] = set() - # pyre-ignore[11] + self.race: Race | None = None self.enemy_race: Race | None = None self._generated_frame = -100 diff --git a/sc2/cache.py b/sc2/cache.py index b1682fc5..1fcbd8ca 100644 --- a/sc2/cache.py +++ b/sc2/cache.py @@ -36,13 +36,13 @@ def __init__(self, func: Callable[[BotAI], T], name: str | None = None) -> None: def __set__(self, obj: BotAI, value: T) -> None: obj.cache[self.__name__] = value - # pyre-ignore[16] + obj.cache[self.__frame__] = obj.state.game_loop # pyre-fixme[34] def __get__(self, obj: BotAI, _type=None) -> T: value = obj.cache.get(self.__name__, None) - # pyre-ignore[16] + bot_frame = obj.state.game_loop if value is None or obj.cache[self.__frame__] < bot_frame: value = self.func(obj) diff --git a/sc2/client.py b/sc2/client.py index 8971b86f..4fe0cccc 100644 --- a/sc2/client.py +++ b/sc2/client.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[6, 9, 16, 29, 58] from __future__ import annotations from collections.abc import Iterable @@ -792,7 +791,7 @@ def to_debug_color(color: tuple[float, float, float] | list[float] | Point3 | No r = getattr(color, "r", getattr(color, "x", 255)) g = getattr(color, "g", getattr(color, "y", 255)) b = getattr(color, "b", getattr(color, "z", 255)) - # pyre-ignore[20] + if max(r, g, b) <= 1: r *= 255 g *= 255 diff --git a/sc2/constants.py b/sc2/constants.py index 6478ff01..a3be87c3 100644 --- a/sc2/constants.py +++ b/sc2/constants.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[16] from __future__ import annotations from collections import defaultdict @@ -494,7 +493,7 @@ def return_NOTAUNIT() -> UnitTypeId: UnitTypeId.EXTRACTOR, UnitTypeId.EXTRACTORRICH, } -# pyre-ignore[11] + DAMAGE_BONUS_PER_UPGRADE: dict[UnitTypeId, dict[int, Any]] = { # # Protoss diff --git a/sc2/data.py b/sc2/data.py index d376138b..e36bddd0 100644 --- a/sc2/data.py +++ b/sc2/data.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[16, 19] """For the list of enums, see here https://github.com/Blizzard/s2client-proto/tree/bff45dae1fc685e6acbaae084670afb7d1c0832c/s2clientprotocol @@ -38,7 +37,7 @@ ActionResult = enum.Enum("ActionResult", error_pb.ActionResult.items()) -# pyre-ignore[11] + race_worker: dict[Race, UnitTypeId] = { Race.Protoss: UnitTypeId.PROBE, Race.Terran: UnitTypeId.SCV, diff --git a/sc2/expiring_dict.py b/sc2/expiring_dict.py index 7c6cc2c0..fd14d2a0 100644 --- a/sc2/expiring_dict.py +++ b/sc2/expiring_dict.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[14, 15, 58] from __future__ import annotations from collections import OrderedDict @@ -42,7 +41,6 @@ def __init__(self, bot: BotAI, max_age_frames: int = 1) -> None: @property def frame(self) -> int: - # pyre-ignore[16] return self.bot.state.game_loop def __contains__(self, key: Hashable) -> bool: diff --git a/sc2/game_data.py b/sc2/game_data.py index 4be84ee2..4283892c 100644 --- a/sc2/game_data.py +++ b/sc2/game_data.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[29] from __future__ import annotations from bisect import bisect_left @@ -49,7 +48,6 @@ def calculate_ability_cost(self, ability: AbilityData | AbilityId | UnitCommand) if not AbilityData.id_exists(unit.creation_ability.id.value): continue - # pyre-ignore[16] if unit.creation_ability.is_free_morph: continue @@ -265,9 +263,7 @@ def morph_cost(self) -> Cost | None: self._game_data.units[tech_alias.value].cost.minerals for tech_alias in self.tech_alias ) tech_alias_cost_vespene = max( - self._game_data.units[tech_alias.value].cost.vespene - # pyre-ignore[16] - for tech_alias in self.tech_alias + self._game_data.units[tech_alias.value].cost.vespene for tech_alias in self.tech_alias ) return Cost( self._proto.mineral_cost - tech_alias_cost_minerals, diff --git a/sc2/game_info.py b/sc2/game_info.py index aab025d5..f9d1f0d0 100644 --- a/sc2/game_info.py +++ b/sc2/game_info.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[6, 11, 16, 58] from __future__ import annotations import heapq @@ -174,7 +173,7 @@ def protoss_wall_pylon(self) -> Point2 | None: middle = self.depot_in_middle # direction up the ramp direction = self.barracks_in_middle.negative_offset(middle) - # pyre-ignore[7] + return middle + 6 * direction @cached_property @@ -223,7 +222,7 @@ def __init__(self, proto: sc2api_pb2.ResponseGameInfo) -> None: self.players: list[Player] = [Player.from_proto(p) for p in self._proto.player_info] self.map_name: str = self._proto.map_name self.local_map_path: str = self._proto.local_map_path - # pyre-ignore[8] + self.map_size: Size = Size.from_proto(self._proto.start_raw.map_size) # self.pathing_grid[point]: if 0, point is not pathable, if 1, point is pathable @@ -234,9 +233,9 @@ def __init__(self, proto: sc2api_pb2.ResponseGameInfo) -> None: self.placement_grid: PixelMap = PixelMap(self._proto.start_raw.placement_grid, in_bits=True) self.playable_area = Rect.from_proto(self._proto.start_raw.playable_area) self.map_center = self.playable_area.center - # pyre-ignore[8] + self.map_ramps: list[Ramp] = None # Filled later by BotAI._prepare_first_step - # pyre-ignore[8] + self.vision_blockers: frozenset[Point2] = None # Filled later by BotAI._prepare_first_step self.player_races: dict[int, int] = { p.player_id: p.race_actual or p.race_requested for p in self._proto.player_info @@ -244,7 +243,7 @@ def __init__(self, proto: sc2api_pb2.ResponseGameInfo) -> None: self.start_locations: list[Point2] = [ Point2.from_proto(sl).round(decimals=1) for sl in self._proto.start_raw.start_locations ] - # pyre-ignore[8] + self.player_start_location: Point2 = None # Filled later by BotAI._prepare_first_step def _find_ramps_and_vision_blockers(self) -> tuple[list[Ramp], frozenset[Point2]]: diff --git a/sc2/game_state.py b/sc2/game_state.py index 7ff4bb89..d9f37448 100644 --- a/sc2/game_state.py +++ b/sc2/game_state.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[11, 16] from __future__ import annotations from dataclasses import dataclass @@ -321,7 +320,7 @@ def actions_unit_commands(self) -> list[ActionRawUnitCommand]: List of successful unit actions since last frame. See https://github.com/Blizzard/s2client-proto/blob/01ab351e21c786648e4c6693d4aad023a176d45c/s2clientprotocol/raw.proto#L185-L193 """ - # pyre-ignore[7] + return list(filter(lambda action: isinstance(action, ActionRawUnitCommand), self.actions)) @cached_property @@ -330,7 +329,7 @@ def actions_toggle_autocast(self) -> list[ActionRawToggleAutocast]: List of successful autocast toggle actions since last frame. See https://github.com/Blizzard/s2client-proto/blob/01ab351e21c786648e4c6693d4aad023a176d45c/s2clientprotocol/raw.proto#L199-L202 """ - # pyre-ignore[7] + return list(filter(lambda action: isinstance(action, ActionRawToggleAutocast), self.actions)) @cached_property diff --git a/sc2/generate_ids.py b/sc2/generate_ids.py index 1fd4e7d6..c5c27956 100644 --- a/sc2/generate_ids.py +++ b/sc2/generate_ids.py @@ -25,7 +25,7 @@ def __init__( self.game_version = game_version self.verbose = verbose - self.HEADER = f"""# pyre-ignore-all-errors[14] + self.HEADER = f""" from __future__ import annotations # DO NOT EDIT! # This file was automatically generated by "{Path(__file__).name}" diff --git a/sc2/ids/__init__.py b/sc2/ids/__init__.py index 24b6bc2a..a69ff863 100644 --- a/sc2/ids/__init__.py +++ b/sc2/ids/__init__.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[14] from __future__ import annotations # DO NOT EDIT! diff --git a/sc2/ids/ability_id.py b/sc2/ids/ability_id.py index 955be493..d895fa90 100644 --- a/sc2/ids/ability_id.py +++ b/sc2/ids/ability_id.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[14] from __future__ import annotations # DO NOT EDIT! diff --git a/sc2/ids/buff_id.py b/sc2/ids/buff_id.py index 5a7345a8..632965b1 100644 --- a/sc2/ids/buff_id.py +++ b/sc2/ids/buff_id.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[14] from __future__ import annotations # DO NOT EDIT! diff --git a/sc2/ids/effect_id.py b/sc2/ids/effect_id.py index 77aea24b..7bef1546 100644 --- a/sc2/ids/effect_id.py +++ b/sc2/ids/effect_id.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[14] from __future__ import annotations # DO NOT EDIT! diff --git a/sc2/ids/unit_typeid.py b/sc2/ids/unit_typeid.py index ec74ebe8..ab0facdf 100644 --- a/sc2/ids/unit_typeid.py +++ b/sc2/ids/unit_typeid.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[14] from __future__ import annotations # DO NOT EDIT! diff --git a/sc2/ids/upgrade_id.py b/sc2/ids/upgrade_id.py index 4be6cbfe..0f411d5e 100644 --- a/sc2/ids/upgrade_id.py +++ b/sc2/ids/upgrade_id.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[14] from __future__ import annotations # DO NOT EDIT! diff --git a/sc2/main.py b/sc2/main.py index 8d07314e..89996f5f 100644 --- a/sc2/main.py +++ b/sc2/main.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[6, 11, 16, 21, 29] from __future__ import annotations import asyncio diff --git a/sc2/observer_ai.py b/sc2/observer_ai.py index 7cd23c99..814fe455 100644 --- a/sc2/observer_ai.py +++ b/sc2/observer_ai.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[6, 11, 16] """ This class is very experimental and probably not up to date and needs to be refurbished. If it works, you can watch replays with it. diff --git a/sc2/player.py b/sc2/player.py index bd1410a5..646aabe8 100644 --- a/sc2/player.py +++ b/sc2/player.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[6, 11, 16, 29] from __future__ import annotations from abc import ABC diff --git a/sc2/proxy.py b/sc2/proxy.py index 340570cd..c022360d 100644 --- a/sc2/proxy.py +++ b/sc2/proxy.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[16, 29] from __future__ import annotations import asyncio diff --git a/sc2/unit.py b/sc2/unit.py index 236116f9..2153decc 100644 --- a/sc2/unit.py +++ b/sc2/unit.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[11, 16, 29] from __future__ import annotations import math @@ -1358,7 +1357,10 @@ def warp_in( :param queue: :param can_afford_check: """ - normal_creation_ability = self._bot_object.game_data.units[unit.value].creation_ability.id + creation_ability = self._bot_object.game_data.units[unit.value].creation_ability + if creation_ability is None: + return False + normal_creation_ability = creation_ability.id return self( warpgate_abilities[normal_creation_ability], target=position, diff --git a/sc2/units.py b/sc2/units.py index 1871dfc6..3e58b1d2 100644 --- a/sc2/units.py +++ b/sc2/units.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[14, 15, 16] from __future__ import annotations import random diff --git a/sc2/wsl.py b/sc2/wsl.py index 4f2a3cdd..dcdfecec 100644 --- a/sc2/wsl.py +++ b/sc2/wsl.py @@ -93,10 +93,11 @@ def detect() -> str | None: # Unix-style newlines for safety's sake. lines = re.sub(r"\000|\r", "", wsl_proc.stdout.decode("utf-8")).split("\n") - def line_has_proc(ln): - return re.search("^\\s*[*]?\\s+" + wsl_name, ln) + def line_has_proc(ln: str): + if wsl_name is not None: + return re.search("^\\s*[*]?\\s+" + wsl_name, ln) - def line_version(ln): + def line_version(ln: str): return re.sub("^.*\\s+(\\d+)\\s*$", "\\1", ln) versions = [line_version(ln) for ln in lines if line_has_proc(ln)] diff --git a/test/autotest_bot.py b/test/autotest_bot.py index a6a0dc23..d99dee3f 100644 --- a/test/autotest_bot.py +++ b/test/autotest_bot.py @@ -484,7 +484,6 @@ async def on_start(self): await self.client.debug_kill_unit(self.units) async def on_step(self, iteration: int): - # pyre-ignore[16] map_center = self.game_info.map_center enemies = self.enemy_units | self.enemy_structures if enemies: diff --git a/test/benchmark_distance_two_points.py b/test/benchmark_distance_two_points.py index 9527a107..66006f86 100644 --- a/test/benchmark_distance_two_points.py +++ b/test/benchmark_distance_two_points.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[21] from __future__ import annotations import math diff --git a/test/benchmark_distances_cdist.py b/test/benchmark_distances_cdist.py index fdcfd7b8..6314ee6b 100644 --- a/test/benchmark_distances_cdist.py +++ b/test/benchmark_distances_cdist.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[21] import random import numpy as np diff --git a/test/benchmark_distances_points_to_point.py b/test/benchmark_distances_points_to_point.py index cd36c8d8..f04d5849 100644 --- a/test/benchmark_distances_points_to_point.py +++ b/test/benchmark_distances_points_to_point.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[21] from __future__ import annotations import math diff --git a/test/benchmark_distances_units.py b/test/benchmark_distances_units.py index 11d81462..c281045a 100644 --- a/test/benchmark_distances_units.py +++ b/test/benchmark_distances_units.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[21] import math import random diff --git a/test/damagetest_bot.py b/test/damagetest_bot.py index 962c2bd8..9263de08 100644 --- a/test/damagetest_bot.py +++ b/test/damagetest_bot.py @@ -321,7 +321,6 @@ async def on_start(self): await self.client.debug_kill_unit(self.units) async def on_step(self, iteration: int): - # pyre-ignore[16] map_center = self.game_info.map_center enemies = self.enemy_units | self.enemy_structures if enemies: diff --git a/test/generate_pickle_files_bot.py b/test/generate_pickle_files_bot.py index 3b4158ee..0bfcf1c2 100644 --- a/test/generate_pickle_files_bot.py +++ b/test/generate_pickle_files_bot.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[16] """ This "bot" will loop over several available ladder maps and generate the pickle file in the "/test/pickle_data/" subfolder. These will then be used to run tests from the test script "test_pickled_data.py" diff --git a/test/queries_test_bot.py b/test/queries_test_bot.py index 72c3c7f8..70da2c18 100644 --- a/test/queries_test_bot.py +++ b/test/queries_test_bot.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[16] """ This testbot's purpose is to test the query behavior of the API. These query functions are: diff --git a/test/real_time_worker_production.py b/test/real_time_worker_production.py index d272e0ea..f0f69bdf 100644 --- a/test/real_time_worker_production.py +++ b/test/real_time_worker_production.py @@ -92,7 +92,6 @@ async def on_building_construction_complete(self, unit: Unit): if unit.is_structure: unit(AbilityId.RALLY_WORKERS, self.mineral_field.closest_to(unit)) - # pyre-ignore[11] async def on_end(self, game_result: Result): global on_end_was_called on_end_was_called = True diff --git a/test/run_example_bots_vs_computer.py b/test/run_example_bots_vs_computer.py index 9cb15576..3c144be8 100644 --- a/test/run_example_bots_vs_computer.py +++ b/test/run_example_bots_vs_computer.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[16] """ This script makes sure to run all bots in the examples folder to check if they can launch. """ @@ -115,7 +114,6 @@ # Run example bots for bot_info in bot_infos: - # pyre-ignore[11] bot_race: Race = bot_info["race"] bot_path: str = bot_info["path"] bot_class_name: str = bot_info["bot_class_name"] diff --git a/test/run_example_bots_vs_each_other.py b/test/run_example_bots_vs_each_other.py index b70f3ccb..88d8f445 100644 --- a/test/run_example_bots_vs_each_other.py +++ b/test/run_example_bots_vs_each_other.py @@ -1,4 +1,3 @@ -# pyre-ignore-all-errors[16] """ This script makes sure to run all bots in the examples folder to check if they can launch against each other. """ @@ -96,7 +95,6 @@ # Run bots against each other for bot_info1, bot_info2 in combinations(bot_infos, 2): - # pyre-ignore[11] bot_race1: Race = bot_info1["race"] bot_path: str = bot_info1["path"] bot_class_name: str = bot_info1["bot_class_name"] diff --git a/test/test_pickled_data.py b/test/test_pickled_data.py index 8a6c5e69..8c41baa8 100644 --- a/test/test_pickled_data.py +++ b/test/test_pickled_data.py @@ -291,7 +291,6 @@ def test_bot_ai(): def calc_cost(item_id) -> Cost: if isinstance(item_id, AbilityId): - # pyre-ignore[16] return bot.game_data.calculate_ability_cost(item_id) elif isinstance(item_id, UpgradeId): return bot.game_data.upgrades[item_id.value].cost diff --git a/test/test_pickled_ramp.py b/test/test_pickled_ramp.py index 0695fe65..e6e43051 100644 --- a/test/test_pickled_ramp.py +++ b/test/test_pickled_ramp.py @@ -41,7 +41,7 @@ class TestClass: def test_main_base_ramp(self, map_path: Path): bot = get_map_specific_bot(map_path) - # pyre-ignore[16] + bot.game_info.map_ramps, bot.game_info.vision_blockers = bot.game_info._find_ramps_and_vision_blockers() # Test if main ramp works for all spawns @@ -107,7 +107,7 @@ def test_bot_ai(self, map_path: Path): ) # On N player maps, it is expected that there are N*X bases because of symmetry, at least for maps designed for 1vs1 # Those maps in the list have an un-even expansion count - # pyre-ignore[16] + expect_even_expansion_count = 1 if bot.game_info.map_name in self.MAPS_WITH_ODD_EXPANSION_COUNT else 0 assert ( len(bot.expansion_locations_list) % (len(bot.enemy_start_locations) + 1) == expect_even_expansion_count @@ -120,7 +120,7 @@ def test_bot_ai(self, map_path: Path): for location in bot.enemy_start_locations: assert location in set(bot.expansion_locations_list), f"{location}, {bot.expansion_locations_list}" # Each expansion is supposed to have at least one geysir and 6-12 minerals - # pyre-ignore[16] + for expansion, resource_positions in bot.expansion_locations_dict.items(): assert isinstance(expansion, Point2) assert isinstance(resource_positions, Units) diff --git a/test/travis_test_script.py b/test/travis_test_script.py index 973141d6..cc1b036e 100644 --- a/test/travis_test_script.py +++ b/test/travis_test_script.py @@ -49,20 +49,18 @@ # Break as the bot run was successful break - # pyre-ignore[16] if process.returncode is not None: # Reformat the output into a list - # pyre-ignore[16] + logger.info_output = result linebreaks = [ - # pyre-ignore[16] ["\r\n", logger.info_output.count("\r\n")], ["\r", logger.info_output.count("\r")], ["\n", logger.info_output.count("\n")], ] most_linebreaks_type = max(linebreaks, key=lambda x: x[1]) linebreak_type, linebreak_count = most_linebreaks_type - # pyre-ignore[16] + output_as_list = logger.info_output.split(linebreak_type) logger.info("Travis test script, bot output:\r\n{}\r\nEnd of bot output".format("\r\n".join(output_as_list))) diff --git a/test/upgradestest_bot.py b/test/upgradestest_bot.py index 6267866b..623d5f71 100644 --- a/test/upgradestest_bot.py +++ b/test/upgradestest_bot.py @@ -174,7 +174,6 @@ async def on_start(self): await self.client.debug_kill_unit(self.units) async def on_step(self, iteration: int): - # pyre-ignore[16] map_center = self.game_info.map_center enemies = self.enemy_units | self.enemy_structures if enemies: From 94e94ede7f337f6fb98ccb98f9af8e7ad7bcd70e Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Fri, 2 Jan 2026 16:09:18 +0100 Subject: [PATCH 3/5] Fix most pyrefly issues --- .github/workflows/ci.yml | 4 +- .pre-commit-config.yaml | 21 ++++--- README.md | 2 +- pyproject.toml | 9 +++ s2clientprotocol/data_pb2.pyi | 18 +++--- s2clientprotocol/debug_pb2.pyi | 22 ++++---- s2clientprotocol/query_pb2.pyi | 30 +++++----- s2clientprotocol/raw_pb2.pyi | 58 +++++++++---------- s2clientprotocol/sc2api_pb2.pyi | 96 ++++++++++++++++---------------- s2clientprotocol/spatial_pb2.pyi | 6 +- s2clientprotocol/ui_pb2.pyi | 26 ++++----- sc2/action.py | 12 +++- sc2/bot_ai_internal.py | 2 +- sc2/cache.py | 1 + sc2/client.py | 4 +- sc2/controller.py | 9 ++- sc2/expiring_dict.py | 2 +- sc2/game_info.py | 20 ++++--- sc2/game_state.py | 22 ++++---- sc2/generate_ids.py | 4 +- sc2/main.py | 83 ++++++++++++++++----------- sc2/paths.py | 21 +++---- sc2/pixel_map.py | 13 +++-- sc2/player.py | 9 ++- sc2/portconfig.py | 8 +-- sc2/position.py | 16 ++++-- sc2/protocol.py | 2 +- sc2/proxy.py | 7 ++- sc2/renderer.py | 31 +++++++---- sc2/sc2process.py | 13 +++-- sc2/score.py | 3 +- sc2/unit.py | 41 +++++++++----- sc2/versions.py | 2 +- 33 files changed, 354 insertions(+), 263 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 645001d4..64a37745 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: uv run pre-commit install - name: Run pre-commit hooks - run: uv run pre-commit run --all-files --hook-stage push + run: uv run pre-commit run --all-files --hook-stage pre-push generate_dicts_from_data_json: name: Generate dicts from data.json @@ -80,7 +80,7 @@ jobs: run: | mv sc2/dicts sc2/dicts_old uv run python generate_dicts_from_data_json.py - uv run pre-commit run --all-files --hook-stage push || true + uv run pre-commit run --all-files --hook-stage pre-push || true rm -rf sc2/dicts/__pycache__ sc2/dicts_old/__pycache__ - name: Upload generated dicts folder as artifact diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0c784213..03917649 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,22 +38,25 @@ repos: # Autoformat code - id: ruff-format-check name: Check if files are formatted - stages: [push] + stages: [pre-push] language: system + # Run the following command to fix: + # uv run ruff format . entry: uv run ruff format . --check --diff pass_filenames: false - id: ruff-lint name: Lint files - stages: [push] + stages: [pre-push] language: system + # Run the following command to fix: + # uv run ruff check . --fix entry: uv run ruff check . pass_filenames: false - # TODO Fix issues - # - id: pyrefly - # name: Static types checking with pyrefly - # stages: [push] - # language: system - # entry: uv run pyrefly check - # pass_filenames: false + - id: pyrefly + name: Static types checking with pyrefly + stages: [pre-push] + language: system + entry: uv run pyrefly check + pass_filenames: false diff --git a/README.md b/README.md index 0ae3d0b8..835e80a0 100644 --- a/README.md +++ b/README.md @@ -186,5 +186,5 @@ Git commit messages use [imperative-style messages](https://stackoverflow.com/a/ To run pre-commit hooks (which run autoformatting and autosort imports) you can run ```sh uv run pre-commit install -uv run pre-commit run --all-files --hook-stage push +uv run pre-commit run --all-files --hook-stage pre-push ``` diff --git a/pyproject.toml b/pyproject.toml index 3187f86d..0129eab6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,9 +98,18 @@ project_includes = [ # "examples", # "test" ] +project-excludes = [ + # Disable for those files and folders + "sc2/data.py", + # TODO Temp disable for those files and folders + "sc2/client.py", + "sc2/bot_ai_internal.py", + "sc2/bot_ai.py", +] [tool.pyrefly.errors] bad-override = false +inconsistent-overload = false [tool.ruff] target-version = 'py310' diff --git a/s2clientprotocol/data_pb2.pyi b/s2clientprotocol/data_pb2.pyi index 8b839cf5..dd663eb7 100644 --- a/s2clientprotocol/data_pb2.pyi +++ b/s2clientprotocol/data_pb2.pyi @@ -1,4 +1,4 @@ -from collections.abc import Iterable +from collections.abc import Sequence from enum import Enum from google.protobuf.message import Message @@ -71,7 +71,7 @@ class TargetType(Enum): class Weapon(Message): type: int damage: float - damage_bonus: Iterable[DamageBonus] + damage_bonus: Sequence[DamageBonus] attacks: int range: float speed: float @@ -79,7 +79,7 @@ class Weapon(Message): self, type: int = ..., damage: float = ..., - damage_bonus: Iterable[DamageBonus] = ..., + damage_bonus: Sequence[DamageBonus] = ..., attacks: int = ..., range: float = ..., speed: float = ..., @@ -100,14 +100,14 @@ class UnitTypeData(Message): has_vespene: bool has_minerals: bool sight_range: float - tech_alias: Iterable[int] + tech_alias: Sequence[int] unit_alias: int tech_requirement: int require_attached: bool - attributes: Iterable[int] + attributes: Sequence[int] movement_speed: float armor: float - weapons: Iterable[Weapon] + weapons: Sequence[Weapon] def __init__( self, unit_id: int = ..., @@ -124,14 +124,14 @@ class UnitTypeData(Message): has_vespene: bool = ..., has_minerals: bool = ..., sight_range: float = ..., - tech_alias: Iterable[int] = ..., + tech_alias: Sequence[int] = ..., unit_alias: int = ..., tech_requirement: int = ..., require_attached: bool = ..., - attributes: Iterable[int] = ..., + attributes: Sequence[int] = ..., movement_speed: float = ..., armor: float = ..., - weapons: Iterable[Weapon] = ..., + weapons: Sequence[Weapon] = ..., ) -> None: ... class UpgradeData(Message): diff --git a/s2clientprotocol/debug_pb2.pyi b/s2clientprotocol/debug_pb2.pyi index edc956ee..49f52525 100644 --- a/s2clientprotocol/debug_pb2.pyi +++ b/s2clientprotocol/debug_pb2.pyi @@ -1,4 +1,4 @@ -from collections.abc import Iterable +from collections.abc import Sequence from enum import Enum from google.protobuf.message import Message @@ -27,16 +27,16 @@ class DebugCommand(Message): ) -> None: ... class DebugDraw(Message): - text: Iterable[DebugText] - lines: Iterable[DebugLine] - boxes: Iterable[DebugBox] - spheres: Iterable[DebugSphere] + text: Sequence[DebugText] + lines: Sequence[DebugLine] + boxes: Sequence[DebugBox] + spheres: Sequence[DebugSphere] def __init__( self, - text: Iterable[DebugText] = ..., - lines: Iterable[DebugLine] = ..., - boxes: Iterable[DebugBox] = ..., - spheres: Iterable[DebugSphere] = ..., + text: Sequence[DebugText] = ..., + lines: Sequence[DebugLine] = ..., + boxes: Sequence[DebugBox] = ..., + spheres: Sequence[DebugSphere] = ..., ) -> None: ... class Line(Message): @@ -110,8 +110,8 @@ class DebugCreateUnit(Message): ) -> None: ... class DebugKillUnit(Message): - tag: Iterable[int] - def __init__(self, tag: Iterable[int] = ...) -> None: ... + tag: Sequence[int] + def __init__(self, tag: Sequence[int] = ...) -> None: ... class Test(Enum): hang: int diff --git a/s2clientprotocol/query_pb2.pyi b/s2clientprotocol/query_pb2.pyi index 746d86d9..67831b5f 100644 --- a/s2clientprotocol/query_pb2.pyi +++ b/s2clientprotocol/query_pb2.pyi @@ -1,31 +1,31 @@ -from collections.abc import Iterable +from collections.abc import Sequence from google.protobuf.message import Message from .common_pb2 import AvailableAbility, Point2D class RequestQuery(Message): - pathing: Iterable[RequestQueryPathing] - abilities: Iterable[RequestQueryAvailableAbilities] - placements: Iterable[RequestQueryBuildingPlacement] + pathing: Sequence[RequestQueryPathing] + abilities: Sequence[RequestQueryAvailableAbilities] + placements: Sequence[RequestQueryBuildingPlacement] ignore_resource_requirements: bool def __init__( self, - pathing: Iterable[RequestQueryPathing] = ..., - abilities: Iterable[RequestQueryAvailableAbilities] = ..., - placements: Iterable[RequestQueryBuildingPlacement] = ..., + pathing: Sequence[RequestQueryPathing] = ..., + abilities: Sequence[RequestQueryAvailableAbilities] = ..., + placements: Sequence[RequestQueryBuildingPlacement] = ..., ignore_resource_requirements: bool = ..., ) -> None: ... class ResponseQuery(Message): - pathing: Iterable[ResponseQueryPathing] - abilities: Iterable[ResponseQueryAvailableAbilities] - placements: Iterable[ResponseQueryBuildingPlacement] + pathing: Sequence[ResponseQueryPathing] + abilities: Sequence[ResponseQueryAvailableAbilities] + placements: Sequence[ResponseQueryBuildingPlacement] def __init__( self, - pathing: Iterable[ResponseQueryPathing] = ..., - abilities: Iterable[ResponseQueryAvailableAbilities] = ..., - placements: Iterable[ResponseQueryBuildingPlacement] = ..., + pathing: Sequence[ResponseQueryPathing] = ..., + abilities: Sequence[ResponseQueryAvailableAbilities] = ..., + placements: Sequence[ResponseQueryBuildingPlacement] = ..., ) -> None: ... class RequestQueryPathing(Message): @@ -48,12 +48,12 @@ class RequestQueryAvailableAbilities(Message): def __init__(self, unit_tag: int = ...) -> None: ... class ResponseQueryAvailableAbilities(Message): - abilities: Iterable[AvailableAbility] + abilities: Sequence[AvailableAbility] unit_tag: int unit_type_id: int def __init__( self, - abilities: Iterable[AvailableAbility] = ..., + abilities: Sequence[AvailableAbility] = ..., unit_tag: int = ..., unit_type_id: int = ..., ) -> None: ... diff --git a/s2clientprotocol/raw_pb2.pyi b/s2clientprotocol/raw_pb2.pyi index 34d89c6d..3db88e38 100644 --- a/s2clientprotocol/raw_pb2.pyi +++ b/s2clientprotocol/raw_pb2.pyi @@ -1,4 +1,4 @@ -from collections.abc import Iterable +from collections.abc import Sequence from enum import Enum from google.protobuf.message import Message @@ -11,7 +11,7 @@ class StartRaw(Message): terrain_height: ImageData placement_grid: ImageData playable_area: RectangleI - start_locations: Iterable[Point2D] + start_locations: Sequence[Point2D] def __init__( self, map_size: Size2DI = ..., @@ -19,24 +19,24 @@ class StartRaw(Message): terrain_height: ImageData = ..., placement_grid: ImageData = ..., playable_area: RectangleI = ..., - start_locations: Iterable[Point2D] = ..., + start_locations: Sequence[Point2D] = ..., ) -> None: ... class ObservationRaw(Message): player: PlayerRaw - units: Iterable[Unit] + units: Sequence[Unit] map_state: MapState event: Event - effects: Iterable[Effect] - radar: Iterable[RadarRing] + effects: Sequence[Effect] + radar: Sequence[RadarRing] def __init__( self, player: PlayerRaw = ..., - units: Iterable[Unit] = ..., + units: Sequence[Unit] = ..., map_state: MapState = ..., event: Event = ..., - effects: Iterable[Effect] = ..., - radar: Iterable[RadarRing] = ..., + effects: Sequence[Effect] = ..., + radar: Sequence[RadarRing] = ..., ) -> None: ... class RadarRing(Message): @@ -51,14 +51,14 @@ class PowerSource(Message): def __init__(self, pos: Point = ..., radius: float = ..., tag: int = ...) -> None: ... class PlayerRaw(Message): - power_sources: Iterable[PowerSource] + power_sources: Sequence[PowerSource] camera: Point - upgrade_ids: Iterable[int] + upgrade_ids: Sequence[int] def __init__( self, - power_sources: Iterable[PowerSource] = ..., + power_sources: Sequence[PowerSource] = ..., camera: Point = ..., - upgrade_ids: Iterable[int] = ..., + upgrade_ids: Sequence[int] = ..., ) -> None: ... class UnitOrder(Message): @@ -130,7 +130,7 @@ class Unit(Message): radius: float build_progress: float cloak: int - buff_ids: Iterable[int] + buff_ids: Sequence[int] detect_range: float radar_range: float is_selected: bool @@ -152,9 +152,9 @@ class Unit(Message): is_flying: bool is_burrowed: bool is_hallucination: bool - orders: Iterable[UnitOrder] + orders: Sequence[UnitOrder] add_on_tag: int - passengers: Iterable[PassengerUnit] + passengers: Sequence[PassengerUnit] cargo_space_taken: int cargo_space_max: int assigned_harvesters: int @@ -163,7 +163,7 @@ class Unit(Message): engaged_target_tag: int buff_duration_remain: int buff_duration_max: int - rally_targets: Iterable[RallyTarget] + rally_targets: Sequence[RallyTarget] def __init__( self, display_type: int = ..., @@ -176,7 +176,7 @@ class Unit(Message): radius: float = ..., build_progress: float = ..., cloak: int = ..., - buff_ids: Iterable[int] = ..., + buff_ids: Sequence[int] = ..., detect_range: float = ..., radar_range: float = ..., is_selected: bool = ..., @@ -198,9 +198,9 @@ class Unit(Message): is_flying: bool = ..., is_burrowed: bool = ..., is_hallucination: bool = ..., - orders: Iterable[UnitOrder] = ..., + orders: Sequence[UnitOrder] = ..., add_on_tag: int = ..., - passengers: Iterable[PassengerUnit] = ..., + passengers: Sequence[PassengerUnit] = ..., cargo_space_taken: int = ..., cargo_space_max: int = ..., assigned_harvesters: int = ..., @@ -209,7 +209,7 @@ class Unit(Message): engaged_target_tag: int = ..., buff_duration_remain: int = ..., buff_duration_max: int = ..., - rally_targets: Iterable[RallyTarget] = ..., + rally_targets: Sequence[RallyTarget] = ..., ) -> None: ... class MapState(Message): @@ -218,19 +218,19 @@ class MapState(Message): def __init__(self, visibility: ImageData = ..., creep: ImageData = ...) -> None: ... class Event(Message): - dead_units: Iterable[int] - def __init__(self, dead_units: Iterable[int] = ...) -> None: ... + dead_units: Sequence[int] + def __init__(self, dead_units: Sequence[int] = ...) -> None: ... class Effect(Message): effect_id: int - pos: Iterable[Point2D] + pos: Sequence[Point2D] alliance: int owner: int radius: float def __init__( self, effect_id: int = ..., - pos: Iterable[Point2D] = ..., + pos: Sequence[Point2D] = ..., alliance: int = ..., owner: int = ..., radius: float = ..., @@ -251,14 +251,14 @@ class ActionRawUnitCommand(Message): ability_id: int target_world_space_pos: Point2D target_unit_tag: int - unit_tags: Iterable[int] + unit_tags: Sequence[int] queue_command: bool def __init__( self, ability_id: int = ..., target_world_space_pos: Point2D = ..., target_unit_tag: int = ..., - unit_tags: Iterable[int] = ..., + unit_tags: Sequence[int] = ..., queue_command: bool = ..., ) -> None: ... @@ -268,5 +268,5 @@ class ActionRawCameraMove(Message): class ActionRawToggleAutocast(Message): ability_id: int - unit_tags: Iterable[int] - def __init__(self, ability_id: int = ..., unit_tags: Iterable[int] = ...) -> None: ... + unit_tags: Sequence[int] + def __init__(self, ability_id: int = ..., unit_tags: Sequence[int] = ...) -> None: ... diff --git a/s2clientprotocol/sc2api_pb2.pyi b/s2clientprotocol/sc2api_pb2.pyi index 67574e00..629d020e 100644 --- a/s2clientprotocol/sc2api_pb2.pyi +++ b/s2clientprotocol/sc2api_pb2.pyi @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Sequence from enum import Enum from google.protobuf.message import Message @@ -90,7 +90,7 @@ class Response(Message): ping: ResponsePing debug: ResponseDebug id: int - error: Iterable[str] + error: Sequence[str] status: int def __init__( self, @@ -117,7 +117,7 @@ class Response(Message): ping: ResponsePing = ..., debug: ResponseDebug = ..., id: int = ..., - error: Iterable[str] = ..., + error: Sequence[str] = ..., status: int = ..., ) -> None: ... @@ -133,7 +133,7 @@ class Status(Enum): class RequestCreateGame(Message): local_map: LocalMap battlenet_map_name: str - player_setup: Iterable[PlayerSetup] + player_setup: Sequence[PlayerSetup] disable_fog: bool random_seed: int realtime: bool @@ -141,7 +141,7 @@ class RequestCreateGame(Message): self, local_map: LocalMap = ..., battlenet_map_name: str = ..., - player_setup: Iterable[PlayerSetup] = ..., + player_setup: Sequence[PlayerSetup] = ..., disable_fog: bool = ..., random_seed: int = ..., realtime: bool = ..., @@ -172,7 +172,7 @@ class RequestJoinGame(Message): observed_player_id: int options: InterfaceOptions server_ports: PortSet - client_ports: Iterable[PortSet] + client_ports: Sequence[PortSet] shared_port: int player_name: str host_ip: str @@ -182,7 +182,7 @@ class RequestJoinGame(Message): observed_player_id: int = ..., options: InterfaceOptions = ..., server_ports: PortSet = ..., - client_ports: Iterable[PortSet] = ..., + client_ports: Sequence[PortSet] = ..., shared_port: int = ..., player_name: str = ..., host_ip: str = ..., @@ -302,17 +302,17 @@ class RequestGameInfo(Message): class ResponseGameInfo(Message): map_name: str - mod_names: Iterable[str] + mod_names: Sequence[str] local_map_path: str - player_info: Iterable[PlayerInfo] + player_info: Sequence[PlayerInfo] start_raw: StartRaw options: InterfaceOptions def __init__( self, map_name: str = ..., - mod_names: Iterable[str] = ..., + mod_names: Sequence[str] = ..., local_map_path: str = ..., - player_info: Iterable[PlayerInfo] = ..., + player_info: Sequence[PlayerInfo] = ..., start_raw: StartRaw = ..., options: InterfaceOptions = ..., ) -> None: ... @@ -323,18 +323,18 @@ class RequestObservation(Message): def __init__(self, disable_fog: bool = ..., game_loop: int = ...) -> None: ... class ResponseObservation(Message): - actions: Iterable[Action] - action_errors: Iterable[ActionError] + actions: Sequence[Action] + action_errors: Sequence[ActionError] observation: Observation - player_result: Iterable[PlayerResult] - chat: Iterable[ChatReceived] + player_result: Sequence[PlayerResult] + chat: Sequence[ChatReceived] def __init__( self, - actions: Iterable[Action] = ..., - action_errors: Iterable[ActionError] = ..., + actions: Sequence[Action] = ..., + action_errors: Sequence[ActionError] = ..., observation: Observation = ..., - player_result: Iterable[PlayerResult] = ..., - chat: Iterable[ChatReceived] = ..., + player_result: Sequence[PlayerResult] = ..., + chat: Sequence[ChatReceived] = ..., ) -> None: ... class ChatReceived(Message): @@ -343,16 +343,16 @@ class ChatReceived(Message): def __init__(self, player_id: int = ..., message: str = ...) -> None: ... class RequestAction(Message): - actions: Iterable[Action] - def __init__(self, actions: Iterable[Action] = ...) -> None: ... + actions: Sequence[Action] + def __init__(self, actions: Sequence[Action] = ...) -> None: ... class ResponseAction(Message): - result: Iterable[int] - def __init__(self, result: Iterable[int] = ...) -> None: ... + result: Sequence[int] + def __init__(self, result: Sequence[int] = ...) -> None: ... class RequestObserverAction(Message): - actions: Iterable[ObserverAction] - def __init__(self, actions: Iterable[ObserverAction] = ...) -> None: ... + actions: Sequence[ObserverAction] + def __init__(self, actions: Sequence[ObserverAction] = ...) -> None: ... class ResponseObserverAction(Message): def __init__(self) -> None: ... @@ -381,18 +381,18 @@ class RequestData(Message): ) -> None: ... class ResponseData(Message): - abilities: Iterable[AbilityData] - units: Iterable[UnitTypeData] - upgrades: Iterable[UpgradeData] - buffs: Iterable[BuffData] - effects: Iterable[EffectData] + abilities: Sequence[AbilityData] + units: Sequence[UnitTypeData] + upgrades: Sequence[UpgradeData] + buffs: Sequence[BuffData] + effects: Sequence[EffectData] def __init__( self, - abilities: Iterable[AbilityData] = ..., - units: Iterable[UnitTypeData] = ..., - upgrades: Iterable[UpgradeData] = ..., - buffs: Iterable[BuffData] = ..., - effects: Iterable[EffectData] = ..., + abilities: Sequence[AbilityData] = ..., + units: Sequence[UnitTypeData] = ..., + upgrades: Sequence[UpgradeData] = ..., + buffs: Sequence[BuffData] = ..., + effects: Sequence[EffectData] = ..., ) -> None: ... class RequestSaveReplay(Message): @@ -436,7 +436,7 @@ class ResponseReplayInfo(Message): map_name: str local_map_path: str - player_info: Iterable[PlayerInfoExtra] + player_info: Sequence[PlayerInfoExtra] game_duration_loops: int game_duration_seconds: float game_version: str @@ -449,7 +449,7 @@ class ResponseReplayInfo(Message): self, map_name: str = ..., local_map_path: str = ..., - player_info: Iterable[PlayerInfoExtra] = ..., + player_info: Sequence[PlayerInfoExtra] = ..., game_duration_loops: int = ..., game_duration_seconds: float = ..., game_version: str = ..., @@ -464,9 +464,9 @@ class RequestAvailableMaps(Message): def __init__(self) -> None: ... class ResponseAvailableMaps(Message): - local_map_paths: Iterable[str] - battlenet_map_names: Iterable[str] - def __init__(self, local_map_paths: Iterable[str] = ..., battlenet_map_names: Iterable[str] = ...) -> None: ... + local_map_paths: Sequence[str] + battlenet_map_names: Sequence[str] + def __init__(self, local_map_paths: Sequence[str] = ..., battlenet_map_names: Sequence[str] = ...) -> None: ... class RequestSaveMap(Message): map_path: str @@ -497,8 +497,8 @@ class ResponsePing(Message): ) -> None: ... class RequestDebug(Message): - debug: Iterable[DebugCommand] - def __init__(self, debug: Iterable[DebugCommand] = ...) -> None: ... + debug: Sequence[DebugCommand] + def __init__(self, debug: Sequence[DebugCommand] = ...) -> None: ... class ResponseDebug(Message): def __init__(self) -> None: ... @@ -630,8 +630,8 @@ class PlayerCommon(Message): class Observation(Message): game_loop: int player_common: PlayerCommon - alerts: Iterable[int] - abilities: Iterable[AvailableAbility] + alerts: Sequence[int] + abilities: Sequence[AvailableAbility] score: Score raw_data: ObservationRaw feature_layer_data: ObservationFeatureLayer @@ -641,8 +641,8 @@ class Observation(Message): self, game_loop: int = ..., player_common: PlayerCommon = ..., - alerts: Iterable[int] = ..., - abilities: Iterable[AvailableAbility] = ..., + alerts: Sequence[int] = ..., + abilities: Sequence[AvailableAbility] = ..., score: Score = ..., raw_data: ObservationRaw = ..., feature_layer_data: ObservationFeatureLayer = ..., @@ -709,8 +709,8 @@ class ActionObserverCameraFollowPlayer(Message): def __init__(self, player_id: int = ...) -> None: ... class ActionObserverCameraFollowUnits(Message): - unit_tags: Iterable[int] - def __init__(self, unit_tags: Iterable[int] = ...) -> None: ... + unit_tags: Sequence[int] + def __init__(self, unit_tags: Sequence[int] = ...) -> None: ... class Alert(Enum): AlertError: int diff --git a/s2clientprotocol/spatial_pb2.pyi b/s2clientprotocol/spatial_pb2.pyi index a1b72a29..92f6afe9 100644 --- a/s2clientprotocol/spatial_pb2.pyi +++ b/s2clientprotocol/spatial_pb2.pyi @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Sequence from enum import Enum from google.protobuf.message import Message @@ -149,6 +149,6 @@ class ActionSpatialUnitSelectionPoint(Message): def __init__(self, selection_screen_coord: PointI = ..., type: int = ...) -> None: ... class ActionSpatialUnitSelectionRect(Message): - selection_screen_coord: Iterable[RectangleI] + selection_screen_coord: Sequence[RectangleI] selection_add: bool - def __init__(self, selection_screen_coord: Iterable[RectangleI] = ..., selection_add: bool = ...) -> None: ... + def __init__(self, selection_screen_coord: Sequence[RectangleI] = ..., selection_add: bool = ...) -> None: ... diff --git a/s2clientprotocol/ui_pb2.pyi b/s2clientprotocol/ui_pb2.pyi index dbf39f3b..82706576 100644 --- a/s2clientprotocol/ui_pb2.pyi +++ b/s2clientprotocol/ui_pb2.pyi @@ -1,19 +1,19 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Sequence from enum import Enum from google.protobuf.message import Message class ObservationUI(Message): - groups: Iterable[ControlGroup] + groups: Sequence[ControlGroup] single: SinglePanel multi: MultiPanel cargo: CargoPanel production: ProductionPanel def __init__( self, - groups: Iterable[ControlGroup] = ..., + groups: Sequence[ControlGroup] = ..., single: SinglePanel = ..., multi: MultiPanel = ..., cargo: CargoPanel = ..., @@ -63,28 +63,28 @@ class SinglePanel(Message): attack_upgrade_level: int armor_upgrade_level: int shield_upgrade_level: int - buffs: Iterable[int] + buffs: Sequence[int] def __init__( self, unit: UnitInfo = ..., attack_upgrade_level: int = ..., armor_upgrade_level: int = ..., shield_upgrade_level: int = ..., - buffs: Iterable[int] = ..., + buffs: Sequence[int] = ..., ) -> None: ... class MultiPanel(Message): - units: Iterable[UnitInfo] - def __init__(self, units: Iterable[UnitInfo] = ...) -> None: ... + units: Sequence[UnitInfo] + def __init__(self, units: Sequence[UnitInfo] = ...) -> None: ... class CargoPanel(Message): unit: UnitInfo - passengers: Iterable[UnitInfo] + passengers: Sequence[UnitInfo] slots_available: int def __init__( self, unit: UnitInfo = ..., - passengers: Iterable[UnitInfo] = ..., + passengers: Sequence[UnitInfo] = ..., slots_available: int = ..., ) -> None: ... @@ -95,13 +95,13 @@ class BuildItem(Message): class ProductionPanel(Message): unit: UnitInfo - build_queue: Iterable[UnitInfo] - production_queue: Iterable[BuildItem] + build_queue: Sequence[UnitInfo] + production_queue: Sequence[BuildItem] def __init__( self, unit: UnitInfo = ..., - build_queue: Iterable[UnitInfo] = ..., - production_queue: Iterable[BuildItem] = ..., + build_queue: Sequence[UnitInfo] = ..., + production_queue: Sequence[BuildItem] = ..., ) -> None: ... class ActionUI(Message): diff --git a/sc2/action.py b/sc2/action.py index b43124ae..019f3803 100644 --- a/sc2/action.py +++ b/sc2/action.py @@ -33,7 +33,10 @@ def combine_actions(action_iter: list[UnitCommand]): if combineable: # Combine actions with no target, e.g. lift, burrowup, burrowdown, siege, unsiege, uproot spines cmd = raw_pb.ActionRawUnitCommand( - ability_id=ability.value, unit_tags={u.unit.tag for u in items}, queue_command=queue + ability_id=ability.value, + # pyrefly: ignore + unit_tags={u.unit.tag for u in items}, + queue_command=queue, ) # Combine actions with target point, e.g. attack_move or move commands on a position if isinstance(target, Point2): @@ -58,13 +61,17 @@ def combine_actions(action_iter: list[UnitCommand]): if target is None: for u in items: cmd = raw_pb.ActionRawUnitCommand( - ability_id=ability.value, unit_tags={u.unit.tag}, queue_command=queue + ability_id=ability.value, + # pyrefly: ignore + unit_tags={u.unit.tag}, + queue_command=queue, ) yield raw_pb.ActionRaw(unit_command=cmd) elif isinstance(target, Point2): for u in items: cmd = raw_pb.ActionRawUnitCommand( ability_id=ability.value, + # pyrefly: ignore unit_tags={u.unit.tag}, queue_command=queue, target_world_space_pos=target.as_Point2D, @@ -74,6 +81,7 @@ def combine_actions(action_iter: list[UnitCommand]): for u in items: cmd = raw_pb.ActionRawUnitCommand( ability_id=ability.value, + # pyrefly: ignore unit_tags={u.unit.tag}, queue_command=queue, target_unit_tag=target.tag, diff --git a/sc2/bot_ai_internal.py b/sc2/bot_ai_internal.py index da0dcb7b..7f68a271 100644 --- a/sc2/bot_ai_internal.py +++ b/sc2/bot_ai_internal.py @@ -103,7 +103,7 @@ def _initialize_variables(self) -> None: self.actions: list[UnitCommand] = [] self.blips: set[Blip] = set() - self.race: Race | None = None + self.race: Race = None self.enemy_race: Race | None = None self._generated_frame = -100 self._units_created: Counter = Counter() diff --git a/sc2/cache.py b/sc2/cache.py index 1fcbd8ca..ad06ce99 100644 --- a/sc2/cache.py +++ b/sc2/cache.py @@ -32,6 +32,7 @@ class property_cache_once_per_frame(property): # noqa: N801 def __init__(self, func: Callable[[BotAI], T], name: str | None = None) -> None: self.__name__ = name or func.__name__ self.__frame__ = f"__frame__{self.__name__}" + # pyrefly: ignore self.func = func def __set__(self, obj: BotAI, value: T) -> None: diff --git a/sc2/client.py b/sc2/client.py index 4fe0cccc..6aa7f0ff 100644 --- a/sc2/client.py +++ b/sc2/client.py @@ -36,8 +36,8 @@ def __init__(self, ws: ClientWebSocketResponse, save_replay_path: str | None = N # How many frames will be waited between iterations before the next one is called self.game_step: int = 4 self.save_replay_path: str | None = save_replay_path - self._player_id = None - self._game_result = None + self._player_id: int = None + self._game_result: dict[int, Result] = None # Store a hash value of all the debug requests to prevent sending the same ones again if they haven't changed last frame self._debug_hash_tuple_last_iteration: tuple[int, int, int, int] = (0, 0, 0, 0) self._debug_draw_last_frame = False diff --git a/sc2/controller.py b/sc2/controller.py index e068aa3f..45b7dec5 100644 --- a/sc2/controller.py +++ b/sc2/controller.py @@ -26,18 +26,23 @@ def running(self) -> bool: async def create_game(self, game_map, players, realtime: bool, random_seed=None, disable_fog=None): req = sc_pb.RequestCreateGame( - local_map=sc_pb.LocalMap(map_path=str(game_map.relative_path)), realtime=realtime, disable_fog=disable_fog + local_map=sc_pb.LocalMap(map_path=str(game_map.relative_path)), + realtime=realtime, + # pyrefly: ignore + disable_fog=disable_fog, ) if random_seed is not None: req.random_seed = random_seed for player in players: + # pyrefly: ignore p = req.player_setup.add() p.type = player.type.value if isinstance(player, Computer): p.race = player.race.value p.difficulty = player.difficulty.value - p.ai_build = player.ai_build.value + if player.ai_build is not None: + p.ai_build = player.ai_build.value logger.info("Creating new game") logger.info(f"Map: {game_map.name}") diff --git a/sc2/expiring_dict.py b/sc2/expiring_dict.py index fd14d2a0..d80ffa8a 100644 --- a/sc2/expiring_dict.py +++ b/sc2/expiring_dict.py @@ -73,7 +73,7 @@ def __setitem__(self, key: Hashable, value: Any) -> None: def __repr__(self) -> str: """Printable version of the dict instead of getting memory adress""" - print_list = [] + print_list: list[str] = [] with self.lock: for key, value in OrderedDict.items(self): if self.frame - value[1] < self.max_age: diff --git a/sc2/game_info.py b/sc2/game_info.py index f9d1f0d0..fa853fd9 100644 --- a/sc2/game_info.py +++ b/sc2/game_info.py @@ -30,7 +30,7 @@ def y_offset(self) -> float: return 0.5 @cached_property - def _height_map(self): + def _height_map(self) -> PixelMap: return self.game_info.terrain_height @cached_property @@ -145,6 +145,7 @@ def barracks_can_fit_addon(self) -> bool: """Test if a barracks can fit an addon at natural ramp""" # https://i.imgur.com/4b2cXHZ.png if len(self.upper2_for_ramp_wall) == 2: + # pyrefly: ignore return self.barracks_in_middle.x + 1 > max(self.corner_depots, key=lambda depot: depot.x).x raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") @@ -172,6 +173,7 @@ def protoss_wall_pylon(self) -> Point2 | None: raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") middle = self.depot_in_middle # direction up the ramp + # pyrefly: ignore direction = self.barracks_in_middle.negative_offset(middle) return middle + 6 * direction @@ -187,12 +189,14 @@ def protoss_wall_buildings(self) -> frozenset[Point2]: if len(self.upper2_for_ramp_wall) == 2: middle = self.depot_in_middle # direction up the ramp + # pyrefly: ignore direction = self.barracks_in_middle.negative_offset(middle) # sort depots based on distance to start to get wallin orientation sorted_depots = sorted( self.corner_depots, key=lambda depot: depot.distance_to(self.game_info.player_start_location) ) wall1: Point2 = sorted_depots[1].offset(direction) + # pyrefly: ignore wall2 = middle + direction + (middle - wall1) / 1.5 return frozenset([wall1, wall2]) @@ -210,6 +214,7 @@ def protoss_wall_warpin(self) -> Point2 | None: raise Exception("Not implemented. Trying to access a ramp that has a wrong amount of upper points.") middle = self.depot_in_middle # direction up the ramp + # pyrefly: ignore direction = self.barracks_in_middle.negative_offset(middle) # sort depots based on distance to start to get wallin orientation sorted_depots = sorted(self.corner_depots, key=lambda x: x.distance_to(self.game_info.player_start_location)) @@ -233,9 +238,9 @@ def __init__(self, proto: sc2api_pb2.ResponseGameInfo) -> None: self.placement_grid: PixelMap = PixelMap(self._proto.start_raw.placement_grid, in_bits=True) self.playable_area = Rect.from_proto(self._proto.start_raw.playable_area) self.map_center = self.playable_area.center - + # pyrefly: ignore self.map_ramps: list[Ramp] = None # Filled later by BotAI._prepare_first_step - + # pyrefly: ignore self.vision_blockers: frozenset[Point2] = None # Filled later by BotAI._prepare_first_step self.player_races: dict[int, int] = { p.player_id: p.race_actual or p.race_requested for p in self._proto.player_info @@ -243,7 +248,7 @@ def __init__(self, proto: sc2api_pb2.ResponseGameInfo) -> None: self.start_locations: list[Point2] = [ Point2.from_proto(sl).round(decimals=1) for sl in self._proto.start_raw.start_locations ] - + # pyrefly: ignore self.player_start_location: Point2 = None # Filled later by BotAI._prepare_first_step def _find_ramps_and_vision_blockers(self) -> tuple[list[Ramp], frozenset[Point2]]: @@ -269,10 +274,10 @@ def equal_height_around(tile): # divide points into ramp points and vision blockers ramp_points = [point for point in points if not equal_height_around(point)] vision_blockers = frozenset(point for point in points if equal_height_around(point)) - ramps = [Ramp(group, self) for group in self._find_groups(ramp_points)] + ramps = [Ramp(frozenset(group), self) for group in self._find_groups(ramp_points)] return ramps, vision_blockers - def _find_groups(self, points: frozenset[Point2], minimum_points_per_group: int = 8) -> Iterable[frozenset[Point2]]: + def _find_groups(self, points: Iterable[Point2], minimum_points_per_group: int = 8) -> Iterable[frozenset[Point2]]: """ From a set of points, this function will try to group points together by painting clusters of points in a rectangular map using flood fill algorithm. @@ -286,6 +291,7 @@ def _find_groups(self, points: frozenset[Point2], minimum_points_per_group: int picture: list[list[int]] = [[-2 for _ in range(map_width)] for _ in range(map_height)] def paint(pt: Point2) -> None: + # pyrefly: ignore picture[pt.y][pt.x] = current_color nearby: list[tuple[int, int]] = [(a, b) for a in [-1, 0, 1] for b in [-1, 0, 1] if a != 0 or b != 0] @@ -309,12 +315,12 @@ def paint(pt: Point2) -> None: # Do we ever reach out of map bounds? if not (0 <= px < map_width and 0 <= py < map_height): continue + # pyrefly: ignore if picture[py][px] != NOT_COLORED_YET: continue point: Point2 = Point2((px, py)) remaining.discard(point) paint(point) - queue.append(point) current_group.add(point) if len(current_group) >= minimum_points_per_group: yield frozenset(current_group) diff --git a/sc2/game_state.py b/sc2/game_state.py index d9f37448..fd7af3db 100644 --- a/sc2/game_state.py +++ b/sc2/game_state.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from functools import cached_property from itertools import chain +from s2clientprotocol.raw_pb2 import Effect, Unit from loguru import logger @@ -45,7 +46,7 @@ def is_visible(self) -> bool: return self._proto.display_type == DisplayType.Visible.value @property - def alliance(self) -> Alliance: + def alliance(self) -> int: return self._proto.alliance @property @@ -96,7 +97,7 @@ def __init__(self, proto: raw_pb2.Effect | raw_pb2.Unit, fake: bool = False) -> :param proto: :param fake: """ - self._proto = proto + self._proto: Effect | Unit = proto self.fake = fake @property @@ -132,7 +133,7 @@ def owner(self) -> int: @property def radius(self) -> float: - if self.fake: + if isinstance(self._proto, Unit): return FakeEffectRadii[self._proto.unit_type] return self._proto.radius @@ -148,6 +149,8 @@ class ChatMessage: @dataclass class AbilityLookupTemplateClass: + ability_id: int + @property def exact_id(self) -> AbilityId: return AbilityId(self.ability_id) @@ -163,7 +166,6 @@ def generic_id(self) -> AbilityId: @dataclass class ActionRawUnitCommand(AbilityLookupTemplateClass): game_loop: int - ability_id: int unit_tags: list[int] queue_command: bool target_world_space_pos: Point2 | None @@ -173,7 +175,6 @@ class ActionRawUnitCommand(AbilityLookupTemplateClass): @dataclass class ActionRawToggleAutocast(AbilityLookupTemplateClass): game_loop: int - ability_id: int unit_tags: list[int] @@ -184,7 +185,6 @@ class ActionRawCameraMove: @dataclass class ActionError(AbilityLookupTemplateClass): - ability_id: int unit_tag: int # See here for the codes of 'result': https://github.com/Blizzard/s2client-proto/blob/01ab351e21c786648e4c6693d4aad023a176d45c/s2clientprotocol/error.proto#L6 result: int @@ -211,7 +211,7 @@ def __init__( self.common: Common = Common(self.observation.player_common) # Area covered by Pylons and Warpprisms - self.psionic_matrix: PsionicMatrix = PsionicMatrix.from_proto(self.observation_raw.player.power_sources) + self.psionic_matrix: PsionicMatrix = PsionicMatrix.from_proto(list(self.observation_raw.player.power_sources)) # 22.4 per second on faster game speed self.game_loop: int = self.observation.game_loop @@ -258,7 +258,7 @@ def alerts(self) -> list[int]: """ if self.previous_observation is not None: return list(chain(self.previous_observation.observation.alerts, self.observation.alerts)) - return self.observation.alerts + return list(self.observation.alerts) @cached_property def actions(self) -> list[ActionRawUnitCommand | ActionRawToggleAutocast | ActionRawCameraMove]: @@ -282,7 +282,7 @@ def actions(self) -> list[ActionRawUnitCommand | ActionRawToggleAutocast | Actio ActionRawUnitCommand( game_loop, raw_unit_command.ability_id, - raw_unit_command.unit_tags, + list(raw_unit_command.unit_tags), raw_unit_command.queue_command, Point2.from_proto(raw_unit_command.target_world_space_pos), ) @@ -293,7 +293,7 @@ def actions(self) -> list[ActionRawUnitCommand | ActionRawToggleAutocast | Actio ActionRawUnitCommand( game_loop, raw_unit_command.ability_id, - raw_unit_command.unit_tags, + list(raw_unit_command.unit_tags), raw_unit_command.queue_command, None, raw_unit_command.target_unit_tag, @@ -306,7 +306,7 @@ def actions(self) -> list[ActionRawUnitCommand | ActionRawToggleAutocast | Actio ActionRawToggleAutocast( game_loop, raw_toggle_autocast_action.ability_id, - raw_toggle_autocast_action.unit_tags, + list(raw_toggle_autocast_action.unit_tags), ) ) else: diff --git a/sc2/generate_ids.py b/sc2/generate_ids.py index c5c27956..af0b5fe1 100644 --- a/sc2/generate_ids.py +++ b/sc2/generate_ids.py @@ -194,8 +194,8 @@ def _missing_(cls, value: int) -> {class_name}: f.write(f'ID_VERSION_STRING = "{self.game_version}"\n') def update_ids_from_stableid_json(self) -> None: - if self.game_version is None or ID_VERSION_STRING is None or self.game_version != ID_VERSION_STRING: - if self.verbose and self.game_version is not None and ID_VERSION_STRING is not None: + if self.game_version is None or self.game_version != ID_VERSION_STRING: + if self.verbose and self.game_version is not None: logger.info( f"Game version is different (Old: {self.game_version}, new: {ID_VERSION_STRING}. Updating ids to match game version" ) diff --git a/sc2/main.py b/sc2/main.py index 89996f5f..cb18c0d3 100644 --- a/sc2/main.py +++ b/sc2/main.py @@ -24,7 +24,7 @@ from sc2.game_state import GameState from sc2.maps import Map from sc2.observer_ai import ObserverAI -from sc2.player import AbstractPlayer, Bot, BotProcess, Human +from sc2.player import AbstractPlayer, Bot, BotProcess, Computer, Human from sc2.portconfig import Portconfig from sc2.protocol import ConnectionAlreadyClosedError, ProtocolError from sc2.proxy import Proxy @@ -55,7 +55,12 @@ class GameMatch: def __post_init__(self) -> None: # avoid players sharing names - if len(self.players) > 1 and self.players[0].name is not None and self.players[0].name == self.players[1].name: + if ( + len(self.players) > 1 + and self.players[0].name is not None + and self.players[1].name is not None + and self.players[0].name == self.players[1].name + ): self.players[1].name += "2" if self.sc2_config is not None: @@ -106,7 +111,8 @@ async def _play_game_human(client, player_id, realtime, game_time_limit): async def _play_game_ai( client: Client, player_id: int, ai: BotAI, realtime: bool, game_time_limit: int | None ) -> Result: - gs: GameState | None = None + # pyrefly: ignore + gs: GameState = None async def initialize_first_step() -> Result | None: nonlocal gs @@ -204,10 +210,10 @@ async def run_bot_iteration(iteration: int): async def _play_game( - player: AbstractPlayer, + player: Human | Bot, client: Client, realtime: bool, - portconfig: Portconfig, + portconfig: Portconfig | None = None, game_time_limit: int | None = None, rgb_render_config: dict[str, Any] | None = None, ) -> Result: @@ -335,7 +341,7 @@ async def _setup_host_game( async def _host_game( map_settings: Map, - players: list[AbstractPlayer], + players: list[Human | Bot | Computer] | list[Human | Bot], realtime: bool = False, portconfig: Portconfig | None = None, save_replay_as: str | None = None, @@ -348,6 +354,7 @@ async def _host_game( assert players, "Can't create a game without players" assert any((isinstance(p, (Human, Bot))) for p in players) + assert isinstance(players[0], (Human, Bot)), "First player needs to be a Human or a Bot" async with SC2Process( fullscreen=players[0].fullscreen, render=rgb_render_config is not None, sc2_version=sc2_version @@ -358,7 +365,7 @@ async def _host_game( server, map_settings, players, realtime, random_seed, disable_fog, save_replay_as ) # Bot can decide if it wants to launch with 'raw_affects_selection=True' - if not isinstance(players[0], Human) and getattr(players[0].ai, "raw_affects_selection", None) is not None: + if isinstance(players[0], Bot) and getattr(players[0].ai, "raw_affects_selection", None) is not None: client.raw_affects_selection = players[0].ai.raw_affects_selection result = await _play_game(players[0], client, realtime, portconfig, game_time_limit, rgb_render_config) @@ -377,7 +384,7 @@ async def _host_game_aiter( map_settings, players, realtime, - portconfig=None, + portconfig, save_replay_as=None, game_time_limit=None, ): @@ -416,9 +423,9 @@ def _host_game_iter(*args, **kwargs): async def _join_game( - players: list[AbstractPlayer], - realtime: bool = False, - portconfig: Portconfig | None = None, + players: list[Human | Bot], + realtime: bool, + portconfig: Portconfig, save_replay_as: str | None = None, game_time_limit: int | None = None, sc2_version: str | None = None, @@ -464,6 +471,7 @@ def get_replay_version(replay_path: str | Path) -> tuple[str, str]: replay_io.write(replay_data) replay_io.seek(0) archive = mpyq.MPQArchive(replay_io).extract() + # pyrefly: ignore metadata = json.loads(archive[b"replay.gamemetadata.json"].decode("utf-8")) return metadata["BaseBuild"], metadata["DataVersion"] @@ -471,7 +479,7 @@ def get_replay_version(replay_path: str | Path) -> tuple[str, str]: # TODO Deprecate run_game function in favor of run_multiple_games def run_game( map_settings: Map, - players: list[AbstractPlayer], + players: list[Human | Bot | Computer], realtime: bool, portconfig: Portconfig | None = None, save_replay_as: str | None = None, @@ -485,14 +493,16 @@ def run_game( Returns a single Result enum if the game was against the built-in computer. Returns a list of two Result enums if the game was "Human vs Bot" or "Bot vs Bot". """ + result: Result | list[Result | None] if sum(isinstance(p, (Human, Bot)) for p in players) > 1: portconfig = Portconfig() + players_non_computer: list[Human | Bot] = [p for p in players if isinstance(p, (Human, Bot))] async def run_host_and_join(): return await asyncio.gather( _host_game( map_settings, - players, + players_non_computer, realtime=realtime, portconfig=portconfig, save_replay_as=save_replay_as, @@ -503,7 +513,7 @@ async def run_host_and_join(): disable_fog=disable_fog, ), _join_game( - players, + players_non_computer, realtime=realtime, portconfig=portconfig, save_replay_as=save_replay_as, @@ -513,11 +523,12 @@ async def run_host_and_join(): return_exceptions=True, ) - result: list[Result] = asyncio.run(run_host_and_join()) + # pyrefly: ignore + result = asyncio.run(run_host_and_join()) assert isinstance(result, list) assert all(isinstance(r, Result) for r in result) else: - result: Result = asyncio.run( + result = asyncio.run( _host_game( map_settings, players, @@ -550,9 +561,9 @@ def run_replay(ai: ObserverAI, replay_path: Path | str, realtime: bool = False, async def play_from_websocket( ws_connection: str | ClientWebSocketResponse, - player: AbstractPlayer, - realtime: bool = False, - portconfig: Portconfig | None = None, + player: Human | Bot, + realtime: bool, + portconfig: Portconfig, save_replay_as: str | None = None, game_time_limit: int | None = None, should_close: bool = True, @@ -568,6 +579,7 @@ async def play_from_websocket( try: if isinstance(ws_connection, str): session = ClientSession() + # pyrefly: ignore ws_connection = await session.ws_connect(ws_connection, timeout=120) should_close = True client = Client(ws_connection) @@ -591,7 +603,7 @@ async def run_match(controllers: list[Controller], match: GameMatch, close_ws: b # Setup portconfig beforehand, so all players use the same ports startport = None - portconfig = None + portconfig: Portconfig = None # pyrefly: ignore if match.needed_sc2_count > 1: if any(isinstance(player, BotProcess) for player in match.players): portconfig = Portconfig.contiguous_ports() @@ -623,12 +635,12 @@ async def run_match(controllers: list[Controller], match: GameMatch, close_ws: b async_results = await asyncio.gather(*coros, return_exceptions=True) - if not isinstance(async_results, list): - async_results = [async_results] for i, a in enumerate(async_results): if isinstance(a, Exception): logger.error(f"Exception[{a}] thrown by {[p for p in match.players if p.needs_sc2][i]}") + # TODO async_results may contain exceptions + # pyrefly: ignore return process_results(match.players, async_results) @@ -643,9 +655,10 @@ def process_results(players: list[AbstractPlayer], async_results: list[Result]) else: result[player] = Result.Undecided i += 1 - else: # computer + else: + # Computer other_result = async_results[0] - result[player] = None + result[player] = Result.Undecided if other_result in opp_res: result[player] = opp_res[other_result] @@ -658,12 +671,14 @@ async def maintain_SCII_count(count: int, controllers: list[Controller], proc_ar if controllers: to_remove = [] alive = await asyncio.wait_for( - asyncio.gather(*(c.ping() for c in controllers if not c._ws.closed), return_exceptions=True), timeout=20 + # pyrefly: ignore + asyncio.gather(*(c.ping() for c in controllers if not c._ws.closed), return_exceptions=True), + timeout=20, ) i = 0 # for alive for controller in controllers: if controller._ws.closed: - if not controller._process._session.closed: + if controller._process._session is not None and not controller._process._session.closed: await controller._process._session.close() to_remove.append(controller) else: @@ -697,12 +712,14 @@ async def maintain_SCII_count(count: int, controllers: list[Controller], proc_ar else: # Doesnt seem to work on linux: starting 2 clients nearly at the same time new_controllers = await asyncio.wait_for( + # pyrefly: ignore asyncio.gather(*[sc.__aenter__() for sc in extra], return_exceptions=True), timeout=50, ) controllers.extend(c for c in new_controllers if isinstance(c, Controller)) if len(controllers) == count: + # pyrefly: ignore await asyncio.wait_for(asyncio.gather(*(c.ping() for c in controllers)), timeout=20) break extra = [ @@ -738,8 +755,8 @@ async def a_run_multiple_games(matches: list[GameMatch]) -> list[dict[AbstractPl if not matches: return [] - results = [] - controllers = [] + results: list[dict[AbstractPlayer, Result]] = [] + controllers: list[Controller] = [] for m in matches: result = None dont_restart = m.needed_sc2_count == 2 @@ -754,7 +771,8 @@ async def a_run_multiple_games(matches: list[GameMatch]) -> list[dict[AbstractPl finally: if dont_restart: # Keeping them alive after a non-computer match can cause crashes await maintain_SCII_count(0, controllers, m.sc2_config) - results.append(result) + if result is not None: + results.append(result) KillSwitch.kill_all() return results @@ -771,7 +789,7 @@ async def a_run_multiple_games_nokill(matches: list[GameMatch]) -> list[dict[Abs return [] # Start the matches - results = [] + results: list[dict[AbstractPlayer, Result]] = [] controllers: list[Controller] = [] for m in matches: logger.info(f"Starting match {1 + len(results)} / {len(matches)}: {m}") @@ -794,10 +812,11 @@ async def a_run_multiple_games_nokill(matches: list[GameMatch]) -> list[dict[Abs logger.exception(f"Caught unknown exception: {e}") if not (isinstance(e, ProtocolError) and e.is_game_over_error): logger.info(f"controller {c.__dict__} threw {e}") - - results.append(result) + if result is not None: + results.append(result) # Fire the killswitch manually, instead of letting the winning player fire it. + # pyrefly: ignore await asyncio.wait_for(asyncio.gather(*(c._process._close_connection() for c in controllers)), timeout=50) KillSwitch.kill_all() signal.signal(signal.SIGINT, signal.SIG_DFL) diff --git a/sc2/paths.py b/sc2/paths.py index 10ec26bb..86eac639 100644 --- a/sc2/paths.py +++ b/sc2/paths.py @@ -55,7 +55,7 @@ def platform_detect(): return pf -PF = platform_detect() +PF: str = platform_detect() def get_home(): @@ -68,12 +68,12 @@ def get_home(): def get_user_sc2_install(): """Attempts to find a user's SC2 install if their OS has ExecuteInfo.txt""" if USERPATH[PF]: - einfo = str(get_home() / Path(USERPATH[PF])) + einfo = str(get_home() / Path(USERPATH[PF])) # pyrefly: ignore if Path(einfo).is_file(): with Path(einfo).open() as f: content = f.read() if content: - base = re.search(r" = (.*)Versions", content).group(1) + base = re.search(r" = (.*)Versions", content).group(1) # pyrefly: ignore if PF in {"WSL1", "WSL2"}: base = str(wsl.win_path_to_wsl_path(base)) @@ -88,8 +88,9 @@ def get_env() -> None: def get_runner_args(cwd): - if "WINE" in os.environ: - runner_file = Path(os.environ.get("WINE")) + wine_path = os.environ.get("WINE") + if wine_path is not None: + runner_file = Path(wine_path) runner_file = runner_file if runner_file.is_file() else runner_file / "wine" """ TODO Is converting linux path really necessary? @@ -133,16 +134,16 @@ def __setup(cls): try: base = os.environ.get("SC2PATH") or get_user_sc2_install() or BASEDIR[PF] - cls.BASE = Path(base).expanduser() + cls.BASE = Path(base).expanduser() # pyrefly: ignore cls.EXECUTABLE = latest_executeble(cls.BASE / "Versions") - cls.CWD = cls.BASE / CWD[PF] if CWD[PF] else None + cls.CWD = cls.BASE / CWD[PF] if CWD[PF] else None # pyrefly: ignore - cls.REPLAYS = cls.BASE / "Replays" + cls.REPLAYS = cls.BASE / "Replays" # pyrefly: ignore if (cls.BASE / "maps").exists(): - cls.MAPS = cls.BASE / "maps" + cls.MAPS = cls.BASE / "maps" # pyrefly: ignore else: - cls.MAPS = cls.BASE / "Maps" + cls.MAPS = cls.BASE / "Maps" # pyrefly: ignore except FileNotFoundError as e: logger.critical(f"SC2 installation not found: File '{e.filename}' does not exist.") sys.exit(1) diff --git a/sc2/pixel_map.py b/sc2/pixel_map.py index c6925d80..4b4af8f0 100644 --- a/sc2/pixel_map.py +++ b/sc2/pixel_map.py @@ -6,7 +6,7 @@ import numpy as np from s2clientprotocol.common_pb2 import ImageData -from sc2.position import Point2 +from sc2.position import Point2, _PointLike class PixelMap: @@ -43,13 +43,14 @@ def bits_per_pixel(self) -> int: def bytes_per_pixel(self) -> int: return self._proto.bits_per_pixel // 8 - def __getitem__(self, pos: tuple[int, int]) -> int: + def __getitem__(self, pos: _PointLike) -> int: """Example usage: is_pathable = self._game_info.pathing_grid[Point2((20, 20))] != 0""" assert 0 <= pos[0] < self.width, f"x is {pos[0]}, self.width is {self.width}" assert 0 <= pos[1] < self.height, f"y is {pos[1]}, self.height is {self.height}" + # pyrefly: ignore return int(self.data_numpy[pos[1], pos[0]]) - def __setitem__(self, pos: tuple[int, int], value: int) -> None: + def __setitem__(self, pos: _PointLike, value: int) -> None: """Example usage: self._game_info.pathing_grid[Point2((20, 20))] = 255""" assert 0 <= pos[0] < self.width, f"x is {pos[0]}, self.width is {self.width}" assert 0 <= pos[1] < self.height, f"y is {pos[1]}, self.height is {self.height}" @@ -57,6 +58,7 @@ def __setitem__(self, pos: tuple[int, int], value: int) -> None: f"value is {value}, it should be between 0 and {254 * self._in_bits + 1}" ) assert isinstance(value, int), f"value is of type {type(value)}, it should be an integer" + # pyrefly: ignore self.data_numpy[pos[1], pos[0]] = value def is_set(self, p: tuple[int, int]) -> bool: @@ -70,7 +72,8 @@ def copy(self) -> PixelMap: def flood_fill(self, start_point: Point2, pred: Callable[[int], bool]) -> set[Point2]: nodes: set[Point2] = set() - queue: list[Point2] = [start_point] + # pyrefly: ignore + queue: list[tuple[int, int]] = [start_point] while queue: x, y = queue.pop() @@ -83,7 +86,7 @@ def flood_fill(self, start_point: Point2, pred: Callable[[int], bool]) -> set[Po if pred(self[x, y]): nodes.add(Point2((x, y))) - queue += [Point2((x + a, y + b)) for a in [-1, 0, 1] for b in [-1, 0, 1] if not (a == 0 and b == 0)] + queue += [(x + a, y + b) for a in [-1, 0, 1] for b in [-1, 0, 1] if not (a == 0 and b == 0)] return nodes def flood_fill_all(self, pred: Callable[[int], bool]) -> set[frozenset[Point2]]: diff --git a/sc2/player.py b/sc2/player.py index 646aabe8..e624d41c 100644 --- a/sc2/player.py +++ b/sc2/player.py @@ -67,7 +67,7 @@ def __init__(self, race: Race, ai: BotAI, name: str | None = None, fullscreen: b """ assert isinstance(ai, BotAI) or ai is None, f"ai is of type {type(ai)}, inherit BotAI from bot_ai.py" super().__init__(PlayerType.Participant, race, name=name, fullscreen=fullscreen) - self.ai = ai + self.ai: BotAI = ai def __str__(self) -> str: if self.name is not None: @@ -82,7 +82,9 @@ def __init__( super().__init__(PlayerType.Computer, race, difficulty=difficulty, ai_build=ai_build) def __str__(self) -> str: - return f"Computer {self.difficulty._name_}({self.race._name_}, {self.ai_build.name})" + if self.ai_build is not None: + return f"Computer {self.difficulty._name_}({self.race._name_}, {self.ai_build.name})" + return f"Computer {self.difficulty._name_}({self.race._name_})" class Observer(AbstractPlayer): @@ -98,7 +100,8 @@ def __init__( self, player_id: int, p_type: PlayerType, - requested_race: Race, + # None in case of observer + requested_race: Race | None, difficulty: Difficulty | None = None, actual_race: Race | None = None, name: str | None = None, diff --git a/sc2/portconfig.py b/sc2/portconfig.py index 9041b90f..ca022c16 100644 --- a/sc2/portconfig.py +++ b/sc2/portconfig.py @@ -26,20 +26,20 @@ class Portconfig: """ def __init__( - self, guests: int = 1, server_ports: list[int] | None = None, player_ports: list[int] | None = None + self, guests: int = 1, server_ports: list[int] | None = None, player_ports: list[list[int]] | None = None ) -> None: self.shared = None self._picked_ports: list[int] = [] if server_ports: - self.server = server_ports + self.server: list[int] = server_ports else: self.server = [portpicker.pick_unused_port() for _ in range(2)] self._picked_ports.extend(self.server) if player_ports: - self.players = player_ports + self.players: list[list[int]] = player_ports else: self.players = [[portpicker.pick_unused_port() for _ in range(2)] for _ in range(guests)] - self._picked_ports.extend(port for player in self.players for port in player) + self._picked_ports.extend([port for player in self.players for port in player]) def clean(self) -> None: while self._picked_ports: diff --git a/sc2/position.py b/sc2/position.py index f3d9bd70..d323f8ba 100644 --- a/sc2/position.py +++ b/sc2/position.py @@ -43,6 +43,7 @@ def distance_to(self, target: _PosLike) -> float: """Calculate a single distance from a point or unit to another point or unit :param target:""" + # pyrefly: ignore p: tuple[float, ...] = target if isinstance(target, tuple) else target.position return math.hypot(self[0] - p[0], self[1] - p[1]) @@ -82,6 +83,7 @@ def distance_to_closest(self, ps: Iterable[_TPosLike]) -> float: assert ps, "ps is empty" closest_distance = math.inf for p in ps: + # pyrefly: ignore p2: tuple[float, ...] = p if isinstance(p, tuple) else p.position distance = self.distance_to_point2(p2) if distance <= closest_distance: @@ -103,6 +105,7 @@ def distance_to_furthest(self, ps: Iterable[_PosLike]) -> float: assert ps, "ps is empty" furthest_distance = -math.inf for p in ps: + # pyrefly: ignore p2: tuple[float, ...] = p if isinstance(p, tuple) else p.position distance = self.distance_to_point2(p2) if distance >= furthest_distance: @@ -130,6 +133,7 @@ def towards(self: T, p: _PosLike, distance: float = 1, limit: bool = False) -> T :param distance: :param limit: """ + # pyrefly: ignore p2: tuple[float, ...] = p if isinstance(p, tuple) else p.position # assert self != p, f"self is {self}, p is {p}" # TODO test and fix this if statement @@ -284,7 +288,7 @@ def neighbors8(self: T) -> set[T]: def negative_offset(self: T, other: Point2) -> T: return self.__class__((self[0] - other[0], self[1] - other[1])) - def __add__(self, other: Point2) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] + def __add__(self, other: Point2) -> Point2: return self.offset(other) def __sub__(self, other: Point2) -> Point2: @@ -299,12 +303,12 @@ def __abs__(self) -> float: def __bool__(self) -> bool: return self[0] != 0 or self[1] != 0 - def __mul__(self, other: _PointLike | float) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] + def __mul__(self, other: _PointLike | float) -> Point2: if isinstance(other, (int, float)): return Point2((self[0] * other, self[1] * other)) return Point2((self[0] * other[0], self[1] * other[1])) - def __rmul__(self, other: _PointLike | float) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] + def __rmul__(self, other: _PointLike | float) -> Point2: return self.__mul__(other) def __truediv__(self, other: float | Point2) -> Point2: @@ -338,7 +342,7 @@ def center(points: list[Point2]) -> Point2: class Point3(Point2): @classmethod - def from_proto(cls, data: common_pb.Point | Point3) -> Point3: # pyright: ignore[reportIncompatibleMethodOverride] + def from_proto(cls, data: common_pb.Point | Point3) -> Point3: """ :param data: """ @@ -387,7 +391,7 @@ def height(self) -> float: class Rect(Point2): @classmethod - def from_proto(cls, data: common_pb.RectangleI) -> Rect: # pyright: ignore[reportIncompatibleMethodOverride] + def from_proto(cls, data: common_pb.RectangleI) -> Rect: """ :param data: """ @@ -425,7 +429,7 @@ def size(self) -> Size: return Size((self[2], self[3])) @property - def center(self) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] + def center(self) -> Point2: return Point2((self.x + self.width / 2, self.y + self.height / 2)) def offset(self, p: _PointLike) -> Rect: diff --git a/sc2/protocol.py b/sc2/protocol.py index 2722abe0..d2fae1aa 100644 --- a/sc2/protocol.py +++ b/sc2/protocol.py @@ -113,7 +113,7 @@ async def _execute(self, debug: sc_pb.RequestDebug) -> sc_pb.Response: ... async def _execute(self, **kwargs) -> sc_pb.Response: assert len(kwargs) == 1, "Only one request allowed by the API" - response = await self.__request(sc_pb.Request(**kwargs)) + response: sc_pb.Response = await self.__request(sc_pb.Request(**kwargs)) new_status = Status(response.status) if new_status != self._status: diff --git a/sc2/proxy.py b/sc2/proxy.py index c022360d..4ea563f5 100644 --- a/sc2/proxy.py +++ b/sc2/proxy.py @@ -4,6 +4,7 @@ import os import platform import subprocess +import sys import time import traceback from pathlib import Path @@ -58,7 +59,8 @@ async def parse_request(self, msg) -> None: elif self.controller._status == Status.ended: await self.get_response() elif request.HasField("join_game") and not request.join_game.HasField("player_name"): - request.join_game.player_name = self.player.name + if self.player.name is not None: + request.join_game.player_name = self.player.name await self.controller._ws.send_bytes(request.SerializeToString()) # TODO Catching too general exception Exception (broad-except) @@ -174,8 +176,9 @@ async def play_with_proxy(self, startport): subproc_args = {"cwd": str(self.player.path), "stderr": subprocess.STDOUT} if platform.system() == "Linux": + # pyrefly: ignore subproc_args["preexec_fn"] = os.setpgrp - elif platform.system() == "Windows": + elif platform.system() == "Windows" and sys.platform == "win32": subproc_args["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP player_command_line = self.player.cmd_line(self.port, startport, self.controller._process._host, self.realtime) diff --git a/sc2/renderer.py b/sc2/renderer.py index 17e3599e..933b1c14 100644 --- a/sc2/renderer.py +++ b/sc2/renderer.py @@ -1,4 +1,12 @@ from __future__ import annotations +from pyglet.image import ImageData + + +from pyglet.text import Label + + +from pyglet.window import Window + import datetime from typing import TYPE_CHECKING @@ -15,17 +23,17 @@ class Renderer: def __init__(self, client: Client, map_size: tuple[float, float], minimap_size: tuple[float, float]) -> None: self._client = client - self._window = None + self._window: Window = None # pyrefly: ignore self._map_size = map_size - self._map_image = None + self._map_image: ImageData = None # pyrefly: ignore self._minimap_size = minimap_size - self._minimap_image = None + self._minimap_image: ImageData = None # pyrefly: ignore self._mouse_x, self._mouse_y = None, None - self._text_supply = None - self._text_vespene = None - self._text_minerals = None - self._text_score = None - self._text_time = None + self._text_supply: Label = None # pyrefly: ignore + self._text_vespene: Label = None # pyrefly: ignore + self._text_minerals: Label = None # pyrefly: ignore + self._text_score: Label = None # pyrefly: ignore + self._text_time: Label = None # pyrefly: ignore async def render(self, observation: ResponseObservation) -> None: render_data = observation.observation.render_data @@ -47,11 +55,11 @@ async def render(self, observation: ResponseObservation) -> None: from pyglet.window import Window self._window = Window(width=map_width, height=map_height) - # pyre-fixme[16] + # pyrefly: ignore self._window.on_mouse_press = self._on_mouse_press - # pyre-fixme[16] + # pyrefly: ignore self._window.on_mouse_release = self._on_mouse_release - # pyre-fixme[16] + # pyrefly: ignore self._window.on_mouse_drag = self._on_mouse_drag self._map_image = ImageData(map_width, map_height, "RGB", map_data, map_pitch) self._minimap_image = ImageData(minimap_width, minimap_height, "RGB", minimap_data, minimap_pitch) @@ -114,6 +122,7 @@ async def render(self, observation: ResponseObservation) -> None: self._text_vespene.text = str(observation.observation.player_common.vespene) self._text_minerals.text = str(observation.observation.player_common.minerals) if observation.observation.HasField("score"): + # pyrefly: ignore self._text_score.text = f"{score_pb._SCORE_SCORETYPE.values_by_number[observation.observation.score.score_type].name} score: {observation.observation.score.score}" await self._update_window() diff --git a/sc2/sc2process.py b/sc2/sc2process.py index 391d30b1..0017c01d 100644 --- a/sc2/sc2process.py +++ b/sc2/sc2process.py @@ -132,9 +132,9 @@ def versions(self): def find_data_hash(self, target_sc2_version: str) -> str | None: """Returns the data hash from the matching version string.""" - version: dict for version in self.versions: if version["label"] == target_sc2_version: + # pyrefly: ignore return version["data-hash"] return None @@ -154,7 +154,7 @@ def _launch(self): else: executable = str(Paths.EXECUTABLE) - if self._port is None: + if self._port == -1: self._port = portpicker.pick_unused_port() self._used_portpicker = True args = paths.get_runner_args(Paths.CWD) + [ @@ -222,6 +222,7 @@ async def _connect(self) -> ClientWebSocketResponse: await asyncio.sleep(1) try: self._session = aiohttp.ClientSession() + # pyrefly: ignore ws = await self._session.ws_connect(self.ws_url, timeout=120) # FIXME fix deprecation warning in for future aiohttp version # ws = await self._session.ws_connect( @@ -229,8 +230,9 @@ async def _connect(self) -> ClientWebSocketResponse: # ) logger.debug("Websocket connection ready") return ws - except aiohttp.client_exceptions.ClientConnectorError: - await self._session.close() + except aiohttp.ClientConnectorError: + if self._session is not None: + await self._session.close() if i > 15: logger.debug("Connection refused (startup not complete (yet))") @@ -266,6 +268,7 @@ def _clean(self, verbose: bool = True) -> None: self._process.wait() logger.error("KILLED") # Try to kill wineserver on linux + # pyrefly: ignore if paths.PF in {"Linux", "WineLinux"}: # Command wineserver not detected with suppress(FileNotFoundError), subprocess.Popen(["wineserver", "-k"]) as p: @@ -278,6 +281,6 @@ def _clean(self, verbose: bool = True) -> None: self._ws = None if self._used_portpicker and self._port is not None: portpicker.return_port(self._port) - self._port = None + self._port = -1 if verbose: logger.info("Cleanup complete") diff --git a/sc2/score.py b/sc2/score.py index aba9c8ff..82e8a823 100644 --- a/sc2/score.py +++ b/sc2/score.py @@ -13,7 +13,7 @@ def __init__(self, proto: score_pb2.Score) -> None: self._proto = proto.score_details @property - def summary(self) -> list[list[int | float]]: + def summary(self) -> list[list[float]]: """ TODO this is super ugly, how can we improve this summary? Print summary to file with: @@ -105,6 +105,7 @@ def summary(self) -> list[list[int | float]]: "current_apm", "current_effective_apm", ] + # pyrefly: ignore return [[value, getattr(self, value)] for value in values] @property diff --git a/sc2/unit.py b/sc2/unit.py index 2153decc..d2caaa5f 100644 --- a/sc2/unit.py +++ b/sc2/unit.py @@ -52,7 +52,7 @@ UNIT_PHOTONCANNON, transforming, ) -from sc2.data import Alliance, Attribute, CloakState, Race, Target, race_gas, warpgate_abilities +from sc2.data import Attribute, CloakState, Race, Target, race_gas, warpgate_abilities from sc2.ids.ability_id import AbilityId from sc2.ids.buff_id import BuffId from sc2.ids.unit_typeid import UnitTypeId @@ -286,7 +286,7 @@ def air_range(self) -> float: return 0 @cached_property - def bonus_damage(self) -> tuple[int, str] | None: + def bonus_damage(self) -> tuple[float, str] | None: """Returns a tuple of form '(bonus damage, armor type)' if unit does 'bonus damage' against 'armor type'. Possible armor typs are: 'Light', 'Armored', 'Biological', 'Mechanical', 'Psionic', 'Massive', 'Structure'.""" # TODO: Consider units with ability attacks (Oracle, Baneling) or multiple attacks (Thor). @@ -471,7 +471,8 @@ def is_snapshot(self) -> bool: if self.base_build >= 82457: return self._proto.display_type == IS_SNAPSHOT # TODO: Fixed in version 5.0.4, remove if a new linux binary is released: https://github.com/Blizzard/s2client-proto/issues/167 - position = self.position.rounded + # pyrefly: ignore + position: tuple[int, int] = self.position.rounded return self._bot_object.state.visibility.data_numpy[position[1], position[0]] != 2 @cached_property @@ -503,7 +504,7 @@ def is_placeholder(self) -> bool: return self._proto.display_type == IS_PLACEHOLDER @property - def alliance(self) -> Alliance: + def alliance(self) -> int: """Returns the team the unit belongs to.""" return self._proto.alliance @@ -528,7 +529,7 @@ def position_tuple(self) -> tuple[float, float]: return self._proto.pos.x, self._proto.pos.y @cached_property - def position(self) -> Point2: # pyright: ignore[reportIncompatibleMethodOverride] + def position(self) -> Point2: """Returns the 2d position of the unit.""" return Point2.from_proto(self._proto.pos) @@ -593,9 +594,7 @@ def in_ability_cast_range(self, ability_id: AbilityId, target: Unit | Point2, bo <= (cast_range + self.radius + target.radius + bonus_distance) ** 2 ) # For casting abilities on the ground, like queen creep tumor, ravager bile, HT storm - if ability_target_type in {Target.Point.value, Target.PointOrUnit.value} and isinstance( - target, (Point2, tuple) - ): + if ability_target_type in {Target.Point.value, Target.PointOrUnit.value} and isinstance(target, Point2): return ( self._bot_object._distance_pos_to_pos(self.position_tuple, target) <= cast_range + self.radius + bonus_distance @@ -1027,6 +1026,7 @@ def buff_duration_max(self) -> int: def orders(self) -> list[UnitOrder]: """Returns the a list of the current orders.""" # TODO: add examples on how to use unit orders + # pyrefly: ignore return [UnitOrder.from_proto(order, self._bot_object) for order in self._proto.orders] @cached_property @@ -1153,6 +1153,7 @@ def add_on_position(self) -> Point2: @cached_property def passengers(self) -> set[Unit]: """Returns the units inside a Bunker, CommandCenter, PlanetaryFortress, Medivac, Nydus, Overlord or WarpPrism.""" + # pyrefly: ignore return {Unit(unit, self._bot_object) for unit in self._proto.passengers} @cached_property @@ -1258,8 +1259,12 @@ def train( :param queue: :param can_afford_check: """ + creation_ability = self._bot_object.game_data.units[unit.value].creation_ability + if creation_ability is None: + return False + return self( - self._bot_object.game_data.units[unit.value].creation_ability.id, + creation_ability.id, queue=queue, subtract_cost=True, can_afford_check=can_afford_check, @@ -1289,8 +1294,11 @@ def build( assert isinstance(position, Unit), ( "When building the gas structure, the target needs to be a unit (the vespene geysir) not the position of the vespene geysir." ) + creation_ability = self._bot_object.game_data.units[unit.value].creation_ability + if creation_ability is None: + return False return self( - self._bot_object.game_data.units[unit.value].creation_ability.id, + creation_ability.id, target=position, queue=queue, subtract_cost=True, @@ -1317,8 +1325,11 @@ def build_gas( assert isinstance(target_geysir, Unit), ( "When building the gas structure, the target needs to be a unit (the vespene geysir) not the position of the vespene geysir." ) + creation_ability = self._bot_object.game_data.units[gas_structure_type_id.value].creation_ability + if creation_ability is None: + return False return self( - self._bot_object.game_data.units[gas_structure_type_id.value].creation_ability.id, + creation_ability.id, target=target_geysir, queue=queue, subtract_cost=True, @@ -1338,8 +1349,11 @@ def research( :param queue: :param can_afford_check: """ + research_ability = self._bot_object.game_data.upgrades[upgrade.value].research_ability + if research_ability is None: + return False return self( - self._bot_object.game_data.upgrades[upgrade.value].research_ability.exact_id, + research_ability.exact_id, queue=queue, subtract_cost=True, can_afford_check=can_afford_check, @@ -1360,9 +1374,8 @@ def warp_in( creation_ability = self._bot_object.game_data.units[unit.value].creation_ability if creation_ability is None: return False - normal_creation_ability = creation_ability.id return self( - warpgate_abilities[normal_creation_ability], + warpgate_abilities[creation_ability.id], target=position, subtract_cost=True, subtract_supply=True, diff --git a/sc2/versions.py b/sc2/versions.py index 96c2f35f..a242b1d8 100644 --- a/sc2/versions.py +++ b/sc2/versions.py @@ -1,4 +1,4 @@ -VERSIONS = [ +VERSIONS: list[dict[str, int | str]] = [ { "base-version": 52910, "data-hash": "8D9FEF2E1CF7C6C9CBE4FBCA830DDE1C", From d3f6fef13bacf59d14d2854b0e094512597791a0 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Fri, 2 Jan 2026 16:13:43 +0100 Subject: [PATCH 4/5] Fix pyglet ImageData import --- sc2/renderer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sc2/renderer.py b/sc2/renderer.py index 933b1c14..68ae1d29 100644 --- a/sc2/renderer.py +++ b/sc2/renderer.py @@ -1,5 +1,4 @@ from __future__ import annotations -from pyglet.image import ImageData from pyglet.text import Label @@ -17,6 +16,7 @@ if TYPE_CHECKING: from sc2.client import Client + from pyglet.image import ImageData class Renderer: From dd9e9dbdbd34a3d3616fa209560b529a9ff7bbe6 Mon Sep 17 00:00:00 2001 From: burnysc2 Date: Fri, 2 Jan 2026 16:17:44 +0100 Subject: [PATCH 5/5] Move renderer pyglet imports into TYPECHECKING check --- sc2/renderer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sc2/renderer.py b/sc2/renderer.py index 68ae1d29..7b7d1165 100644 --- a/sc2/renderer.py +++ b/sc2/renderer.py @@ -1,12 +1,6 @@ from __future__ import annotations -from pyglet.text import Label - - -from pyglet.window import Window - - import datetime from typing import TYPE_CHECKING @@ -17,6 +11,8 @@ if TYPE_CHECKING: from sc2.client import Client from pyglet.image import ImageData + from pyglet.text import Label + from pyglet.window import Window class Renderer: